mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 19:28:23 +00:00
feat(feedback): add UNIQUE(thread_id, run_id, user_id) constraint
Add UNIQUE constraint to FeedbackRow to enforce one feedback per user per run, enabling upsert behavior in Task 2. Update tests to use distinct user_ids for multiple feedback records per run, and pass user_id=None to list_by_run for admin-style queries that bypass user isolation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
b2ec1f99b9
commit
60a5ad7279
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import DateTime, String, Text
|
||||
from sqlalchemy import DateTime, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
@ -13,6 +13,10 @@ from deerflow.persistence.base import Base
|
||||
class FeedbackRow(Base):
|
||||
__tablename__ = "feedback"
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"),
|
||||
)
|
||||
|
||||
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
|
||||
@ -97,10 +97,10 @@ class TestFeedbackRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_run(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1)
|
||||
await repo.create(run_id="r2", thread_id="t1", rating=1)
|
||||
results = await repo.list_by_run("t1", "r1")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-2")
|
||||
await repo.create(run_id="r2", thread_id="t1", rating=1, user_id="user-1")
|
||||
results = await repo.list_by_run("t1", "r1", user_id=None)
|
||||
assert len(results) == 2
|
||||
assert all(r["run_id"] == "r1" for r in results)
|
||||
await _cleanup()
|
||||
@ -135,9 +135,9 @@ class TestFeedbackRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_aggregate_by_run(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1)
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-2")
|
||||
await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-3")
|
||||
stats = await repo.aggregate_by_run("t1", "r1")
|
||||
assert stats["total"] == 3
|
||||
assert stats["positive"] == 2
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user