feat(persistence): add user feedback + follow-up run association

Phase 2-C: feedback and follow-up tracking.

- FeedbackRow ORM model (rating +1/-1, optional message_id, comment)
- FeedbackRepository with CRUD, list_by_run/thread, aggregate stats
- Feedback API endpoints: create, list, stats, delete
- follow_up_to_run_id in RunCreateRequest (explicit or auto-detected
  from latest successful run on the thread)
- Worker writes follow_up_to_run_id into human_message event metadata
- Gateway deps: feedback_repo factory + getter
- 17 new tests (14 FeedbackRepository + 3 follow-up association)
- 109 total tests pass, zero regressions

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng 2026-04-02 19:10:11 +08:00
parent e3179cd54d
commit 5cb0471af5
11 changed files with 508 additions and 3 deletions

View File

@ -11,6 +11,7 @@ from app.gateway.routers import (
artifacts,
assistants_compat,
channels,
feedback,
mcp,
memory,
models,
@ -199,6 +200,9 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
# Assistants compatibility API (LangGraph Platform stub)
app.include_router(assistants_compat.router)
# Feedback API is mounted at /api/threads/{thread_id}/runs/{run_id}/feedback
app.include_router(feedback.router)
# Thread Runs API (LangGraph Platform-compatible runs lifecycle)
app.include_router(thread_runs.router)

View File

@ -46,6 +46,9 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
# Initialize run event store based on config
app.state.run_event_store = _make_run_event_store(config)
# Initialize feedback repository (None when no DB engine)
app.state.feedback_repo = _make_feedback_repo()
# RunManager with store backing for persistence
app.state.run_manager = RunManager(store=app.state.run_store)
@ -74,6 +77,18 @@ def _make_run_store() -> RunStore:
return MemoryRunStore()
def _make_feedback_repo():
"""Create a FeedbackRepository if DB engine is available, else None."""
from deerflow.persistence.engine import get_session_factory
sf = get_session_factory()
if sf is not None:
from deerflow.persistence.repositories.feedback_repo import FeedbackRepository
return FeedbackRepository(sf)
return None
def _make_run_event_store(config) -> RunEventStore:
from deerflow.runtime.events.store import make_run_event_store
@ -123,6 +138,14 @@ def get_run_event_store(request: Request) -> RunEventStore:
return store
def get_feedback_repo(request: Request):
"""Return the FeedbackRepository, or 503 if not available."""
repo = getattr(request.app.state, "feedback_repo", None)
if repo is None:
raise HTTPException(status_code=503, detail="Feedback not available")
return repo
def get_run_store(request: Request) -> RunStore:
"""Return the RunStore, or 503 if not available."""
store = getattr(request.app.state, "run_store", None)

View File

@ -0,0 +1,121 @@
"""Feedback endpoints — create, list, stats, delete.
Allows users to submit thumbs-up/down feedback on runs,
optionally scoped to a specific message.
"""
from __future__ import annotations
import logging
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
from app.gateway.deps import get_current_user, get_feedback_repo, get_run_store
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["feedback"])
# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------
class FeedbackCreateRequest(BaseModel):
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
comment: str | None = Field(default=None, description="Optional text feedback")
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
class FeedbackResponse(BaseModel):
feedback_id: str
run_id: str
thread_id: str
owner_id: str | None = None
message_id: str | None = None
rating: int
comment: str | None = None
created_at: str = ""
class FeedbackStatsResponse(BaseModel):
run_id: str
total: int = 0
positive: int = 0
negative: int = 0
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
async def create_feedback(
thread_id: str,
run_id: str,
body: FeedbackCreateRequest,
request: Request,
) -> dict[str, Any]:
"""Submit feedback (thumbs-up/down) for a run."""
if body.rating not in (1, -1):
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
user_id = await get_current_user(request)
# Validate run exists and belongs to thread
run_store = get_run_store(request)
run = await run_store.get(run_id)
if run is None:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
if run.get("thread_id") != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
feedback_repo = get_feedback_repo(request)
return await feedback_repo.create(
run_id=run_id,
thread_id=thread_id,
rating=body.rating,
owner_id=user_id,
message_id=body.message_id,
comment=body.comment,
)
@router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse])
async def list_feedback(
thread_id: str,
run_id: str,
request: Request,
) -> list[dict[str, Any]]:
"""List all feedback for a run."""
feedback_repo = get_feedback_repo(request)
return await feedback_repo.list_by_run(thread_id, run_id)
@router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse)
async def feedback_stats(
thread_id: str,
run_id: str,
request: Request,
) -> dict[str, Any]:
"""Get aggregated feedback stats (positive/negative counts) for a run."""
feedback_repo = get_feedback_repo(request)
return await feedback_repo.aggregate_by_run(thread_id, run_id)
@router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}")
async def delete_feedback(
thread_id: str,
run_id: str,
feedback_id: str,
request: Request,
) -> dict[str, bool]:
"""Delete a feedback record."""
feedback_repo = get_feedback_repo(request)
deleted = await feedback_repo.delete(feedback_id)
if not deleted:
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
return {"success": True}

View File

@ -52,6 +52,7 @@ class RunCreateRequest(BaseModel):
after_seconds: float | None = Field(default=None, description="Delayed execution")
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
follow_up_to_run_id: str | None = Field(default=None, description="Run ID this message follows up on. Auto-detected from latest successful run if not provided.")
class RunResponse(BaseModel):

View File

@ -17,7 +17,7 @@ from typing import Any
from fastapi import HTTPException, Request
from langchain_core.messages import HumanMessage
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_store, get_stream_bridge
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_store, get_stream_bridge
from deerflow.runtime import (
END_SENTINEL,
HEARTBEAT_SENTINEL,
@ -274,6 +274,17 @@ async def start_run(
if store is not None:
await _upsert_thread_in_store(store, thread_id, body.metadata)
# Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run
follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None)
if follow_up_to_run_id is None:
run_store = get_run_store(request)
try:
recent_runs = await run_store.list_by_thread(thread_id, limit=1)
if recent_runs and recent_runs[0].get("status") == "success":
follow_up_to_run_id = recent_runs[0]["run_id"]
except Exception:
pass # Don't block run creation
agent_factory = resolve_agent_factory(body.assistant_id)
graph_input = normalize_input(body.input)
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
@ -295,6 +306,7 @@ async def start_run(
interrupt_after=body.interrupt_after,
event_store=event_store,
run_events_config=run_events_config,
follow_up_to_run_id=follow_up_to_run_id,
)
)
record.task = task

View File

@ -1,5 +1,6 @@
from deerflow.persistence.models.feedback import FeedbackRow
from deerflow.persistence.models.run import RunRow
from deerflow.persistence.models.run_event import RunEventRow
from deerflow.persistence.models.thread_meta import ThreadMetaRow
__all__ = ["RunEventRow", "RunRow", "ThreadMetaRow"]
__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow"]

View File

@ -0,0 +1,30 @@
"""ORM model for user feedback on runs."""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import String, Text
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
class FeedbackRow(Base):
__tablename__ = "feedback"
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
message_id: Mapped[str | None] = mapped_column(String(64))
# message_id is an optional RunEventStore event identifier —
# allows feedback to target a specific message or the entire run
rating: Mapped[int] = mapped_column(nullable=False)
# +1 (thumbs-up) or -1 (thumbs-down)
comment: Mapped[str | None] = mapped_column(Text)
# Optional text feedback from the user
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))

View File

@ -1,4 +1,5 @@
from deerflow.persistence.repositories.feedback_repo import FeedbackRepository
from deerflow.persistence.repositories.run_repo import RunRepository
from deerflow.persistence.repositories.thread_meta_repo import ThreadMetaRepository
__all__ = ["RunRepository", "ThreadMetaRepository"]
__all__ = ["FeedbackRepository", "RunRepository", "ThreadMetaRepository"]

View File

@ -0,0 +1,97 @@
"""SQLAlchemy-backed feedback storage.
Each method acquires its own short-lived session.
"""
from __future__ import annotations
import logging
import uuid
from datetime import UTC, datetime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.feedback import FeedbackRow
logger = logging.getLogger(__name__)
class FeedbackRepository:
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory
@staticmethod
def _row_to_dict(row: FeedbackRow) -> dict:
d = row.to_dict()
val = d.get("created_at")
if isinstance(val, datetime):
d["created_at"] = val.isoformat()
return d
async def create(
self,
*,
run_id: str,
thread_id: str,
rating: int,
owner_id: str | None = None,
message_id: str | None = None,
comment: str | None = None,
) -> dict:
"""Create a feedback record. rating must be +1 or -1."""
if rating not in (1, -1):
raise ValueError(f"rating must be +1 or -1, got {rating}")
row = FeedbackRow(
feedback_id=str(uuid.uuid4()),
run_id=run_id,
thread_id=thread_id,
owner_id=owner_id,
message_id=message_id,
rating=rating,
comment=comment,
created_at=datetime.now(UTC),
)
async with self._sf() as session:
session.add(row)
await session.commit()
await session.refresh(row)
return self._row_to_dict(row)
async def get(self, feedback_id: str) -> dict | None:
async with self._sf() as session:
row = await session.get(FeedbackRow, feedback_id)
return self._row_to_dict(row) if row else None
async def list_by_run(self, thread_id: str, run_id: str, *, limit: int = 100) -> list[dict]:
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id).order_by(FeedbackRow.created_at.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def list_by_thread(self, thread_id: str, *, limit: int = 100) -> list[dict]:
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id).order_by(FeedbackRow.created_at.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def delete(self, feedback_id: str) -> bool:
async with self._sf() as session:
row = await session.get(FeedbackRow, feedback_id)
if row is None:
return False
await session.delete(row)
await session.commit()
return True
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict:
"""Aggregate feedback stats for a run."""
items = await self.list_by_run(thread_id, run_id, limit=10000)
positive = sum(1 for i in items if i["rating"] == 1)
negative = sum(1 for i in items if i["rating"] == -1)
return {
"run_id": run_id,
"total": len(items),
"positive": positive,
"negative": negative,
}

View File

@ -47,6 +47,7 @@ async def run_agent(
interrupt_after: list[str] | Literal["*"] | None = None,
event_store: Any | None = None,
run_events_config: Any | None = None,
follow_up_to_run_id: str | None = None,
) -> None:
"""Execute an agent in the background, publishing events to *bridge*."""
@ -69,12 +70,16 @@ async def run_agent(
# Write human_message event
user_input = _extract_user_input(graph_input)
if user_input:
msg_metadata = {}
if follow_up_to_run_id:
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
await event_store.put(
thread_id=thread_id,
run_id=run_id,
event_type="human_message",
category="message",
content=user_input,
metadata=msg_metadata or None,
)
journal.set_first_human_message(user_input)

View File

@ -0,0 +1,210 @@
"""Tests for FeedbackRepository and follow-up association.
Uses temp SQLite DB for ORM tests.
"""
import pytest
from deerflow.persistence.repositories.feedback_repo import FeedbackRepository
async def _make_feedback_repo(tmp_path):
from deerflow.persistence.engine import get_session_factory, init_engine
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
return FeedbackRepository(get_session_factory())
async def _cleanup():
from deerflow.persistence.engine import close_engine
await close_engine()
# -- FeedbackRepository --
class TestFeedbackRepository:
@pytest.mark.anyio
async def test_create_positive(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
record = await repo.create(run_id="r1", thread_id="t1", rating=1)
assert record["feedback_id"]
assert record["rating"] == 1
assert record["run_id"] == "r1"
assert record["thread_id"] == "t1"
assert "created_at" in record
await _cleanup()
@pytest.mark.anyio
async def test_create_negative_with_comment(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
record = await repo.create(
run_id="r1", thread_id="t1", rating=-1, comment="Response was inaccurate",
)
assert record["rating"] == -1
assert record["comment"] == "Response was inaccurate"
await _cleanup()
@pytest.mark.anyio
async def test_create_with_message_id(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
record = await repo.create(run_id="r1", thread_id="t1", rating=1, message_id="msg-42")
assert record["message_id"] == "msg-42"
await _cleanup()
@pytest.mark.anyio
async def test_create_with_owner(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
record = await repo.create(run_id="r1", thread_id="t1", rating=1, owner_id="user-1")
assert record["owner_id"] == "user-1"
await _cleanup()
@pytest.mark.anyio
async def test_create_invalid_rating_zero(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
with pytest.raises(ValueError):
await repo.create(run_id="r1", thread_id="t1", rating=0)
await _cleanup()
@pytest.mark.anyio
async def test_create_invalid_rating_five(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
with pytest.raises(ValueError):
await repo.create(run_id="r1", thread_id="t1", rating=5)
await _cleanup()
@pytest.mark.anyio
async def test_get(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
created = await repo.create(run_id="r1", thread_id="t1", rating=1)
fetched = await repo.get(created["feedback_id"])
assert fetched is not None
assert fetched["feedback_id"] == created["feedback_id"]
assert fetched["rating"] == 1
await _cleanup()
@pytest.mark.anyio
async def test_get_nonexistent(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
assert await repo.get("nonexistent") is None
await _cleanup()
@pytest.mark.anyio
async def test_list_by_run(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
await repo.create(run_id="r1", thread_id="t1", rating=1)
await repo.create(run_id="r1", thread_id="t1", rating=-1)
await repo.create(run_id="r2", thread_id="t1", rating=1)
results = await repo.list_by_run("t1", "r1")
assert len(results) == 2
assert all(r["run_id"] == "r1" for r in results)
await _cleanup()
@pytest.mark.anyio
async def test_list_by_thread(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
await repo.create(run_id="r1", thread_id="t1", rating=1)
await repo.create(run_id="r2", thread_id="t1", rating=-1)
await repo.create(run_id="r3", thread_id="t2", rating=1)
results = await repo.list_by_thread("t1")
assert len(results) == 2
assert all(r["thread_id"] == "t1" for r in results)
await _cleanup()
@pytest.mark.anyio
async def test_delete(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
created = await repo.create(run_id="r1", thread_id="t1", rating=1)
assert await repo.delete(created["feedback_id"]) is True
assert await repo.get(created["feedback_id"]) is None
await _cleanup()
@pytest.mark.anyio
async def test_delete_nonexistent(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
assert await repo.delete("nonexistent") is False
await _cleanup()
@pytest.mark.anyio
async def test_aggregate_by_run(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
await repo.create(run_id="r1", thread_id="t1", rating=1)
await repo.create(run_id="r1", thread_id="t1", rating=1)
await repo.create(run_id="r1", thread_id="t1", rating=-1)
stats = await repo.aggregate_by_run("t1", "r1")
assert stats["total"] == 3
assert stats["positive"] == 2
assert stats["negative"] == 1
assert stats["run_id"] == "r1"
await _cleanup()
@pytest.mark.anyio
async def test_aggregate_empty(self, tmp_path):
repo = await _make_feedback_repo(tmp_path)
stats = await repo.aggregate_by_run("t1", "r1")
assert stats["total"] == 0
assert stats["positive"] == 0
assert stats["negative"] == 0
await _cleanup()
# -- Follow-up association --
class TestFollowUpAssociation:
@pytest.mark.anyio
async def test_run_records_follow_up_via_memory_store(self):
"""MemoryRunStore stores follow_up_to_run_id in kwargs."""
from deerflow.runtime.runs.store.memory import MemoryRunStore
store = MemoryRunStore()
await store.put("r1", thread_id="t1", status="success")
# MemoryRunStore doesn't have follow_up_to_run_id as a top-level param,
# but it can be passed via metadata
await store.put("r2", thread_id="t1", metadata={"follow_up_to_run_id": "r1"})
run = await store.get("r2")
assert run["metadata"]["follow_up_to_run_id"] == "r1"
@pytest.mark.anyio
async def test_human_message_has_follow_up_metadata(self):
"""human_message event metadata includes follow_up_to_run_id."""
from deerflow.runtime.events.store.memory import MemoryRunEventStore
event_store = MemoryRunEventStore()
await event_store.put(
thread_id="t1",
run_id="r2",
event_type="human_message",
category="message",
content="Tell me more about that",
metadata={"follow_up_to_run_id": "r1"},
)
messages = await event_store.list_messages("t1")
assert messages[0]["metadata"]["follow_up_to_run_id"] == "r1"
@pytest.mark.anyio
async def test_follow_up_auto_detection_logic(self):
"""Simulate the auto-detection: latest successful run becomes follow_up_to."""
from deerflow.runtime.runs.store.memory import MemoryRunStore
store = MemoryRunStore()
await store.put("r1", thread_id="t1", status="success")
await store.put("r2", thread_id="t1", status="error")
# Auto-detect: list_by_thread returns newest first
recent = await store.list_by_thread("t1", limit=1)
follow_up = None
if recent and recent[0].get("status") == "success":
follow_up = recent[0]["run_id"]
# r2 (error) is newest, so no follow_up detected
assert follow_up is None
# Now add a successful run
await store.put("r3", thread_id="t1", status="success")
recent = await store.list_by_thread("t1", limit=1)
follow_up = None
if recent and recent[0].get("status") == "success":
follow_up = recent[0]["run_id"]
assert follow_up == "r3"