diff --git a/backend/packages/harness/deerflow/persistence/feedback/sql.py b/backend/packages/harness/deerflow/persistence/feedback/sql.py index 508c86fa7..1db74ce84 100644 --- a/backend/packages/harness/deerflow/persistence/feedback/sql.py +++ b/backend/packages/harness/deerflow/persistence/feedback/sql.py @@ -162,6 +162,44 @@ class FeedbackRepository: await session.refresh(row) return self._row_to_dict(row) + async def delete_by_run( + self, + *, + thread_id: str, + run_id: str, + user_id: str | None | _AutoSentinel = AUTO, + ) -> bool: + """Delete the current user's feedback for a run. Returns True if a record was deleted.""" + resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete_by_run") + async with self._sf() as session: + stmt = select(FeedbackRow).where( + FeedbackRow.thread_id == thread_id, + FeedbackRow.run_id == run_id, + FeedbackRow.user_id == resolved_user_id, + ) + result = await session.execute(stmt) + row = result.scalar_one_or_none() + if row is None: + return False + await session.delete(row) + await session.commit() + return True + + async def list_by_thread_grouped( + self, + thread_id: str, + *, + user_id: str | None | _AutoSentinel = AUTO, + ) -> dict[str, dict]: + """Return feedback grouped by run_id for a thread: {run_id: feedback_dict}.""" + resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread_grouped") + stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id) + if resolved_user_id is not None: + stmt = stmt.where(FeedbackRow.user_id == resolved_user_id) + async with self._sf() as session: + result = await session.execute(stmt) + return {row.run_id: self._row_to_dict(row) for row in result.scalars()} + async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict: """Aggregate feedback stats for a run using database-side counting.""" stmt = select( diff --git a/backend/tests/test_feedback.py b/backend/tests/test_feedback.py index 01735ffbb..a592bdd22 100644 --- a/backend/tests/test_feedback.py +++ b/backend/tests/test_feedback.py @@ -190,6 +190,44 @@ class TestFeedbackRepository: await repo.upsert(run_id="r1", thread_id="t1", rating=0, user_id="u1") await _cleanup() + @pytest.mark.anyio + async def test_delete_by_run(self, tmp_path): + repo = await _make_feedback_repo(tmp_path) + await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1") + deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1") + assert deleted is True + results = await repo.list_by_run("t1", "r1", user_id="u1") + assert len(results) == 0 + await _cleanup() + + @pytest.mark.anyio + async def test_delete_by_run_nonexistent(self, tmp_path): + repo = await _make_feedback_repo(tmp_path) + deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1") + assert deleted is False + await _cleanup() + + @pytest.mark.anyio + async def test_list_by_thread_grouped(self, tmp_path): + repo = await _make_feedback_repo(tmp_path) + await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1") + await repo.upsert(run_id="r2", thread_id="t1", rating=-1, user_id="u1") + await repo.upsert(run_id="r3", thread_id="t2", rating=1, user_id="u1") + grouped = await repo.list_by_thread_grouped("t1", user_id="u1") + assert "r1" in grouped + assert "r2" in grouped + assert "r3" not in grouped + assert grouped["r1"]["rating"] == 1 + assert grouped["r2"]["rating"] == -1 + await _cleanup() + + @pytest.mark.anyio + async def test_list_by_thread_grouped_empty(self, tmp_path): + repo = await _make_feedback_repo(tmp_path) + grouped = await repo.list_by_thread_grouped("t1", user_id="u1") + assert grouped == {} + await _cleanup() + # -- Follow-up association --