diff --git a/backend/app/gateway/app.py b/backend/app/gateway/app.py index 5b2901107..3d1403d5f 100644 --- a/backend/app/gateway/app.py +++ b/backend/app/gateway/app.py @@ -42,6 +42,11 @@ logger = logging.getLogger(__name__) async def _ensure_admin_user(app: FastAPI) -> None: """Startup hook: generate init token on first boot; migrate orphan threads otherwise. + After admin creation, migrate orphan threads from the LangGraph + store (metadata.user_id unset) to the admin account. This is the + "no-auth → with-auth" upgrade path: users who ran DeerFlow without + authentication have existing LangGraph thread data that needs an + owner assigned. First boot (no admin exists): - Generates a one-time ``init_token`` stored in ``app.state.init_token`` - Logs the token to stdout so the operator can copy-paste it into the @@ -52,7 +57,7 @@ async def _ensure_admin_user(app: FastAPI) -> None: - Runs the one-time "no-auth → with-auth" orphan thread migration for existing LangGraph thread metadata that has no owner_id. - No SQL persistence migration is needed: the four owner_id columns + No SQL persistence migration is needed: the four user_id columns (threads_meta, runs, run_events, feedback) only come into existence alongside the auth module via create_all, so freshly created tables never contain NULL-owner rows. @@ -96,6 +101,8 @@ async def _ensure_admin_user(app: FastAPI) -> None: admin_id = str(row.id) # LangGraph store orphan migration — non-fatal. + # This covers the "no-auth → with-auth" upgrade path for users + # whose existing LangGraph thread metadata has no user_id set. store = getattr(app.state, "store", None) if store is not None: try: @@ -127,7 +134,7 @@ async def _iter_store_items(store, namespace, *, page_size: int = 500): async def _migrate_orphaned_threads(store, admin_user_id: str) -> int: - """Migrate LangGraph store threads with no owner_id to the given admin. + """Migrate LangGraph store threads with no user_id to the given admin. Uses cursor pagination so all orphans are migrated regardless of count. Returns the number of rows migrated. @@ -135,8 +142,8 @@ async def _migrate_orphaned_threads(store, admin_user_id: str) -> int: migrated = 0 async for item in _iter_store_items(store, ("threads",)): metadata = item.value.get("metadata", {}) - if not metadata.get("owner_id"): - metadata["owner_id"] = admin_user_id + if not metadata.get("user_id"): + metadata["user_id"] = admin_user_id item.value["metadata"] = metadata await store.aput(("threads",), item.key, item.value) migrated += 1 diff --git a/backend/app/gateway/authz.py b/backend/app/gateway/authz.py index fa2e5f2d5..5842a24c7 100644 --- a/backend/app/gateway/authz.py +++ b/backend/app/gateway/authz.py @@ -233,18 +233,18 @@ def require_permission( # (``threads_meta`` table). We verify ownership via # ``ThreadMetaStore.check_access``: it returns True for # missing rows (untracked legacy thread) and for rows whose - # ``owner_id`` is NULL (shared / pre-auth data), so this is + # ``user_id`` is NULL (shared / pre-auth data), so this is # strict-deny rather than strict-allow — only an *existing* - # row with a *different* owner_id triggers 404. + # row with a *different* user_id triggers 404. if owner_check: thread_id = kwargs.get("thread_id") if thread_id is None: raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter") - from app.gateway.deps import get_thread_meta_repo + from app.gateway.deps import get_thread_store - thread_meta_repo = get_thread_meta_repo(request) - allowed = await thread_meta_repo.check_access( + thread_store = get_thread_store(request) + allowed = await thread_store.check_access( thread_id, str(auth.user.id), require_existing=require_existing, diff --git a/backend/app/gateway/deps.py b/backend/app/gateway/deps.py index 5ea7f6751..f4fdad473 100644 --- a/backend/app/gateway/deps.py +++ b/backend/app/gateway/deps.py @@ -1,8 +1,7 @@ """Centralized accessors for singleton objects stored on ``app.state``. **Getters** (used by routers): raise 503 when a required dependency is -missing, except ``get_store`` and ``get_thread_meta_repo`` which return -``None``. +missing, except ``get_store`` which returns ``None``. Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`. """ @@ -20,6 +19,7 @@ from deerflow.runtime import RunContext, RunManager if TYPE_CHECKING: from app.gateway.auth.local_provider import LocalAuthProvider from app.gateway.auth.repositories.sqlite import SQLiteUserRepository + from deerflow.persistence.thread_meta.base import ThreadMetaStore @asynccontextmanager @@ -31,10 +31,10 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: async with langgraph_runtime(app): yield """ - from deerflow.agents.checkpointer.async_provider import make_checkpointer from deerflow.config import get_app_config from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config from deerflow.runtime import make_store, make_stream_bridge + from deerflow.runtime.checkpointer.async_provider import make_checkpointer from deerflow.runtime.events.store import make_run_event_store async with AsyncExitStack() as stack: @@ -53,18 +53,18 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: if sf is not None: from deerflow.persistence.feedback import FeedbackRepository from deerflow.persistence.run import RunRepository - from deerflow.persistence.thread_meta import ThreadMetaRepository app.state.run_store = RunRepository(sf) app.state.feedback_repo = FeedbackRepository(sf) - app.state.thread_meta_repo = ThreadMetaRepository(sf) else: - from deerflow.persistence.thread_meta import MemoryThreadMetaStore from deerflow.runtime.runs.store.memory import MemoryRunStore app.state.run_store = MemoryRunStore() app.state.feedback_repo = None - app.state.thread_meta_repo = MemoryThreadMetaStore(app.state.store) + + from deerflow.persistence.thread_meta import make_thread_store + + app.state.thread_store = make_thread_store(sf, app.state.store) # Run event store (has its own factory with config-driven backend selection) run_events_config = getattr(config, "run_events", None) @@ -80,7 +80,7 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: # --------------------------------------------------------------------------- -# Getters -- called by routers per-request +# Getters – called by routers per-request # --------------------------------------------------------------------------- @@ -110,7 +110,12 @@ def get_store(request: Request): return getattr(request.app.state, "store", None) -get_thread_meta_repo = _require("thread_meta_repo", "Thread metadata store") +def get_thread_store(request: Request) -> ThreadMetaStore: + """Return the thread metadata store (SQL or memory-backed).""" + val = getattr(request.app.state, "thread_store", None) + if val is None: + raise HTTPException(status_code=503, detail="Thread metadata store not available") + return val def get_run_context(request: Request) -> RunContext: @@ -128,10 +133,11 @@ def get_run_context(request: Request) -> RunContext: store=get_store(request), event_store=get_run_event_store(request), run_events_config=getattr(get_app_config(), "run_events", None), - thread_meta_repo=get_thread_meta_repo(request), + thread_store=get_thread_store(request), ) + # --------------------------------------------------------------------------- # Auth helpers (used by authz.py and auth middleware) # --------------------------------------------------------------------------- diff --git a/backend/app/gateway/langgraph_auth.py b/backend/app/gateway/langgraph_auth.py index 25d3b434c..06074b9b8 100644 --- a/backend/app/gateway/langgraph_auth.py +++ b/backend/app/gateway/langgraph_auth.py @@ -93,14 +93,14 @@ async def authenticate(request): @auth.on async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict): - """Inject owner_id metadata on writes; filter by owner_id on reads. + """Inject user_id metadata on writes; filter by user_id on reads. - Gateway stores thread ownership as ``metadata.owner_id``. + Gateway stores thread ownership as ``metadata.user_id``. This handler ensures LangGraph Server enforces the same isolation. """ - # On create/update: stamp owner_id into metadata + # On create/update: stamp user_id into metadata metadata = value.setdefault("metadata", {}) - metadata["owner_id"] = ctx.user.identity + metadata["user_id"] = ctx.user.identity # Return filter dict — LangGraph applies it to search/read/delete - return {"owner_id": ctx.user.identity} + return {"user_id": ctx.user.identity} diff --git a/backend/app/gateway/routers/feedback.py b/backend/app/gateway/routers/feedback.py index 2bf631d01..ca5c1d406 100644 --- a/backend/app/gateway/routers/feedback.py +++ b/backend/app/gateway/routers/feedback.py @@ -30,11 +30,16 @@ class FeedbackCreateRequest(BaseModel): message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message") +class FeedbackUpsertRequest(BaseModel): + rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)") + comment: str | None = Field(default=None, description="Optional text feedback") + + class FeedbackResponse(BaseModel): feedback_id: str run_id: str thread_id: str - owner_id: str | None = None + user_id: str | None = None message_id: str | None = None rating: int comment: str | None = None @@ -53,6 +58,57 @@ class FeedbackStatsResponse(BaseModel): # --------------------------------------------------------------------------- +@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse) +@require_permission("threads", "write", owner_check=True, require_existing=True) +async def upsert_feedback( + thread_id: str, + run_id: str, + body: FeedbackUpsertRequest, + request: Request, +) -> dict[str, Any]: + """Create or update feedback for a run (idempotent).""" + if body.rating not in (1, -1): + raise HTTPException(status_code=400, detail="rating must be +1 or -1") + + user_id = await get_current_user(request) + + run_store = get_run_store(request) + run = await run_store.get(run_id) + if run is None: + raise HTTPException(status_code=404, detail=f"Run {run_id} not found") + if run.get("thread_id") != thread_id: + raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}") + + feedback_repo = get_feedback_repo(request) + return await feedback_repo.upsert( + run_id=run_id, + thread_id=thread_id, + rating=body.rating, + user_id=user_id, + comment=body.comment, + ) + + +@router.delete("/{thread_id}/runs/{run_id}/feedback") +@require_permission("threads", "delete", owner_check=True, require_existing=True) +async def delete_run_feedback( + thread_id: str, + run_id: str, + request: Request, +) -> dict[str, bool]: + """Delete the current user's feedback for a run.""" + user_id = await get_current_user(request) + feedback_repo = get_feedback_repo(request) + deleted = await feedback_repo.delete_by_run( + thread_id=thread_id, + run_id=run_id, + user_id=user_id, + ) + if not deleted: + raise HTTPException(status_code=404, detail="No feedback found for this run") + return {"success": True} + + @router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse) @require_permission("threads", "write", owner_check=True, require_existing=True) async def create_feedback( @@ -80,7 +136,7 @@ async def create_feedback( run_id=run_id, thread_id=thread_id, rating=body.rating, - owner_id=user_id, + user_id=user_id, message_id=body.message_id, comment=body.comment, ) diff --git a/backend/app/gateway/routers/thread_runs.py b/backend/app/gateway/routers/thread_runs.py index 904f94ff0..e414801ee 100644 --- a/backend/app/gateway/routers/thread_runs.py +++ b/backend/app/gateway/routers/thread_runs.py @@ -20,7 +20,7 @@ from fastapi.responses import Response, StreamingResponse from pydantic import BaseModel, Field from app.gateway.authz import require_permission -from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge +from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge from app.gateway.services import sse_consumer, start_run from deerflow.runtime import RunRecord, serialize_channel_values @@ -291,9 +291,36 @@ async def list_thread_messages( before_seq: int | None = Query(default=None), after_seq: int | None = Query(default=None), ) -> list[dict]: - """Return displayable messages for a thread (across all runs).""" + """Return displayable messages for a thread (across all runs), with feedback attached.""" event_store = get_run_event_store(request) - return await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq) + messages = await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq) + + # Attach feedback to the last AI message of each run + feedback_repo = get_feedback_repo(request) + user_id = await get_current_user(request) + feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id) + + # Find the last ai_message per run_id + last_ai_per_run: dict[str, int] = {} # run_id -> index in messages list + for i, msg in enumerate(messages): + if msg.get("event_type") == "ai_message": + last_ai_per_run[msg["run_id"]] = i + + # Attach feedback field + last_ai_indices = set(last_ai_per_run.values()) + for i, msg in enumerate(messages): + if i in last_ai_indices: + run_id = msg["run_id"] + fb = feedback_map.get(run_id) + msg["feedback"] = { + "feedback_id": fb["feedback_id"], + "rating": fb["rating"], + "comment": fb.get("comment"), + } if fb else None + else: + msg["feedback"] = None + + return messages @router.get("/{thread_id}/runs/{run_id}/messages") diff --git a/backend/app/gateway/routers/threads.py b/backend/app/gateway/routers/threads.py index c7c7b4053..5eb4a30b5 100644 --- a/backend/app/gateway/routers/threads.py +++ b/backend/app/gateway/routers/threads.py @@ -13,6 +13,7 @@ matching the LangGraph Platform wire format expected by the from __future__ import annotations import logging +import re import time import uuid from typing import Any @@ -21,7 +22,7 @@ from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel, Field, field_validator from app.gateway.authz import require_permission -from app.gateway.deps import get_checkpointer +from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store from app.gateway.utils import sanitize_log_param from deerflow.config.paths import Paths, get_paths from deerflow.runtime import serialize_channel_values @@ -34,7 +35,7 @@ router = APIRouter(prefix="/api/threads", tags=["threads"]) # them. Pydantic ``@field_validator("metadata")`` strips them on every # inbound model below so a malicious client cannot reflect a forged # owner identity through the API surface. Defense-in-depth — the -# row-level invariant is still ``threads_meta.owner_id`` populated from +# row-level invariant is still ``threads_meta.user_id`` populated from # the auth contextvar; this list closes the metadata-blob echo gap. _SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"}) @@ -194,7 +195,7 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe and removes the thread_meta row from the configured ThreadMetaStore (sqlite or memory). """ - from app.gateway.deps import get_thread_meta_repo + from app.gateway.deps import get_thread_store # Clean local filesystem response = _delete_thread_data(thread_id) @@ -211,8 +212,8 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe # Remove thread_meta row (best-effort) — required for sqlite backend # so the deleted thread no longer appears in /threads/search. try: - thread_meta_repo = get_thread_meta_repo(request) - await thread_meta_repo.delete(thread_id) + thread_store = get_thread_store(request) + await thread_store.delete(thread_id) except Exception: logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id)) @@ -227,17 +228,17 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe and an empty checkpoint (so state endpoints work immediately). Idempotent: returns the existing record when ``thread_id`` already exists. """ - from app.gateway.deps import get_thread_meta_repo + from app.gateway.deps import get_thread_store checkpointer = get_checkpointer(request) - thread_meta_repo = get_thread_meta_repo(request) + thread_store = get_thread_store(request) thread_id = body.thread_id or str(uuid.uuid4()) now = time.time() # ``body.metadata`` is already stripped of server-reserved keys by # ``ThreadCreateRequest._strip_reserved`` — see the model definition. # Idempotency: return existing record when already present - existing_record = await thread_meta_repo.get(thread_id) + existing_record = await thread_store.get(thread_id) if existing_record is not None: return ThreadResponse( thread_id=thread_id, @@ -249,7 +250,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe # Write thread_meta so the thread appears in /threads/search immediately try: - await thread_meta_repo.create( + await thread_store.create( thread_id, assistant_id=getattr(body, "assistant_id", None), metadata=body.metadata, @@ -293,9 +294,9 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th Delegates to the configured ThreadMetaStore implementation (SQL-backed for sqlite/postgres, Store-backed for memory mode). """ - from app.gateway.deps import get_thread_meta_repo + from app.gateway.deps import get_thread_store - repo = get_thread_meta_repo(request) + repo = get_thread_store(request) rows = await repo.search( metadata=body.metadata or None, status=body.status, @@ -320,22 +321,22 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th @require_permission("threads", "write", owner_check=True, require_existing=True) async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse: """Merge metadata into a thread record.""" - from app.gateway.deps import get_thread_meta_repo + from app.gateway.deps import get_thread_store - thread_meta_repo = get_thread_meta_repo(request) - record = await thread_meta_repo.get(thread_id) + thread_store = get_thread_store(request) + record = await thread_store.get(thread_id) if record is None: raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") # ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``. try: - await thread_meta_repo.update_metadata(thread_id, body.metadata) + await thread_store.update_metadata(thread_id, body.metadata) except Exception: logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id)) raise HTTPException(status_code=500, detail="Failed to update thread") # Re-read to get the merged metadata + refreshed updated_at - record = await thread_meta_repo.get(thread_id) or record + record = await thread_store.get(thread_id) or record return ThreadResponse( thread_id=thread_id, status=record.get("status", "idle"), @@ -354,12 +355,12 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse: execution status from the checkpointer. Falls back to the checkpointer alone for threads that pre-date ThreadMetaStore adoption (backward compat). """ - from app.gateway.deps import get_thread_meta_repo + from app.gateway.deps import get_thread_store - thread_meta_repo = get_thread_meta_repo(request) + thread_store = get_thread_store(request) checkpointer = get_checkpointer(request) - record: dict | None = await thread_meta_repo.get(thread_id) + record: dict | None = await thread_store.get(thread_id) # Derive accurate status from the checkpointer config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} @@ -402,6 +403,165 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse: ) +# --------------------------------------------------------------------------- +# Event-store-backed message loader +# --------------------------------------------------------------------------- + +_LEGACY_CMD_INNER_CONTENT_RE = re.compile( + r"ToolMessage\(content=(?P['\"])(?P.*?)(?P=q)", + re.DOTALL, +) + + +def _sanitize_legacy_command_repr(content_field: Any) -> Any: + """Recover the inner ToolMessage text from a legacy ``str(Command(...))`` repr. + + Runs captured before the ``on_tool_end`` fix in ``journal.py`` stored + ``str(Command(update={'messages':[ToolMessage(content='X', ...)]}))`` as the + tool_result content. New runs store ``'X'`` directly. For legacy rows, try + to extract ``'X'`` defensively; return the original string if extraction + fails (still no worse than the checkpoint fallback for summarized threads). + """ + if not isinstance(content_field, str) or not content_field.startswith("Command(update="): + return content_field + match = _LEGACY_CMD_INNER_CONTENT_RE.search(content_field) + return match.group("inner") if match else content_field + + +async def _get_event_store_messages(request: Request, thread_id: str) -> list[dict] | None: + """Load the full message stream for ``thread_id`` from the event store. + + The event store is append-only and unaffected by summarization — the + checkpoint's ``channel_values["messages"]`` is rewritten in-place when the + SummarizationMiddleware runs, which drops all pre-summarize messages. The + event store retains the full transcript, so callers in Gateway mode should + prefer it for rendering the conversation history. + + In addition to the core message content, this helper attaches two extra + fields to every returned dict: + + - ``run_id``: the ``run_id`` of the event that produced this message. + Always present. + - ``feedback``: thumbs-up/down data. Present only on the **final + ``ai_message`` of each run** (matching the per-run feedback semantics + of ``POST /api/threads/{id}/runs/{run_id}/feedback``). The frontend uses + the presence of this field to decide whether to render the feedback + button, which sidesteps the positional-index mapping bug that an + out-of-band ``/messages`` fetch exhibited. + + Behaviour contract: + + - **Full pagination.** ``RunEventStore.list_messages`` returns the newest + ``limit`` records when no cursor is given, so a fixed limit silently + drops older messages on long threads. We size the read from + ``count_messages()`` and then page forward with ``after_seq`` cursors. + - **Copy-on-read.** Each content dict is copied before ``id`` is patched + so the live store object is never mutated; ``MemoryRunEventStore`` + returns live references. + - **Stable ids.** Messages with ``id=None`` (human + tool_result) receive + a deterministic ``uuid5(NAMESPACE_URL, f"{thread_id}:{seq}")`` so React + keys are stable across requests without altering stored data. AI messages + retain their LLM-assigned ``lc_run--*`` ids. + - **Legacy Command repr.** Rows captured before the ``journal.py`` + ``on_tool_end`` fix stored ``str(Command(update={...}))`` as the tool + result content. ``_sanitize_legacy_command_repr`` extracts the inner + ToolMessage text. + - **User context.** ``DbRunEventStore`` is user-scoped by default via + ``resolve_user_id(AUTO)`` in ``runtime/user_context.py``. This helper + must run inside a request where ``@require_permission`` has populated + the user contextvar. Both callers below are decorated appropriately. + Do not call this helper from CLI or migration scripts without passing + ``user_id=None`` explicitly to the underlying store methods. + + Returns ``None`` when the event store is not configured or has no message + events for this thread, so callers fall back to checkpoint messages. + """ + try: + event_store = get_run_event_store(request) + except Exception: + return None + + try: + total = await event_store.count_messages(thread_id) + except Exception: + logger.exception("count_messages failed for thread %s", sanitize_log_param(thread_id)) + return None + if not total: + return None + + # Batch by page_size to keep memory bounded for very long threads. + page_size = 500 + collected: list[dict] = [] + after_seq: int | None = None + while True: + try: + page = await event_store.list_messages(thread_id, limit=page_size, after_seq=after_seq) + except Exception: + logger.exception("list_messages failed for thread %s", sanitize_log_param(thread_id)) + return None + if not page: + break + collected.extend(page) + if len(page) < page_size: + break + next_cursor = page[-1].get("seq") + if next_cursor is None or (after_seq is not None and next_cursor <= after_seq): + break + after_seq = next_cursor + + # Build the message list; track the final ``ai_message`` index per run so + # feedback can be attached at the right position (matches thread_runs.py). + messages: list[dict] = [] + last_ai_per_run: dict[str, int] = {} + for evt in collected: + raw = evt.get("content") + if not isinstance(raw, dict) or "type" not in raw: + continue + content = dict(raw) + if content.get("id") is None: + content["id"] = str(uuid.uuid5(uuid.NAMESPACE_URL, f"{thread_id}:{evt['seq']}")) + if content.get("type") == "tool": + content["content"] = _sanitize_legacy_command_repr(content.get("content")) + run_id = evt.get("run_id") + if run_id: + content["run_id"] = run_id + if evt.get("event_type") == "ai_message" and run_id: + last_ai_per_run[run_id] = len(messages) + messages.append(content) + + if not messages: + return None + + # Attach feedback to the final ai_message of each run. If the feedback + # subsystem is unavailable, leave the ``feedback`` field absent entirely + # so the frontend hides the button rather than showing it over a broken + # write path. + feedback_available = False + feedback_map: dict[str, dict] = {} + try: + feedback_repo = get_feedback_repo(request) + user_id = await get_current_user(request) + feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id) + feedback_available = True + except Exception: + logger.exception("feedback lookup failed for thread %s", sanitize_log_param(thread_id)) + + if feedback_available: + for run_id, idx in last_ai_per_run.items(): + fb = feedback_map.get(run_id) + messages[idx]["feedback"] = ( + { + "feedback_id": fb["feedback_id"], + "rating": fb["rating"], + "comment": fb.get("comment"), + } + if fb + else None + ) + + return messages + + @router.get("/{thread_id}/state", response_model=ThreadStateResponse) @require_permission("threads", "read", owner_check=True) async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse: @@ -440,8 +600,15 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")] tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw] + values = serialize_channel_values(channel_values) + + # Prefer event-store messages: append-only, immune to summarization. + es_messages = await _get_event_store_messages(request, thread_id) + if es_messages is not None: + values["messages"] = es_messages + return ThreadStateResponse( - values=serialize_channel_values(channel_values), + values=values, next=next_tasks, metadata=metadata, checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))}, @@ -462,10 +629,10 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re ThreadMetaStore abstraction so that ``/threads/search`` reflects the change immediately in both sqlite and memory backends. """ - from app.gateway.deps import get_thread_meta_repo + from app.gateway.deps import get_thread_store checkpointer = get_checkpointer(request) - thread_meta_repo = get_thread_meta_repo(request) + thread_store = get_thread_store(request) # checkpoint_ns must be present in the config for aput — default to "" # (the root graph namespace). checkpoint_id is optional; omitting it @@ -529,7 +696,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re new_title = body.values["title"] if new_title: # Skip empty strings and None try: - await thread_meta_repo.update_display_name(thread_id, new_title) + await thread_store.update_display_name(thread_id, new_title) except Exception: logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id)) @@ -559,6 +726,11 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request if body.before: config["configurable"]["checkpoint_id"] = body.before + # Load the full event-store message stream once; attach to the latest + # checkpoint entry only (matching the prior semantics). The event store + # is append-only and immune to summarization. + es_messages = await _get_event_store_messages(request, thread_id) + entries: list[HistoryEntry] = [] is_latest_checkpoint = True try: @@ -582,11 +754,17 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request if thread_data := channel_values.get("thread_data"): values["thread_data"] = thread_data - # Attach messages from checkpointer only for the latest checkpoint + # Attach messages only to the latest checkpoint. Prefer the + # event-store stream (complete and unaffected by summarization); + # fall back to checkpoint channel_values when the event store is + # unavailable or empty. if is_latest_checkpoint: - messages = channel_values.get("messages") - if messages: - values["messages"] = serialize_channel_values({"messages": messages}).get("messages", []) + if es_messages is not None: + values["messages"] = es_messages + else: + messages = channel_values.get("messages") + if messages: + values["messages"] = serialize_channel_values({"messages": messages}).get("messages", []) is_latest_checkpoint = False # Derive next tasks diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index d6acffd48..72f074907 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -229,15 +229,15 @@ async def start_run( # even for threads that were never explicitly created via POST /threads # (e.g. stateless runs). try: - existing = await run_ctx.thread_meta_repo.get(thread_id) + existing = await run_ctx.thread_store.get(thread_id) if existing is None: - await run_ctx.thread_meta_repo.create( + await run_ctx.thread_store.create( thread_id, assistant_id=body.assistant_id, metadata=body.metadata, ) else: - await run_ctx.thread_meta_repo.update_status(thread_id, "running") + await run_ctx.thread_store.update_status(thread_id, "running") except Exception: logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id)) @@ -285,7 +285,7 @@ async def start_run( record.task = task # Title sync is handled by worker.py's finally block which reads the - # title from the checkpoint and calls thread_meta_repo.update_display_name + # title from the checkpoint and calls thread_store.update_display_name # after the run completes. return record diff --git a/backend/docs/TITLE_GENERATION_IMPLEMENTATION.md b/backend/docs/TITLE_GENERATION_IMPLEMENTATION.md index 07a026e79..87e8aa61a 100644 --- a/backend/docs/TITLE_GENERATION_IMPLEMENTATION.md +++ b/backend/docs/TITLE_GENERATION_IMPLEMENTATION.md @@ -124,7 +124,7 @@ title: # checkpointer.py from langgraph.checkpoint.sqlite import SqliteSaver -checkpointer = SqliteSaver.from_conn_string("checkpoints.db") +checkpointer = SqliteSaver.from_conn_string("deerflow.db") ``` ```json diff --git a/backend/packages/harness/deerflow/agents/__init__.py b/backend/packages/harness/deerflow/agents/__init__.py index 2c31a514a..397f67f8e 100644 --- a/backend/packages/harness/deerflow/agents/__init__.py +++ b/backend/packages/harness/deerflow/agents/__init__.py @@ -1,4 +1,3 @@ -from .checkpointer import get_checkpointer, make_checkpointer, reset_checkpointer from .factory import create_deerflow_agent from .features import Next, Prev, RuntimeFeatures from .lead_agent import make_lead_agent @@ -18,7 +17,4 @@ __all__ = [ "make_lead_agent", "SandboxState", "ThreadState", - "get_checkpointer", - "reset_checkpointer", - "make_checkpointer", ] diff --git a/backend/packages/harness/deerflow/client.py b/backend/packages/harness/deerflow/client.py index 1c64ba52a..950fdb085 100644 --- a/backend/packages/harness/deerflow/client.py +++ b/backend/packages/harness/deerflow/client.py @@ -240,7 +240,7 @@ class DeerFlowClient: } checkpointer = self._checkpointer if checkpointer is None: - from deerflow.agents.checkpointer import get_checkpointer + from deerflow.runtime.checkpointer import get_checkpointer checkpointer = get_checkpointer() if checkpointer is not None: @@ -374,7 +374,7 @@ class DeerFlowClient: """ checkpointer = self._checkpointer if checkpointer is None: - from deerflow.agents.checkpointer.provider import get_checkpointer + from deerflow.runtime.checkpointer.provider import get_checkpointer checkpointer = get_checkpointer() @@ -429,7 +429,7 @@ class DeerFlowClient: """ checkpointer = self._checkpointer if checkpointer is None: - from deerflow.agents.checkpointer.provider import get_checkpointer + from deerflow.runtime.checkpointer.provider import get_checkpointer checkpointer = get_checkpointer() diff --git a/backend/packages/harness/deerflow/config/database_config.py b/backend/packages/harness/deerflow/config/database_config.py index 207f8ef77..37cfd579d 100644 --- a/backend/packages/harness/deerflow/config/database_config.py +++ b/backend/packages/harness/deerflow/config/database_config.py @@ -4,8 +4,12 @@ Controls BOTH the LangGraph checkpointer and the DeerFlow application persistence layer (runs, threads metadata, users, etc.). The user configures one backend; the system handles physical separation details. -SQLite mode: checkpointer and app use different .db files in the same -directory to avoid write-lock contention. This is automatic. +SQLite mode: checkpointer and app share a single .db file +({sqlite_dir}/deerflow.db) with WAL journal mode enabled on every +connection. WAL allows concurrent readers and a single writer without +blocking, making a unified file safe for both workloads. Writers +that contend for the lock wait via the default 5-second sqlite3 +busy timeout rather than failing immediately. Postgres mode: both use the same database URL but maintain independent connection pools with different lifecycles. @@ -40,7 +44,7 @@ class DatabaseConfig(BaseModel): ) sqlite_dir: str = Field( default=".deer-flow/data", - description=("Directory for SQLite database files. Checkpointer uses {sqlite_dir}/checkpoints.db, application data uses {sqlite_dir}/app.db."), + description=("Directory for the SQLite database file. Both checkpointer and application data share {sqlite_dir}/deerflow.db."), ) postgres_url: str = Field( default="", @@ -69,21 +73,27 @@ class DatabaseConfig(BaseModel): return str(Path(self.sqlite_dir).resolve()) + @property + def sqlite_path(self) -> str: + """Unified SQLite file path shared by checkpointer and app.""" + return os.path.join(self._resolved_sqlite_dir, "deerflow.db") + + # Backward-compatible aliases @property def checkpointer_sqlite_path(self) -> str: - """SQLite file path for the LangGraph checkpointer.""" - return os.path.join(self._resolved_sqlite_dir, "checkpoints.db") + """SQLite file path for the LangGraph checkpointer (alias for sqlite_path).""" + return self.sqlite_path @property def app_sqlite_path(self) -> str: - """SQLite file path for application ORM data.""" - return os.path.join(self._resolved_sqlite_dir, "app.db") + """SQLite file path for application ORM data (alias for sqlite_path).""" + return self.sqlite_path @property def app_sqlalchemy_url(self) -> str: """SQLAlchemy async URL for the application ORM engine.""" if self.backend == "sqlite": - return f"sqlite+aiosqlite:///{self.app_sqlite_path}" + return f"sqlite+aiosqlite:///{self.sqlite_path}" if self.backend == "postgres": url = self.postgres_url if url.startswith("postgresql://"): diff --git a/backend/packages/harness/deerflow/persistence/engine.py b/backend/packages/harness/deerflow/persistence/engine.py index 7e374788c..2777c2450 100644 --- a/backend/packages/harness/deerflow/persistence/engine.py +++ b/backend/packages/harness/deerflow/persistence/engine.py @@ -98,6 +98,11 @@ async def init_engine( # SQLite deployment (TC-UPG-06 in AUTH_TEST_PLAN.md). The companion # ``synchronous=NORMAL`` is the safe-and-fast pairing — fsync only # at WAL checkpoint boundaries instead of every commit. + # Note: we do not set PRAGMA busy_timeout here — Python's sqlite3 + # driver already defaults to a 5-second busy timeout (see the + # ``timeout`` kwarg of ``sqlite3.connect``), and aiosqlite / + # SQLAlchemy's aiosqlite dialect inherit that default. Setting + # it again would be a no-op. @event.listens_for(_engine.sync_engine, "connect") def _enable_sqlite_wal(dbapi_conn, _record): # noqa: ARG001 — SQLAlchemy contract cursor = dbapi_conn.cursor() diff --git a/backend/packages/harness/deerflow/persistence/feedback/model.py b/backend/packages/harness/deerflow/persistence/feedback/model.py index 221fb5fb1..f06bc84e7 100644 --- a/backend/packages/harness/deerflow/persistence/feedback/model.py +++ b/backend/packages/harness/deerflow/persistence/feedback/model.py @@ -4,7 +4,7 @@ from __future__ import annotations from datetime import UTC, datetime -from sqlalchemy import DateTime, String, Text +from sqlalchemy import DateTime, String, Text, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column from deerflow.persistence.base import Base @@ -13,10 +13,14 @@ from deerflow.persistence.base import Base class FeedbackRow(Base): __tablename__ = "feedback" + __table_args__ = ( + UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"), + ) + feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True) run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True) thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True) - owner_id: Mapped[str | None] = mapped_column(String(64), index=True) + user_id: Mapped[str | None] = mapped_column(String(64), index=True) message_id: Mapped[str | None] = mapped_column(String(64)) # message_id is an optional RunEventStore event identifier — # allows feedback to target a specific message or the entire run diff --git a/backend/packages/harness/deerflow/persistence/feedback/sql.py b/backend/packages/harness/deerflow/persistence/feedback/sql.py index 903124953..1db74ce84 100644 --- a/backend/packages/harness/deerflow/persistence/feedback/sql.py +++ b/backend/packages/harness/deerflow/persistence/feedback/sql.py @@ -12,7 +12,7 @@ from sqlalchemy import case, func, select from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from deerflow.persistence.feedback.model import FeedbackRow -from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id +from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id class FeedbackRepository: @@ -33,19 +33,19 @@ class FeedbackRepository: run_id: str, thread_id: str, rating: int, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, message_id: str | None = None, comment: str | None = None, ) -> dict: """Create a feedback record. rating must be +1 or -1.""" if rating not in (1, -1): raise ValueError(f"rating must be +1 or -1, got {rating}") - resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.create") + resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.create") row = FeedbackRow( feedback_id=str(uuid.uuid4()), run_id=run_id, thread_id=thread_id, - owner_id=resolved_owner_id, + user_id=resolved_user_id, message_id=message_id, rating=rating, comment=comment, @@ -61,14 +61,14 @@ class FeedbackRepository: self, feedback_id: str, *, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, ) -> dict | None: - resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.get") + resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.get") async with self._sf() as session: row = await session.get(FeedbackRow, feedback_id) if row is None: return None - if resolved_owner_id is not None and row.owner_id != resolved_owner_id: + if resolved_user_id is not None and row.user_id != resolved_user_id: return None return self._row_to_dict(row) @@ -78,12 +78,12 @@ class FeedbackRepository: run_id: str, *, limit: int = 100, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, ) -> list[dict]: - resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_run") + resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_run") stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id) - if resolved_owner_id is not None: - stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id) + if resolved_user_id is not None: + stmt = stmt.where(FeedbackRow.user_id == resolved_user_id) stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit) async with self._sf() as session: result = await session.execute(stmt) @@ -94,12 +94,12 @@ class FeedbackRepository: thread_id: str, *, limit: int = 100, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, ) -> list[dict]: - resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_thread") + resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread") stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id) - if resolved_owner_id is not None: - stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id) + if resolved_user_id is not None: + stmt = stmt.where(FeedbackRow.user_id == resolved_user_id) stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit) async with self._sf() as session: result = await session.execute(stmt) @@ -109,19 +109,97 @@ class FeedbackRepository: self, feedback_id: str, *, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, ) -> bool: - resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.delete") + resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete") async with self._sf() as session: row = await session.get(FeedbackRow, feedback_id) if row is None: return False - if resolved_owner_id is not None and row.owner_id != resolved_owner_id: + if resolved_user_id is not None and row.user_id != resolved_user_id: return False await session.delete(row) await session.commit() return True + async def upsert( + self, + *, + run_id: str, + thread_id: str, + rating: int, + user_id: str | None | _AutoSentinel = AUTO, + comment: str | None = None, + ) -> dict: + """Create or update feedback for (thread_id, run_id, user_id). rating must be +1 or -1.""" + if rating not in (1, -1): + raise ValueError(f"rating must be +1 or -1, got {rating}") + resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.upsert") + async with self._sf() as session: + stmt = select(FeedbackRow).where( + FeedbackRow.thread_id == thread_id, + FeedbackRow.run_id == run_id, + FeedbackRow.user_id == resolved_user_id, + ) + result = await session.execute(stmt) + row = result.scalar_one_or_none() + if row is not None: + row.rating = rating + row.comment = comment + row.created_at = datetime.now(UTC) + else: + row = FeedbackRow( + feedback_id=str(uuid.uuid4()), + run_id=run_id, + thread_id=thread_id, + user_id=resolved_user_id, + rating=rating, + comment=comment, + created_at=datetime.now(UTC), + ) + session.add(row) + await session.commit() + await session.refresh(row) + return self._row_to_dict(row) + + async def delete_by_run( + self, + *, + thread_id: str, + run_id: str, + user_id: str | None | _AutoSentinel = AUTO, + ) -> bool: + """Delete the current user's feedback for a run. Returns True if a record was deleted.""" + resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete_by_run") + async with self._sf() as session: + stmt = select(FeedbackRow).where( + FeedbackRow.thread_id == thread_id, + FeedbackRow.run_id == run_id, + FeedbackRow.user_id == resolved_user_id, + ) + result = await session.execute(stmt) + row = result.scalar_one_or_none() + if row is None: + return False + await session.delete(row) + await session.commit() + return True + + async def list_by_thread_grouped( + self, + thread_id: str, + *, + user_id: str | None | _AutoSentinel = AUTO, + ) -> dict[str, dict]: + """Return feedback grouped by run_id for a thread: {run_id: feedback_dict}.""" + resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread_grouped") + stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id) + if resolved_user_id is not None: + stmt = stmt.where(FeedbackRow.user_id == resolved_user_id) + async with self._sf() as session: + result = await session.execute(stmt) + return {row.run_id: self._row_to_dict(row) for row in result.scalars()} + async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict: """Aggregate feedback stats for a run using database-side counting.""" stmt = select( diff --git a/backend/packages/harness/deerflow/persistence/migrations/alembic.ini b/backend/packages/harness/deerflow/persistence/migrations/alembic.ini index adeccef32..71b4b1dc0 100644 --- a/backend/packages/harness/deerflow/persistence/migrations/alembic.ini +++ b/backend/packages/harness/deerflow/persistence/migrations/alembic.ini @@ -2,7 +2,7 @@ script_location = %(here)s # Default URL for offline mode / autogenerate. # Runtime uses engine from DeerFlow config. -sqlalchemy.url = sqlite+aiosqlite:///./data/app.db +sqlalchemy.url = sqlite+aiosqlite:///./data/deerflow.db [loggers] keys = root,sqlalchemy,alembic diff --git a/backend/packages/harness/deerflow/persistence/models/run_event.py b/backend/packages/harness/deerflow/persistence/models/run_event.py index 34f55ba03..4f22b4616 100644 --- a/backend/packages/harness/deerflow/persistence/models/run_event.py +++ b/backend/packages/harness/deerflow/persistence/models/run_event.py @@ -19,7 +19,7 @@ class RunEventRow(Base): # Owner of the conversation this event belongs to. Nullable for data # created before auth was introduced; populated by auth middleware on # new writes and by the boot-time orphan migration on existing rows. - owner_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True) + user_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True) event_type: Mapped[str] = mapped_column(String(32), nullable=False) category: Mapped[str] = mapped_column(String(16), nullable=False) # "message" | "trace" | "lifecycle" diff --git a/backend/packages/harness/deerflow/persistence/run/model.py b/backend/packages/harness/deerflow/persistence/run/model.py index 67396bc25..d0dfe4085 100644 --- a/backend/packages/harness/deerflow/persistence/run/model.py +++ b/backend/packages/harness/deerflow/persistence/run/model.py @@ -16,7 +16,7 @@ class RunRow(Base): run_id: Mapped[str] = mapped_column(String(64), primary_key=True) thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True) assistant_id: Mapped[str | None] = mapped_column(String(128)) - owner_id: Mapped[str | None] = mapped_column(String(64), index=True) + user_id: Mapped[str | None] = mapped_column(String(64), index=True) status: Mapped[str] = mapped_column(String(20), default="pending") # "pending" | "running" | "success" | "error" | "timeout" | "interrupted" diff --git a/backend/packages/harness/deerflow/persistence/run/sql.py b/backend/packages/harness/deerflow/persistence/run/sql.py index 5d8656509..fcd1a3411 100644 --- a/backend/packages/harness/deerflow/persistence/run/sql.py +++ b/backend/packages/harness/deerflow/persistence/run/sql.py @@ -16,7 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from deerflow.persistence.run.model import RunRow from deerflow.runtime.runs.store.base import RunStore -from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id +from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id class RunRepository(RunStore): @@ -69,7 +69,7 @@ class RunRepository(RunStore): *, thread_id, assistant_id=None, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, status="pending", multitask_strategy="reject", metadata=None, @@ -78,13 +78,13 @@ class RunRepository(RunStore): created_at=None, follow_up_to_run_id=None, ): - resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.put") + resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put") now = datetime.now(UTC) row = RunRow( run_id=run_id, thread_id=thread_id, assistant_id=assistant_id, - owner_id=resolved_owner_id, + user_id=resolved_user_id, status=status, multitask_strategy=multitask_strategy, metadata_json=self._safe_json(metadata) or {}, @@ -102,14 +102,14 @@ class RunRepository(RunStore): self, run_id, *, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, ): - resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.get") + resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.get") async with self._sf() as session: row = await session.get(RunRow, run_id) if row is None: return None - if resolved_owner_id is not None and row.owner_id != resolved_owner_id: + if resolved_user_id is not None and row.user_id != resolved_user_id: return None return self._row_to_dict(row) @@ -117,13 +117,13 @@ class RunRepository(RunStore): self, thread_id, *, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, limit=100, ): - resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.list_by_thread") + resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.list_by_thread") stmt = select(RunRow).where(RunRow.thread_id == thread_id) - if resolved_owner_id is not None: - stmt = stmt.where(RunRow.owner_id == resolved_owner_id) + if resolved_user_id is not None: + stmt = stmt.where(RunRow.user_id == resolved_user_id) stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit) async with self._sf() as session: result = await session.execute(stmt) @@ -141,14 +141,14 @@ class RunRepository(RunStore): self, run_id, *, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, ): - resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.delete") + resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.delete") async with self._sf() as session: row = await session.get(RunRow, run_id) if row is None: return - if resolved_owner_id is not None and row.owner_id != resolved_owner_id: + if resolved_user_id is not None and row.user_id != resolved_user_id: return await session.delete(row) await session.commit() diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py b/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py index 8e497bb7e..080ce8093 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py @@ -1,13 +1,38 @@ """Thread metadata persistence — ORM, abstract store, and concrete implementations.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + from deerflow.persistence.thread_meta.base import ThreadMetaStore from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore from deerflow.persistence.thread_meta.model import ThreadMetaRow from deerflow.persistence.thread_meta.sql import ThreadMetaRepository +if TYPE_CHECKING: + from langgraph.store.base import BaseStore + from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + __all__ = [ "MemoryThreadMetaStore", "ThreadMetaRepository", "ThreadMetaRow", "ThreadMetaStore", + "make_thread_store", ] + + +def make_thread_store( + session_factory: async_sessionmaker[AsyncSession] | None, + store: BaseStore | None = None, +) -> ThreadMetaStore: + """Create the appropriate ThreadMetaStore based on available backends. + + Returns a SQL-backed repository when a session factory is available, + otherwise falls back to the in-memory LangGraph Store implementation. + """ + if session_factory is not None: + return ThreadMetaRepository(session_factory) + if store is None: + raise ValueError("make_thread_store requires either a session_factory (SQL) or a store (memory)") + return MemoryThreadMetaStore(store) diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/base.py b/backend/packages/harness/deerflow/persistence/thread_meta/base.py index 466a82a21..c87c10a16 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/base.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/base.py @@ -3,12 +3,21 @@ Implementations: - ThreadMetaRepository: SQL-backed (sqlite / postgres via SQLAlchemy) - MemoryThreadMetaStore: wraps LangGraph BaseStore (memory mode) + +All mutating and querying methods accept a ``user_id`` parameter with +three-state semantics (see :mod:`deerflow.runtime.user_context`): + +- ``AUTO`` (default): resolve from the request-scoped contextvar. +- Explicit ``str``: use the provided value verbatim. +- Explicit ``None``: bypass owner filtering (migration/CLI only). """ from __future__ import annotations import abc +from deerflow.runtime.user_context import AUTO, _AutoSentinel + class ThreadMetaStore(abc.ABC): @abc.abstractmethod @@ -17,14 +26,14 @@ class ThreadMetaStore(abc.ABC): thread_id: str, *, assistant_id: str | None = None, - owner_id: str | None = None, + user_id: str | None | _AutoSentinel = AUTO, display_name: str | None = None, metadata: dict | None = None, ) -> dict: pass @abc.abstractmethod - async def get(self, thread_id: str) -> dict | None: + async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None: pass @abc.abstractmethod @@ -35,26 +44,33 @@ class ThreadMetaStore(abc.ABC): status: str | None = None, limit: int = 100, offset: int = 0, + user_id: str | None | _AutoSentinel = AUTO, ) -> list[dict]: pass @abc.abstractmethod - async def update_display_name(self, thread_id: str, display_name: str) -> None: + async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: pass @abc.abstractmethod - async def update_status(self, thread_id: str, status: str) -> None: + async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: pass @abc.abstractmethod - async def update_metadata(self, thread_id: str, metadata: dict) -> None: + async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None: """Merge ``metadata`` into the thread's metadata field. Existing keys are overwritten by the new values; keys absent from - ``metadata`` are preserved. No-op if the thread does not exist. + ``metadata`` are preserved. No-op if the thread does not exist + or the owner check fails. """ pass @abc.abstractmethod - async def delete(self, thread_id: str) -> None: + async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool: + """Check if ``user_id`` has access to ``thread_id``.""" + pass + + @abc.abstractmethod + async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: pass diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py index ab921f229..ccf59ad42 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py @@ -13,6 +13,7 @@ from typing import Any from langgraph.store.base import BaseStore from deerflow.persistence.thread_meta.base import ThreadMetaStore +from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id THREADS_NS: tuple[str, ...] = ("threads",) @@ -21,20 +22,37 @@ class MemoryThreadMetaStore(ThreadMetaStore): def __init__(self, store: BaseStore) -> None: self._store = store + async def _get_owned_record( + self, + thread_id: str, + user_id: str | None | _AutoSentinel, + method_name: str, + ) -> dict | None: + """Fetch a record and verify ownership. Returns a mutable copy, or None.""" + resolved = resolve_user_id(user_id, method_name=method_name) + item = await self._store.aget(THREADS_NS, thread_id) + if item is None: + return None + record = dict(item.value) + if resolved is not None and record.get("user_id") != resolved: + return None + return record + async def create( self, thread_id: str, *, assistant_id: str | None = None, - owner_id: str | None = None, + user_id: str | None | _AutoSentinel = AUTO, display_name: str | None = None, metadata: dict | None = None, ) -> dict: + resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.create") now = time.time() record: dict[str, Any] = { "thread_id": thread_id, "assistant_id": assistant_id, - "owner_id": owner_id, + "user_id": resolved_user_id, "display_name": display_name, "status": "idle", "metadata": metadata or {}, @@ -45,9 +63,8 @@ class MemoryThreadMetaStore(ThreadMetaStore): await self._store.aput(THREADS_NS, thread_id, record) return record - async def get(self, thread_id: str) -> dict | None: - item = await self._store.aget(THREADS_NS, thread_id) - return item.value if item is not None else None + async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None: + return await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.get") async def search( self, @@ -56,12 +73,16 @@ class MemoryThreadMetaStore(ThreadMetaStore): status: str | None = None, limit: int = 100, offset: int = 0, + user_id: str | None | _AutoSentinel = AUTO, ) -> list[dict]: + resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.search") filter_dict: dict[str, Any] = {} if metadata: filter_dict.update(metadata) if status: filter_dict["status"] = status + if resolved_user_id is not None: + filter_dict["user_id"] = resolved_user_id items = await self._store.asearch( THREADS_NS, @@ -71,37 +92,45 @@ class MemoryThreadMetaStore(ThreadMetaStore): ) return [self._item_to_dict(item) for item in items] - async def update_display_name(self, thread_id: str, display_name: str) -> None: + async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool: item = await self._store.aget(THREADS_NS, thread_id) if item is None: + return not require_existing + record_user_id = item.value.get("user_id") + if record_user_id is None: + return True + return record_user_id == user_id + + async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: + record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_display_name") + if record is None: return - record = dict(item.value) record["display_name"] = display_name record["updated_at"] = time.time() await self._store.aput(THREADS_NS, thread_id, record) - async def update_status(self, thread_id: str, status: str) -> None: - item = await self._store.aget(THREADS_NS, thread_id) - if item is None: + async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: + record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_status") + if record is None: return - record = dict(item.value) record["status"] = status record["updated_at"] = time.time() await self._store.aput(THREADS_NS, thread_id, record) - async def update_metadata(self, thread_id: str, metadata: dict) -> None: - """Merge ``metadata`` into the in-memory record. No-op if absent.""" - item = await self._store.aget(THREADS_NS, thread_id) - if item is None: + async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None: + record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_metadata") + if record is None: return - record = dict(item.value) merged = dict(record.get("metadata") or {}) merged.update(metadata) record["metadata"] = merged record["updated_at"] = time.time() await self._store.aput(THREADS_NS, thread_id, record) - async def delete(self, thread_id: str) -> None: + async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: + record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete") + if record is None: + return await self._store.adelete(THREADS_NS, thread_id) @staticmethod @@ -111,7 +140,7 @@ class MemoryThreadMetaStore(ThreadMetaStore): return { "thread_id": item.key, "assistant_id": val.get("assistant_id"), - "owner_id": val.get("owner_id"), + "user_id": val.get("user_id"), "display_name": val.get("display_name"), "status": val.get("status", "idle"), "metadata": val.get("metadata", {}), diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/model.py b/backend/packages/harness/deerflow/persistence/thread_meta/model.py index 34a209277..fe15315e1 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/model.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/model.py @@ -15,7 +15,7 @@ class ThreadMetaRow(Base): thread_id: Mapped[str] = mapped_column(String(64), primary_key=True) assistant_id: Mapped[str | None] = mapped_column(String(128), index=True) - owner_id: Mapped[str | None] = mapped_column(String(64), index=True) + user_id: Mapped[str | None] = mapped_column(String(64), index=True) display_name: Mapped[str | None] = mapped_column(String(256)) status: Mapped[str] = mapped_column(String(20), default="idle") metadata_json: Mapped[dict] = mapped_column(JSON, default=dict) diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py index 5a149e5d6..688fbb247 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py @@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from deerflow.persistence.thread_meta.base import ThreadMetaStore from deerflow.persistence.thread_meta.model import ThreadMetaRow -from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id +from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id class ThreadMetaRepository(ThreadMetaStore): @@ -32,18 +32,18 @@ class ThreadMetaRepository(ThreadMetaStore): thread_id: str, *, assistant_id: str | None = None, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, display_name: str | None = None, metadata: dict | None = None, ) -> dict: - # Auto-resolve owner_id from contextvar when AUTO; explicit None + # Auto-resolve user_id from contextvar when AUTO; explicit None # creates an orphan row (used by migration scripts). - resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.create") + resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.create") now = datetime.now(UTC) row = ThreadMetaRow( thread_id=thread_id, assistant_id=assistant_id, - owner_id=resolved_owner_id, + user_id=resolved_user_id, display_name=display_name, metadata_json=metadata or {}, created_at=now, @@ -59,40 +59,34 @@ class ThreadMetaRepository(ThreadMetaStore): self, thread_id: str, *, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, ) -> dict | None: - resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.get") + resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.get") async with self._sf() as session: row = await session.get(ThreadMetaRow, thread_id) if row is None: return None - # Enforce owner filter unless explicitly bypassed (owner_id=None). - if resolved_owner_id is not None and row.owner_id != resolved_owner_id: + # Enforce owner filter unless explicitly bypassed (user_id=None). + if resolved_user_id is not None and row.user_id != resolved_user_id: return None return self._row_to_dict(row) - async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]: - stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset) - async with self._sf() as session: - result = await session.execute(stmt) - return [self._row_to_dict(r) for r in result.scalars()] - - async def check_access(self, thread_id: str, owner_id: str, *, require_existing: bool = False) -> bool: - """Check if ``owner_id`` has access to ``thread_id``. + async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool: + """Check if ``user_id`` has access to ``thread_id``. Two modes — one row, two distinct semantics depending on what the caller is about to do: - ``require_existing=False`` (default, permissive): Returns True for: row missing (untracked legacy thread), - ``row.owner_id`` is None (shared / pre-auth data), - or ``row.owner_id == owner_id``. Use for **read-style** + ``row.user_id`` is None (shared / pre-auth data), + or ``row.user_id == user_id``. Use for **read-style** decorators where treating an untracked thread as accessible preserves backward-compat. - ``require_existing=True`` (strict): Returns True **only** when the row exists AND - (``row.owner_id == owner_id`` OR ``row.owner_id is None``). + (``row.user_id == user_id`` OR ``row.user_id is None``). Use for **destructive / mutating** decorators (DELETE, PATCH, state-update) so a thread that has *already been deleted* cannot be re-targeted by any caller — closing the @@ -103,9 +97,9 @@ class ThreadMetaRepository(ThreadMetaStore): row = await session.get(ThreadMetaRow, thread_id) if row is None: return not require_existing - if row.owner_id is None: + if row.user_id is None: return True - return row.owner_id == owner_id + return row.user_id == user_id async def search( self, @@ -114,17 +108,17 @@ class ThreadMetaRepository(ThreadMetaStore): status: str | None = None, limit: int = 100, offset: int = 0, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, ) -> list[dict]: """Search threads with optional metadata and status filters. Owner filter is enforced by default: caller must be in a user - context. Pass ``owner_id=None`` to bypass (migration/CLI). + context. Pass ``user_id=None`` to bypass (migration/CLI). """ - resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.search") + resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.search") stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc()) - if resolved_owner_id is not None: - stmt = stmt.where(ThreadMetaRow.owner_id == resolved_owner_id) + if resolved_user_id is not None: + stmt = stmt.where(ThreadMetaRow.user_id == resolved_user_id) if status: stmt = stmt.where(ThreadMetaRow.status == status) @@ -144,24 +138,24 @@ class ThreadMetaRepository(ThreadMetaStore): result = await session.execute(stmt) return [self._row_to_dict(r) for r in result.scalars()] - async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_owner_id: str | None) -> bool: + async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_user_id: str | None) -> bool: """Return True if the row exists and is owned (or filter bypassed).""" - if resolved_owner_id is None: + if resolved_user_id is None: return True # explicit bypass row = await session.get(ThreadMetaRow, thread_id) - return row is not None and row.owner_id == resolved_owner_id + return row is not None and row.user_id == resolved_user_id async def update_display_name( self, thread_id: str, display_name: str, *, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, ) -> None: """Update the display_name (title) for a thread.""" - resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_display_name") + resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_display_name") async with self._sf() as session: - if not await self._check_ownership(session, thread_id, resolved_owner_id): + if not await self._check_ownership(session, thread_id, resolved_user_id): return await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC))) await session.commit() @@ -171,11 +165,11 @@ class ThreadMetaRepository(ThreadMetaStore): thread_id: str, status: str, *, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, ) -> None: - resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_status") + resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_status") async with self._sf() as session: - if not await self._check_ownership(session, thread_id, resolved_owner_id): + if not await self._check_ownership(session, thread_id, resolved_user_id): return await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC))) await session.commit() @@ -185,20 +179,20 @@ class ThreadMetaRepository(ThreadMetaStore): thread_id: str, metadata: dict, *, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, ) -> None: """Merge ``metadata`` into ``metadata_json``. Read-modify-write inside a single session/transaction so concurrent callers see consistent state. No-op if the row does not exist or - the owner_id check fails. + the user_id check fails. """ - resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_metadata") + resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_metadata") async with self._sf() as session: row = await session.get(ThreadMetaRow, thread_id) if row is None: return - if resolved_owner_id is not None and row.owner_id != resolved_owner_id: + if resolved_user_id is not None and row.user_id != resolved_user_id: return merged = dict(row.metadata_json or {}) merged.update(metadata) @@ -210,14 +204,14 @@ class ThreadMetaRepository(ThreadMetaStore): self, thread_id: str, *, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, ) -> None: - resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.delete") + resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.delete") async with self._sf() as session: row = await session.get(ThreadMetaRow, thread_id) if row is None: return - if resolved_owner_id is not None and row.owner_id != resolved_owner_id: + if resolved_user_id is not None and row.user_id != resolved_user_id: return await session.delete(row) await session.commit() diff --git a/backend/packages/harness/deerflow/runtime/__init__.py b/backend/packages/harness/deerflow/runtime/__init__.py index d5faa9018..5a3df2eb6 100644 --- a/backend/packages/harness/deerflow/runtime/__init__.py +++ b/backend/packages/harness/deerflow/runtime/__init__.py @@ -5,12 +5,18 @@ Re-exports the public API of :mod:`~deerflow.runtime.runs` and directly from ``deerflow.runtime``. """ +from .checkpointer import checkpointer_context, get_checkpointer, make_checkpointer, reset_checkpointer from .runs import ConflictError, DisconnectMode, RunContext, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple from .store import get_store, make_store, reset_store, store_context from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge __all__ = [ + # checkpointer + "checkpointer_context", + "get_checkpointer", + "make_checkpointer", + "reset_checkpointer", # runs "ConflictError", "DisconnectMode", diff --git a/backend/packages/harness/deerflow/agents/checkpointer/__init__.py b/backend/packages/harness/deerflow/runtime/checkpointer/__init__.py similarity index 100% rename from backend/packages/harness/deerflow/agents/checkpointer/__init__.py rename to backend/packages/harness/deerflow/runtime/checkpointer/__init__.py diff --git a/backend/packages/harness/deerflow/agents/checkpointer/async_provider.py b/backend/packages/harness/deerflow/runtime/checkpointer/async_provider.py similarity index 96% rename from backend/packages/harness/deerflow/agents/checkpointer/async_provider.py rename to backend/packages/harness/deerflow/runtime/checkpointer/async_provider.py index 623333705..21c747b45 100644 --- a/backend/packages/harness/deerflow/agents/checkpointer/async_provider.py +++ b/backend/packages/harness/deerflow/runtime/checkpointer/async_provider.py @@ -7,12 +7,12 @@ Supported backends: memory, sqlite, postgres. Usage (e.g. FastAPI lifespan):: - from deerflow.agents.checkpointer.async_provider import make_checkpointer + from deerflow.runtime.checkpointer.async_provider import make_checkpointer async with make_checkpointer() as checkpointer: app.state.checkpointer = checkpointer # InMemorySaver if not configured -For sync usage see :mod:`deerflow.agents.checkpointer.provider`. +For sync usage see :mod:`deerflow.runtime.checkpointer.provider`. """ from __future__ import annotations @@ -24,12 +24,12 @@ from collections.abc import AsyncIterator from langgraph.types import Checkpointer -from deerflow.agents.checkpointer.provider import ( +from deerflow.config.app_config import get_app_config +from deerflow.runtime.checkpointer.provider import ( POSTGRES_CONN_REQUIRED, POSTGRES_INSTALL, SQLITE_INSTALL, ) -from deerflow.config.app_config import get_app_config from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str logger = logging.getLogger(__name__) diff --git a/backend/packages/harness/deerflow/agents/checkpointer/provider.py b/backend/packages/harness/deerflow/runtime/checkpointer/provider.py similarity index 98% rename from backend/packages/harness/deerflow/agents/checkpointer/provider.py rename to backend/packages/harness/deerflow/runtime/checkpointer/provider.py index 6f09aac94..59f8b1ab2 100644 --- a/backend/packages/harness/deerflow/agents/checkpointer/provider.py +++ b/backend/packages/harness/deerflow/runtime/checkpointer/provider.py @@ -7,7 +7,7 @@ Supported backends: memory, sqlite, postgres. Usage:: - from deerflow.agents.checkpointer.provider import get_checkpointer, checkpointer_context + from deerflow.runtime.checkpointer.provider import get_checkpointer, checkpointer_context # Singleton — reused across calls, closed on process exit cp = get_checkpointer() diff --git a/backend/packages/harness/deerflow/runtime/events/store/db.py b/backend/packages/harness/deerflow/runtime/events/store/db.py index 16252a26c..63328db43 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/db.py +++ b/backend/packages/harness/deerflow/runtime/events/store/db.py @@ -15,7 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from deerflow.persistence.models.run_event import RunEventRow from deerflow.runtime.events.store.base import RunEventStore -from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_owner_id +from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_user_id logger = logging.getLogger(__name__) @@ -55,16 +55,22 @@ class DbRunEventStore(RunEventStore): return content, metadata or {} @staticmethod - def _owner_from_context() -> str | None: - """Soft read of owner_id from contextvar for write paths. + def _user_id_from_context() -> str | None: + """Soft read of user_id from contextvar for write paths. Returns ``None`` (no filter / no stamp) if contextvar is unset, which is the expected case for background worker writes. HTTP request writes will have the contextvar set by auth middleware and get their user_id stamped automatically. + + Coerces ``user.id`` to ``str`` at the boundary: ``User.id`` is + typed as ``UUID`` by the auth layer, but ``run_events.user_id`` + is ``VARCHAR(64)`` and aiosqlite cannot bind a raw UUID object + to a VARCHAR column ("type 'UUID' is not supported") — the + INSERT would silently roll back and the worker would hang. """ user = get_current_user() - return user.id if user is not None else None + return str(user.id) if user is not None else None async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401 """Write a single event — low-frequency path only. @@ -81,7 +87,7 @@ class DbRunEventStore(RunEventStore): metadata = {**(metadata or {}), "content_is_dict": True} else: db_content = content - owner_id = self._owner_from_context() + user_id = self._user_id_from_context() async with self._sf() as session: async with session.begin(): # Use FOR UPDATE to serialize seq assignment within a thread. @@ -92,7 +98,7 @@ class DbRunEventStore(RunEventStore): row = RunEventRow( thread_id=thread_id, run_id=run_id, - owner_id=owner_id, + user_id=user_id, event_type=event_type, category=category, content=db_content, @@ -106,7 +112,7 @@ class DbRunEventStore(RunEventStore): async def put_batch(self, events): if not events: return [] - owner_id = self._owner_from_context() + user_id = self._user_id_from_context() async with self._sf() as session: async with session.begin(): # Get max seq for the thread (assume all events in batch belong to same thread). @@ -130,7 +136,7 @@ class DbRunEventStore(RunEventStore): row = RunEventRow( thread_id=e["thread_id"], run_id=e["run_id"], - owner_id=e.get("owner_id", owner_id), + user_id=e.get("user_id", user_id), event_type=e["event_type"], category=category, content=db_content, @@ -149,12 +155,12 @@ class DbRunEventStore(RunEventStore): limit=50, before_seq=None, after_seq=None, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, ): - resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages") + resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages") stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message") - if resolved_owner_id is not None: - stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id) + if resolved_user_id is not None: + stmt = stmt.where(RunEventRow.user_id == resolved_user_id) if before_seq is not None: stmt = stmt.where(RunEventRow.seq < before_seq) if after_seq is not None: @@ -181,12 +187,12 @@ class DbRunEventStore(RunEventStore): *, event_types=None, limit=500, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, ): - resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_events") + resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_events") stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id) - if resolved_owner_id is not None: - stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id) + if resolved_user_id is not None: + stmt = stmt.where(RunEventRow.user_id == resolved_user_id) if event_types: stmt = stmt.where(RunEventRow.event_type.in_(event_types)) stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit) @@ -199,12 +205,12 @@ class DbRunEventStore(RunEventStore): thread_id, run_id, *, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, ): - resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages_by_run") + resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages_by_run") stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message") - if resolved_owner_id is not None: - stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id) + if resolved_user_id is not None: + stmt = stmt.where(RunEventRow.user_id == resolved_user_id) stmt = stmt.order_by(RunEventRow.seq.asc()) async with self._sf() as session: result = await session.execute(stmt) @@ -214,12 +220,12 @@ class DbRunEventStore(RunEventStore): self, thread_id, *, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, ): - resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.count_messages") + resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.count_messages") stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message") - if resolved_owner_id is not None: - stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id) + if resolved_user_id is not None: + stmt = stmt.where(RunEventRow.user_id == resolved_user_id) async with self._sf() as session: return await session.scalar(stmt) or 0 @@ -227,13 +233,13 @@ class DbRunEventStore(RunEventStore): self, thread_id, *, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, ): - resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_thread") + resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_thread") async with self._sf() as session: count_conditions = [RunEventRow.thread_id == thread_id] - if resolved_owner_id is not None: - count_conditions.append(RunEventRow.owner_id == resolved_owner_id) + if resolved_user_id is not None: + count_conditions.append(RunEventRow.user_id == resolved_user_id) count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions) count = await session.scalar(count_stmt) or 0 if count > 0: @@ -246,13 +252,13 @@ class DbRunEventStore(RunEventStore): thread_id, run_id, *, - owner_id: str | None | _AutoSentinel = AUTO, + user_id: str | None | _AutoSentinel = AUTO, ): - resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_run") + resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_run") async with self._sf() as session: count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id] - if resolved_owner_id is not None: - count_conditions.append(RunEventRow.owner_id == resolved_owner_id) + if resolved_user_id is not None: + count_conditions.append(RunEventRow.user_id == resolved_user_id) count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions) count = await session.scalar(count_stmt) or 0 if count > 0: diff --git a/backend/packages/harness/deerflow/runtime/journal.py b/backend/packages/harness/deerflow/runtime/journal.py index b9aa019ad..a70404e11 100644 --- a/backend/packages/harness/deerflow/runtime/journal.py +++ b/backend/packages/harness/deerflow/runtime/journal.py @@ -50,6 +50,7 @@ class RunJournal(BaseCallbackHandler): # Write buffer self._buffer: list[dict] = [] + self._pending_flush_tasks: set[asyncio.Task[None]] = set() # Token accumulators self._total_input_tokens = 0 @@ -245,6 +246,19 @@ class RunJournal(BaseCallbackHandler): def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None: from langchain_core.messages import ToolMessage + from langgraph.types import Command + + # Tools that update graph state return a ``Command`` (e.g. + # ``present_files``). LangGraph later unwraps the inner ToolMessage + # into checkpoint state, so to stay checkpoint-aligned we must + # extract it here rather than storing ``str(Command(...))``. + if isinstance(output, Command): + update = getattr(output, "update", None) or {} + inner_msgs = update.get("messages") if isinstance(update, dict) else None + if isinstance(inner_msgs, list): + inner_tool_msg = next((m for m in inner_msgs if isinstance(m, ToolMessage)), None) + if inner_tool_msg is not None: + output = inner_tool_msg # Extract fields from ToolMessage object when LangChain provides one. # LangChain's _format_output wraps tool results into a ToolMessage @@ -381,6 +395,10 @@ class RunJournal(BaseCallbackHandler): """ if not self._buffer: return + # Skip if a flush is already in flight — avoids concurrent writes + # to the same SQLite file from multiple fire-and-forget tasks. + if self._pending_flush_tasks: + return try: loop = asyncio.get_running_loop() except RuntimeError: @@ -389,6 +407,7 @@ class RunJournal(BaseCallbackHandler): batch = self._buffer.copy() self._buffer.clear() task = loop.create_task(self._flush_async(batch)) + self._pending_flush_tasks.add(task) task.add_done_callback(self._on_flush_done) async def _flush_async(self, batch: list[dict]) -> None: @@ -404,8 +423,8 @@ class RunJournal(BaseCallbackHandler): # Return failed events to buffer for retry on next flush self._buffer = batch + self._buffer - @staticmethod - def _on_flush_done(task: asyncio.Task) -> None: + def _on_flush_done(self, task: asyncio.Task) -> None: + self._pending_flush_tasks.discard(task) if task.cancelled(): return exc = task.exception() @@ -450,10 +469,17 @@ class RunJournal(BaseCallbackHandler): async def flush(self) -> None: """Force flush remaining buffer. Called in worker's finally block.""" - if self._buffer: - batch = self._buffer.copy() - self._buffer.clear() - await self._store.put_batch(batch) + if self._pending_flush_tasks: + await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True) + + while self._buffer: + batch = self._buffer[: self._flush_threshold] + del self._buffer[: self._flush_threshold] + try: + await self._store.put_batch(batch) + except Exception: + self._buffer = batch + self._buffer + raise def get_completion_data(self) -> dict: """Return accumulated token and message data for run completion.""" diff --git a/backend/packages/harness/deerflow/runtime/runs/store/base.py b/backend/packages/harness/deerflow/runtime/runs/store/base.py index 9ba1caca3..3212e8ca3 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/base.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/base.py @@ -4,8 +4,8 @@ RunManager depends on this interface. Implementations: - MemoryRunStore: in-memory dict (development, tests) - Future: RunRepository backed by SQLAlchemy ORM -All methods accept an optional owner_id for user isolation. -When owner_id is None, no user filtering is applied (single-user mode). +All methods accept an optional user_id for user isolation. +When user_id is None, no user filtering is applied (single-user mode). """ from __future__ import annotations @@ -22,7 +22,7 @@ class RunStore(abc.ABC): *, thread_id: str, assistant_id: str | None = None, - owner_id: str | None = None, + user_id: str | None = None, status: str = "pending", multitask_strategy: str = "reject", metadata: dict[str, Any] | None = None, @@ -42,7 +42,7 @@ class RunStore(abc.ABC): self, thread_id: str, *, - owner_id: str | None = None, + user_id: str | None = None, limit: int = 100, ) -> list[dict[str, Any]]: pass diff --git a/backend/packages/harness/deerflow/runtime/runs/store/memory.py b/backend/packages/harness/deerflow/runtime/runs/store/memory.py index 6f9d1dfb4..0b2b05f07 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/memory.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/memory.py @@ -21,7 +21,7 @@ class MemoryRunStore(RunStore): *, thread_id, assistant_id=None, - owner_id=None, + user_id=None, status="pending", multitask_strategy="reject", metadata=None, @@ -35,7 +35,7 @@ class MemoryRunStore(RunStore): "run_id": run_id, "thread_id": thread_id, "assistant_id": assistant_id, - "owner_id": owner_id, + "user_id": user_id, "status": status, "multitask_strategy": multitask_strategy, "metadata": metadata or {}, @@ -49,8 +49,8 @@ class MemoryRunStore(RunStore): async def get(self, run_id): return self._runs.get(run_id) - async def list_by_thread(self, thread_id, *, owner_id=None, limit=100): - results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (owner_id is None or r.get("owner_id") == owner_id)] + async def list_by_thread(self, thread_id, *, user_id=None, limit=100): + results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (user_id is None or r.get("user_id") == user_id)] results.sort(key=lambda r: r["created_at"], reverse=True) return results[:limit] diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index efa306b0b..74581a275 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -50,7 +50,7 @@ class RunContext: store: Any | None = field(default=None) event_store: Any | None = field(default=None) run_events_config: Any | None = field(default=None) - thread_meta_repo: Any | None = field(default=None) + thread_store: Any | None = field(default=None) follow_up_to_run_id: str | None = field(default=None) @@ -75,7 +75,7 @@ async def run_agent( store = ctx.store event_store = ctx.event_store run_events_config = ctx.run_events_config - thread_meta_repo = ctx.thread_meta_repo + thread_store = ctx.thread_store follow_up_to_run_id = ctx.follow_up_to_run_id run_id = record.run_id @@ -85,63 +85,7 @@ async def run_agent( pre_run_snapshot: dict[str, Any] | None = None snapshot_capture_failed = False - # Initialize RunJournal for event capture journal = None - if event_store is not None: - from deerflow.runtime.journal import RunJournal - - journal = RunJournal( - run_id=run_id, - thread_id=thread_id, - event_store=event_store, - track_token_usage=getattr(run_events_config, "track_token_usage", True), - ) - - # Write human_message event (model_dump format, aligned with checkpoint) - human_msg = _extract_human_message(graph_input) - if human_msg is not None: - msg_metadata = {} - if follow_up_to_run_id: - msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id - await event_store.put( - thread_id=thread_id, - run_id=run_id, - event_type="human_message", - category="message", - content=human_msg.model_dump(), - metadata=msg_metadata or None, - ) - content = human_msg.content - journal.set_first_human_message(content if isinstance(content, str) else str(content)) - - # Initialize RunJournal for event capture - journal = None - if event_store is not None: - from deerflow.runtime.journal import RunJournal - - journal = RunJournal( - run_id=run_id, - thread_id=thread_id, - event_store=event_store, - track_token_usage=getattr(run_events_config, "track_token_usage", True), - ) - - # Write human_message event (model_dump format, aligned with checkpoint) - human_msg = _extract_human_message(graph_input) - if human_msg is not None: - msg_metadata = {} - if follow_up_to_run_id: - msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id - await event_store.put( - thread_id=thread_id, - run_id=run_id, - event_type="human_message", - category="message", - content=human_msg.model_dump(), - metadata=msg_metadata or None, - ) - content = human_msg.content - journal.set_first_human_message(content if isinstance(content, str) else str(content)) # Track whether "events" was requested but skipped if "events" in requested_modes: @@ -151,6 +95,38 @@ async def run_agent( ) try: + # Initialize RunJournal + write human_message event. + # These are inside the try block so any exception (e.g. a DB + # error writing the event) flows through the except/finally + # path that publishes an "end" event to the SSE bridge — + # otherwise a failure here would leave the stream hanging + # with no terminator. + if event_store is not None: + from deerflow.runtime.journal import RunJournal + + journal = RunJournal( + run_id=run_id, + thread_id=thread_id, + event_store=event_store, + track_token_usage=getattr(run_events_config, "track_token_usage", True), + ) + + human_msg = _extract_human_message(graph_input) + if human_msg is not None: + msg_metadata = {} + if follow_up_to_run_id: + msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id + await event_store.put( + thread_id=thread_id, + run_id=run_id, + event_type="human_message", + category="message", + content=human_msg.model_dump(), + metadata=msg_metadata or None, + ) + content = human_msg.content + journal.set_first_human_message(content if isinstance(content, str) else str(content)) + # 1. Mark running await run_manager.set_status(run_id, RunStatus.running) @@ -334,12 +310,15 @@ async def run_agent( except Exception: logger.warning("Failed to flush journal for run %s", run_id, exc_info=True) - # Persist token usage + convenience fields to RunStore - completion = journal.get_completion_data() - await run_manager.update_run_completion(run_id, status=record.status.value, **completion) + try: + # Persist token usage + convenience fields to RunStore + completion = journal.get_completion_data() + await run_manager.update_run_completion(run_id, status=record.status.value, **completion) + except Exception: + logger.warning("Failed to persist run completion for %s (non-fatal)", run_id, exc_info=True) # Sync title from checkpoint to threads_meta.display_name - if checkpointer is not None: + if checkpointer is not None and thread_store is not None: try: ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} ckpt_tuple = await checkpointer.aget_tuple(ckpt_config) @@ -347,16 +326,17 @@ async def run_agent( ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {} title = ckpt.get("channel_values", {}).get("title") if title: - await thread_meta_repo.update_display_name(thread_id, title) + await thread_store.update_display_name(thread_id, title) except Exception: logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id) # Update threads_meta status based on run outcome - try: - final_status = "idle" if record.status == RunStatus.success else record.status.value - await thread_meta_repo.update_status(thread_id, final_status) - except Exception: - logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id) + if thread_store is not None: + try: + final_status = "idle" if record.status == RunStatus.success else record.status.value + await thread_store.update_status(thread_id, final_status) + except Exception: + logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id) await bridge.publish_end(run_id) asyncio.create_task(bridge.cleanup(run_id, delay=60)) diff --git a/backend/packages/harness/deerflow/runtime/store/async_provider.py b/backend/packages/harness/deerflow/runtime/store/async_provider.py index bc7a60559..68cd107c8 100644 --- a/backend/packages/harness/deerflow/runtime/store/async_provider.py +++ b/backend/packages/harness/deerflow/runtime/store/async_provider.py @@ -91,7 +91,7 @@ async def make_store() -> AsyncIterator[BaseStore]: configured checkpointer. Reads from the same ``checkpointer`` section of *config.yaml* used by - :func:`deerflow.agents.checkpointer.async_provider.make_checkpointer` so + :func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer` so that both singletons always use the same persistence technology:: async with make_store() as store: diff --git a/backend/packages/harness/deerflow/runtime/stream_bridge/async_provider.py b/backend/packages/harness/deerflow/runtime/stream_bridge/async_provider.py index 891f79fa0..f35b7d639 100644 --- a/backend/packages/harness/deerflow/runtime/stream_bridge/async_provider.py +++ b/backend/packages/harness/deerflow/runtime/stream_bridge/async_provider.py @@ -1,7 +1,7 @@ """Async stream bridge factory. Provides an **async context manager** aligned with -:func:`deerflow.agents.checkpointer.async_provider.make_checkpointer`. +:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer`. Usage (e.g. FastAPI lifespan):: diff --git a/backend/packages/harness/deerflow/runtime/user_context.py b/backend/packages/harness/deerflow/runtime/user_context.py index 07ffbb744..33fce65d5 100644 --- a/backend/packages/harness/deerflow/runtime/user_context.py +++ b/backend/packages/harness/deerflow/runtime/user_context.py @@ -1,11 +1,11 @@ -"""Request-scoped user context for owner-based authorization. +"""Request-scoped user context for user-based authorization. This module holds a :class:`~contextvars.ContextVar` that the gateway's auth middleware sets after a successful authentication. Repository methods read the contextvar via a sentinel default parameter, letting -routers stay free of ``owner_id`` boilerplate. +routers stay free of ``user_id`` boilerplate. -Three-state semantics for the repository ``owner_id`` parameter (the +Three-state semantics for the repository ``user_id`` parameter (the consumer side of this module lives in ``deerflow.persistence.*``): - ``_AUTO`` (module-private sentinel, default): read from contextvar; @@ -91,16 +91,16 @@ def require_current_user() -> CurrentUser: # --------------------------------------------------------------------------- -# Sentinel-based owner_id resolution +# Sentinel-based user_id resolution # --------------------------------------------------------------------------- # -# Repository methods accept an ``owner_id`` keyword-only argument that +# Repository methods accept a ``user_id`` keyword-only argument that # defaults to ``AUTO``. The three possible values drive distinct -# behaviours; see the docstring on :func:`resolve_owner_id`. +# behaviours; see the docstring on :func:`resolve_user_id`. class _AutoSentinel: - """Singleton marker meaning 'resolve owner_id from contextvar'.""" + """Singleton marker meaning 'resolve user_id from contextvar'.""" _instance: _AutoSentinel | None = None @@ -116,12 +116,12 @@ class _AutoSentinel: AUTO: Final[_AutoSentinel] = _AutoSentinel() -def resolve_owner_id( +def resolve_user_id( value: str | None | _AutoSentinel, *, method_name: str = "repository method", ) -> str | None: - """Resolve the owner_id parameter passed to a repository method. + """Resolve the user_id parameter passed to a repository method. Three-state semantics: @@ -131,16 +131,16 @@ def resolve_owner_id( - Explicit ``str``: use the provided id verbatim, overriding any contextvar value. Useful for tests and admin-override flows. - Explicit ``None``: no filter — the repository should skip the - owner_id WHERE clause entirely. Reserved for migration scripts + user_id WHERE clause entirely. Reserved for migration scripts and CLI tools that intentionally bypass isolation. """ if isinstance(value, _AutoSentinel): user = _current_user.get() if user is None: - raise RuntimeError(f"{method_name} called with owner_id=AUTO but no user context is set; pass an explicit owner_id, set the contextvar via auth middleware, or opt out with owner_id=None for migration/CLI paths.") + raise RuntimeError(f"{method_name} called with user_id=AUTO but no user context is set; pass an explicit user_id, set the contextvar via auth middleware, or opt out with user_id=None for migration/CLI paths.") # Coerce to ``str`` at the boundary: ``User.id`` is typed as # ``UUID`` for the API surface, but the persistence layer - # stores ``owner_id`` as ``String(64)`` and aiosqlite cannot + # stores ``user_id`` as ``String(64)`` and aiosqlite cannot # bind a raw UUID object to a VARCHAR column ("type 'UUID' is # not supported"). Honour the documented return type here # rather than ripple a type change through every caller. diff --git a/backend/tests/_router_auth_helpers.py b/backend/tests/_router_auth_helpers.py index e48d01146..a7ce60468 100644 --- a/backend/tests/_router_auth_helpers.py +++ b/backend/tests/_router_auth_helpers.py @@ -3,16 +3,16 @@ The production gateway runs ``AuthMiddleware`` (validates the JWT cookie) ahead of every router, plus ``@require_permission(owner_check=True)`` decorators that read ``request.state.auth`` and call -``thread_meta_repo.check_access``. Router-level unit tests construct +``thread_store.check_access``. Router-level unit tests construct **bare** FastAPI apps that include only one router — they have neither -the auth middleware nor a real thread_meta_repo, so the decorators raise +the auth middleware nor a real thread_store, so the decorators raise 401 (TestClient path) or ValueError (direct-call path). This module provides two surfaces: 1. :func:`make_authed_test_app` — wraps ``FastAPI()`` with a tiny ``BaseHTTPMiddleware`` that stamps a fake user / AuthContext on every - request, plus a permissive ``thread_meta_repo`` mock on + request, plus a permissive ``thread_store`` mock on ``app.state``. Use from TestClient-based router tests. 2. :func:`call_unwrapped` — invokes the underlying function bypassing @@ -86,20 +86,20 @@ def make_authed_test_app( user_factory: Callable[[], User] | None = None, owner_check_passes: bool = True, ) -> FastAPI: - """Build a FastAPI test app with stub auth + permissive thread_meta_repo. + """Build a FastAPI test app with stub auth + permissive thread_store. Args: user_factory: Override the default test user. Must return a fully populated :class:`User`. Useful for cross-user isolation tests that need a stable id across requests. - owner_check_passes: When True (default), ``thread_meta_repo.check_access`` + owner_check_passes: When True (default), ``thread_store.check_access`` returns True for every call so ``@require_permission(owner_check=True)`` never blocks the route under test. Pass False to verify that permission failures surface correctly. Returns: A ``FastAPI`` app with the stub middleware installed and - ``app.state.thread_meta_repo`` set to a permissive mock. The + ``app.state.thread_store`` set to a permissive mock. The caller is still responsible for ``app.include_router(...)``. """ factory = user_factory or _make_stub_user @@ -108,7 +108,7 @@ def make_authed_test_app( repo = MagicMock() repo.check_access = AsyncMock(return_value=owner_check_passes) - app.state.thread_meta_repo = repo + app.state.thread_store = repo return app diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 7fa0bf08e..d48630f37 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -60,7 +60,7 @@ def provisioner_module(): # Auto-set user context for every test unless marked no_auto_user # --------------------------------------------------------------------------- # -# Repository methods read ``owner_id`` from a contextvar by default +# Repository methods read ``user_id`` from a contextvar by default # (see ``deerflow.runtime.user_context``). Without this fixture, every # pre-existing persistence test would raise RuntimeError because the # contextvar is unset. The fixture sets a default test user on every diff --git a/backend/tests/test_checkpointer.py b/backend/tests/test_checkpointer.py index 79a4912d9..58f57237e 100644 --- a/backend/tests/test_checkpointer.py +++ b/backend/tests/test_checkpointer.py @@ -6,13 +6,13 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest import deerflow.config.app_config as app_config_module -from deerflow.agents.checkpointer import get_checkpointer, reset_checkpointer from deerflow.config.checkpointer_config import ( CheckpointerConfig, get_checkpointer_config, load_checkpointer_config_from_dict, set_checkpointer_config, ) +from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer @pytest.fixture(autouse=True) @@ -78,7 +78,7 @@ class TestGetCheckpointer: """get_checkpointer should return InMemorySaver when not configured.""" from langgraph.checkpoint.memory import InMemorySaver - with patch("deerflow.agents.checkpointer.provider.get_app_config", side_effect=FileNotFoundError): + with patch("deerflow.runtime.checkpointer.provider.get_app_config", side_effect=FileNotFoundError): cp = get_checkpointer() assert cp is not None assert isinstance(cp, InMemorySaver) @@ -178,7 +178,7 @@ class TestAsyncCheckpointer: @pytest.mark.anyio async def test_sqlite_creates_parent_dir_via_to_thread(self): """Async SQLite setup should move mkdir off the event loop.""" - from deerflow.agents.checkpointer.async_provider import make_checkpointer + from deerflow.runtime.checkpointer.async_provider import make_checkpointer mock_config = MagicMock() mock_config.checkpointer = CheckpointerConfig(type="sqlite", connection_string="relative/test.db") @@ -195,11 +195,11 @@ class TestAsyncCheckpointer: mock_module.AsyncSqliteSaver = mock_saver_cls with ( - patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config), + patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config), patch.dict(sys.modules, {"langgraph.checkpoint.sqlite.aio": mock_module}), - patch("deerflow.agents.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread, + patch("deerflow.runtime.checkpointer.async_provider.asyncio.to_thread", new_callable=AsyncMock) as mock_to_thread, patch( - "deerflow.agents.checkpointer.async_provider.resolve_sqlite_conn_str", + "deerflow.runtime.checkpointer.async_provider.resolve_sqlite_conn_str", return_value="/tmp/resolved/test.db", ), ): diff --git a/backend/tests/test_checkpointer_none_fix.py b/backend/tests/test_checkpointer_none_fix.py index 1da435c85..3c7a25fa1 100644 --- a/backend/tests/test_checkpointer_none_fix.py +++ b/backend/tests/test_checkpointer_none_fix.py @@ -12,14 +12,14 @@ class TestCheckpointerNoneFix: @pytest.mark.anyio async def test_async_make_checkpointer_returns_in_memory_saver_when_not_configured(self): """make_checkpointer should return InMemorySaver when config.checkpointer is None.""" - from deerflow.agents.checkpointer.async_provider import make_checkpointer + from deerflow.runtime.checkpointer.async_provider import make_checkpointer # Mock get_app_config to return a config with checkpointer=None and database=None mock_config = MagicMock() mock_config.checkpointer = None mock_config.database = None - with patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config): + with patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config): async with make_checkpointer() as checkpointer: # Should return InMemorySaver, not None assert checkpointer is not None @@ -36,13 +36,13 @@ class TestCheckpointerNoneFix: def test_sync_checkpointer_context_returns_in_memory_saver_when_not_configured(self): """checkpointer_context should return InMemorySaver when config.checkpointer is None.""" - from deerflow.agents.checkpointer.provider import checkpointer_context + from deerflow.runtime.checkpointer.provider import checkpointer_context # Mock get_app_config to return a config with checkpointer=None mock_config = MagicMock() mock_config.checkpointer = None - with patch("deerflow.agents.checkpointer.provider.get_app_config", return_value=mock_config): + with patch("deerflow.runtime.checkpointer.provider.get_app_config", return_value=mock_config): with checkpointer_context() as checkpointer: # Should return InMemorySaver, not None assert checkpointer is not None diff --git a/backend/tests/test_client.py b/backend/tests/test_client.py index a6d2ebfb3..a9b854e8e 100644 --- a/backend/tests/test_client.py +++ b/backend/tests/test_client.py @@ -817,7 +817,7 @@ class TestEnsureAgent: patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares, patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt, patch.object(client, "_get_tools", return_value=[]), - patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()), + patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()), ): client._agent_name = "custom-agent" client._available_skills = {"test_skill"} @@ -842,7 +842,7 @@ class TestEnsureAgent: patch("deerflow.client._build_middlewares", return_value=[]), patch("deerflow.client.apply_prompt_template", return_value="prompt"), patch.object(client, "_get_tools", return_value=[]), - patch("deerflow.agents.checkpointer.get_checkpointer", return_value=mock_checkpointer), + patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=mock_checkpointer), ): client._ensure_agent(config) @@ -867,7 +867,7 @@ class TestEnsureAgent: patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares), patch("deerflow.client.apply_prompt_template", return_value="prompt"), patch.object(client, "_get_tools", return_value=[]), - patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()), + patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()), ): client._ensure_agent(config) @@ -886,7 +886,7 @@ class TestEnsureAgent: patch("deerflow.client._build_middlewares", return_value=[]), patch("deerflow.client.apply_prompt_template", return_value="prompt"), patch.object(client, "_get_tools", return_value=[]), - patch("deerflow.agents.checkpointer.get_checkpointer", return_value=None), + patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=None), ): client._ensure_agent(config) @@ -1015,7 +1015,7 @@ class TestThreadQueries: mock_checkpointer = MagicMock() mock_checkpointer.list.return_value = [] - with patch("deerflow.agents.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer): + with patch("deerflow.runtime.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer): # No internal checkpointer, should fetch from provider result = client.list_threads() @@ -1069,7 +1069,7 @@ class TestThreadQueries: mock_checkpointer = MagicMock() mock_checkpointer.list.return_value = [] - with patch("deerflow.agents.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer): + with patch("deerflow.runtime.checkpointer.provider.get_checkpointer", return_value=mock_checkpointer): result = client.get_thread("t99") assert result["thread_id"] == "t99" @@ -1844,7 +1844,7 @@ class TestScenarioAgentRecreation: patch("deerflow.client._build_middlewares", return_value=[]), patch("deerflow.client.apply_prompt_template", return_value="prompt"), patch.object(client, "_get_tools", return_value=[]), - patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()), + patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()), ): client._ensure_agent(config_a) first_agent = client._agent @@ -1872,7 +1872,7 @@ class TestScenarioAgentRecreation: patch("deerflow.client._build_middlewares", return_value=[]), patch("deerflow.client.apply_prompt_template", return_value="prompt"), patch.object(client, "_get_tools", return_value=[]), - patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()), + patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()), ): client._ensure_agent(config) client._ensure_agent(config) @@ -1897,7 +1897,7 @@ class TestScenarioAgentRecreation: patch("deerflow.client._build_middlewares", return_value=[]), patch("deerflow.client.apply_prompt_template", return_value="prompt"), patch.object(client, "_get_tools", return_value=[]), - patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()), + patch("deerflow.runtime.checkpointer.get_checkpointer", return_value=MagicMock()), ): client._ensure_agent(config) client.reset_agent() diff --git a/backend/tests/test_ensure_admin.py b/backend/tests/test_ensure_admin.py index 5079c4cfb..f787b2545 100644 --- a/backend/tests/test_ensure_admin.py +++ b/backend/tests/test_ensure_admin.py @@ -199,12 +199,12 @@ def test_migration_failure_is_non_fatal(): # ── Section 5.1-5.6 upgrade path: orphan thread migration ──────────────── -def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows(): +def test_migrate_orphaned_threads_stamps_user_id_on_unowned_rows(): """First boot finds Store-only legacy threads → stamps admin's id. Validates the **TC-UPG-02 upgrade story**: an operator running main (no auth) accumulates threads in the LangGraph Store namespace - ``("threads",)`` with no ``metadata.owner_id``. After upgrading to + ``("threads",)`` with no ``metadata.user_id``. After upgrading to feat/auth-on-2.0-rc, the first ``_ensure_admin_user`` boot should rewrite each unowned item with the freshly created admin's id. """ @@ -215,7 +215,7 @@ def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows(): SimpleNamespace(key="t1", value={"metadata": {"title": "old-thread-1"}}), SimpleNamespace(key="t2", value={"metadata": {"title": "old-thread-2"}}), SimpleNamespace(key="t3", value={"metadata": {}}), - SimpleNamespace(key="t4", value={"metadata": {"owner_id": "someone-else", "title": "preserved"}}), + SimpleNamespace(key="t4", value={"metadata": {"user_id": "someone-else", "title": "preserved"}}), ] store = AsyncMock() # asearch returns the entire batch on first call, then an empty page @@ -235,11 +235,11 @@ def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows(): assert len(aput_calls) == 3 rewritten_keys = {call[1] for call in aput_calls} assert rewritten_keys == {"t1", "t2", "t3"} - # Each rewrite carries the new owner_id; titles preserved where present. + # Each rewrite carries the new user_id; titles preserved where present. by_key = {call[1]: call[2] for call in aput_calls} - assert by_key["t1"]["metadata"]["owner_id"] == "admin-id-42" + assert by_key["t1"]["metadata"]["user_id"] == "admin-id-42" assert by_key["t1"]["metadata"]["title"] == "old-thread-1" - assert by_key["t3"]["metadata"]["owner_id"] == "admin-id-42" + assert by_key["t3"]["metadata"]["user_id"] == "admin-id-42" # The pre-owned item must NOT have been rewritten. assert "t4" not in rewritten_keys diff --git a/backend/tests/test_feedback.py b/backend/tests/test_feedback.py index ed6c09f44..a592bdd22 100644 --- a/backend/tests/test_feedback.py +++ b/backend/tests/test_feedback.py @@ -60,8 +60,8 @@ class TestFeedbackRepository: @pytest.mark.anyio async def test_create_with_owner(self, tmp_path): repo = await _make_feedback_repo(tmp_path) - record = await repo.create(run_id="r1", thread_id="t1", rating=1, owner_id="user-1") - assert record["owner_id"] == "user-1" + record = await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1") + assert record["user_id"] == "user-1" await _cleanup() @pytest.mark.anyio @@ -97,10 +97,10 @@ class TestFeedbackRepository: @pytest.mark.anyio async def test_list_by_run(self, tmp_path): repo = await _make_feedback_repo(tmp_path) - await repo.create(run_id="r1", thread_id="t1", rating=1) - await repo.create(run_id="r1", thread_id="t1", rating=-1) - await repo.create(run_id="r2", thread_id="t1", rating=1) - results = await repo.list_by_run("t1", "r1") + await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1") + await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-2") + await repo.create(run_id="r2", thread_id="t1", rating=1, user_id="user-1") + results = await repo.list_by_run("t1", "r1", user_id=None) assert len(results) == 2 assert all(r["run_id"] == "r1" for r in results) await _cleanup() @@ -135,9 +135,9 @@ class TestFeedbackRepository: @pytest.mark.anyio async def test_aggregate_by_run(self, tmp_path): repo = await _make_feedback_repo(tmp_path) - await repo.create(run_id="r1", thread_id="t1", rating=1) - await repo.create(run_id="r1", thread_id="t1", rating=1) - await repo.create(run_id="r1", thread_id="t1", rating=-1) + await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1") + await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-2") + await repo.create(run_id="r1", thread_id="t1", rating=-1, user_id="user-3") stats = await repo.aggregate_by_run("t1", "r1") assert stats["total"] == 3 assert stats["positive"] == 2 @@ -154,6 +154,80 @@ class TestFeedbackRepository: assert stats["negative"] == 0 await _cleanup() + @pytest.mark.anyio + async def test_upsert_creates_new(self, tmp_path): + repo = await _make_feedback_repo(tmp_path) + record = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1") + assert record["rating"] == 1 + assert record["feedback_id"] + assert record["user_id"] == "u1" + await _cleanup() + + @pytest.mark.anyio + async def test_upsert_updates_existing(self, tmp_path): + repo = await _make_feedback_repo(tmp_path) + first = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1") + second = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u1", comment="changed my mind") + assert second["feedback_id"] == first["feedback_id"] + assert second["rating"] == -1 + assert second["comment"] == "changed my mind" + await _cleanup() + + @pytest.mark.anyio + async def test_upsert_different_users_separate(self, tmp_path): + repo = await _make_feedback_repo(tmp_path) + r1 = await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1") + r2 = await repo.upsert(run_id="r1", thread_id="t1", rating=-1, user_id="u2") + assert r1["feedback_id"] != r2["feedback_id"] + assert r1["rating"] == 1 + assert r2["rating"] == -1 + await _cleanup() + + @pytest.mark.anyio + async def test_upsert_invalid_rating(self, tmp_path): + repo = await _make_feedback_repo(tmp_path) + with pytest.raises(ValueError): + await repo.upsert(run_id="r1", thread_id="t1", rating=0, user_id="u1") + await _cleanup() + + @pytest.mark.anyio + async def test_delete_by_run(self, tmp_path): + repo = await _make_feedback_repo(tmp_path) + await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1") + deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1") + assert deleted is True + results = await repo.list_by_run("t1", "r1", user_id="u1") + assert len(results) == 0 + await _cleanup() + + @pytest.mark.anyio + async def test_delete_by_run_nonexistent(self, tmp_path): + repo = await _make_feedback_repo(tmp_path) + deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1") + assert deleted is False + await _cleanup() + + @pytest.mark.anyio + async def test_list_by_thread_grouped(self, tmp_path): + repo = await _make_feedback_repo(tmp_path) + await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1") + await repo.upsert(run_id="r2", thread_id="t1", rating=-1, user_id="u1") + await repo.upsert(run_id="r3", thread_id="t2", rating=1, user_id="u1") + grouped = await repo.list_by_thread_grouped("t1", user_id="u1") + assert "r1" in grouped + assert "r2" in grouped + assert "r3" not in grouped + assert grouped["r1"]["rating"] == 1 + assert grouped["r2"]["rating"] == -1 + await _cleanup() + + @pytest.mark.anyio + async def test_list_by_thread_grouped_empty(self, tmp_path): + repo = await _make_feedback_repo(tmp_path) + grouped = await repo.list_by_thread_grouped("t1", user_id="u1") + assert grouped == {} + await _cleanup() + # -- Follow-up association -- diff --git a/backend/tests/test_langgraph_auth.py b/backend/tests/test_langgraph_auth.py index 41fbd0340..52d215751 100644 --- a/backend/tests/test_langgraph_auth.py +++ b/backend/tests/test_langgraph_auth.py @@ -175,46 +175,46 @@ def _make_ctx(user_id): def test_filter_injects_user_id(): value = {} asyncio.run(add_owner_filter(_make_ctx("user-a"), value)) - assert value["metadata"]["owner_id"] == "user-a" + assert value["metadata"]["user_id"] == "user-a" def test_filter_preserves_existing_metadata(): value = {"metadata": {"title": "hello"}} asyncio.run(add_owner_filter(_make_ctx("user-a"), value)) - assert value["metadata"]["owner_id"] == "user-a" + assert value["metadata"]["user_id"] == "user-a" assert value["metadata"]["title"] == "hello" def test_filter_returns_user_id_dict(): result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {})) - assert result == {"owner_id": "user-x"} + assert result == {"user_id": "user-x"} def test_filter_read_write_consistency(): value = {} filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value)) - assert value["metadata"]["owner_id"] == filter_dict["owner_id"] + assert value["metadata"]["user_id"] == filter_dict["user_id"] def test_different_users_different_filters(): f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {})) f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {})) - assert f_a["owner_id"] != f_b["owner_id"] + assert f_a["user_id"] != f_b["user_id"] def test_filter_overrides_conflicting_user_id(): """If value already has a different user_id in metadata, it gets overwritten.""" - value = {"metadata": {"owner_id": "attacker"}} + value = {"metadata": {"user_id": "attacker"}} asyncio.run(add_owner_filter(_make_ctx("real-owner"), value)) - assert value["metadata"]["owner_id"] == "real-owner" + assert value["metadata"]["user_id"] == "real-owner" def test_filter_with_empty_metadata(): """Explicit empty metadata dict is fine.""" value = {"metadata": {}} result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value)) - assert value["metadata"]["owner_id"] == "user-z" - assert result == {"owner_id": "user-z"} + assert value["metadata"]["user_id"] == "user-z" + assert result == {"user_id": "user-z"} # ── Gateway parity ─────────────────────────────────────────────────────── diff --git a/backend/tests/test_memory_thread_meta_isolation.py b/backend/tests/test_memory_thread_meta_isolation.py new file mode 100644 index 000000000..25c9298f0 --- /dev/null +++ b/backend/tests/test_memory_thread_meta_isolation.py @@ -0,0 +1,156 @@ +"""Owner isolation tests for MemoryThreadMetaStore. + +Mirrors the SQL-backed tests in test_owner_isolation.py but exercises +the in-memory LangGraph Store backend used when database.backend=memory. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +from langgraph.store.memory import InMemoryStore + +from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore +from deerflow.runtime.user_context import reset_current_user, set_current_user + +USER_A = SimpleNamespace(id="user-a", email="a@test.local") +USER_B = SimpleNamespace(id="user-b", email="b@test.local") + + +def _as_user(user): + class _Ctx: + def __enter__(self): + self._token = set_current_user(user) + return user + + def __exit__(self, *exc): + reset_current_user(self._token) + + return _Ctx() + + +@pytest.fixture +def store(): + return MemoryThreadMetaStore(InMemoryStore()) + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_search_isolation(store): + """search() returns only threads owned by the current user.""" + with _as_user(USER_A): + await store.create("t-alpha", display_name="A's thread") + with _as_user(USER_B): + await store.create("t-beta", display_name="B's thread") + + with _as_user(USER_A): + results = await store.search() + assert [r["thread_id"] for r in results] == ["t-alpha"] + + with _as_user(USER_B): + results = await store.search() + assert [r["thread_id"] for r in results] == ["t-beta"] + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_get_isolation(store): + """get() returns None for threads owned by another user.""" + with _as_user(USER_A): + await store.create("t-alpha", display_name="A's thread") + + with _as_user(USER_B): + assert await store.get("t-alpha") is None + + with _as_user(USER_A): + result = await store.get("t-alpha") + assert result is not None + assert result["display_name"] == "A's thread" + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_update_display_name_denied(store): + """User B cannot rename User A's thread.""" + with _as_user(USER_A): + await store.create("t-alpha", display_name="original") + + with _as_user(USER_B): + await store.update_display_name("t-alpha", "hacked") + + with _as_user(USER_A): + row = await store.get("t-alpha") + assert row is not None + assert row["display_name"] == "original" + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_update_status_denied(store): + """User B cannot change status of User A's thread.""" + with _as_user(USER_A): + await store.create("t-alpha") + + with _as_user(USER_B): + await store.update_status("t-alpha", "error") + + with _as_user(USER_A): + row = await store.get("t-alpha") + assert row is not None + assert row["status"] == "idle" + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_update_metadata_denied(store): + """User B cannot modify metadata of User A's thread.""" + with _as_user(USER_A): + await store.create("t-alpha", metadata={"key": "original"}) + + with _as_user(USER_B): + await store.update_metadata("t-alpha", {"key": "hacked"}) + + with _as_user(USER_A): + row = await store.get("t-alpha") + assert row is not None + assert row["metadata"]["key"] == "original" + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_delete_denied(store): + """User B cannot delete User A's thread.""" + with _as_user(USER_A): + await store.create("t-alpha") + + with _as_user(USER_B): + await store.delete("t-alpha") + + with _as_user(USER_A): + row = await store.get("t-alpha") + assert row is not None + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_no_context_raises(store): + """Calling methods without user context raises RuntimeError.""" + with pytest.raises(RuntimeError, match="no user context is set"): + await store.search() + + +@pytest.mark.anyio +@pytest.mark.no_auto_user +async def test_explicit_none_bypasses_filter(store): + """user_id=None bypasses isolation (migration/CLI escape hatch).""" + with _as_user(USER_A): + await store.create("t-alpha") + with _as_user(USER_B): + await store.create("t-beta") + + all_rows = await store.search(user_id=None) + assert {r["thread_id"] for r in all_rows} == {"t-alpha", "t-beta"} + + row = await store.get("t-alpha", user_id=None) + assert row is not None diff --git a/backend/tests/test_owner_isolation.py b/backend/tests/test_owner_isolation.py index 4943936c7..33d21f3e3 100644 --- a/backend/tests/test_owner_isolation.py +++ b/backend/tests/test_owner_isolation.py @@ -9,8 +9,8 @@ These tests bypass the HTTP layer and exercise the storage-layer owner filter directly by switching the ``user_context`` contextvar between two users. The safety property under test is: - After a repository write with owner_id=A, a subsequent read with - owner_id=B must not return the row, and vice versa. + After a repository write with user_id=A, a subsequent read with + user_id=B must not return the row, and vice versa. The HTTP layer is covered by test_auth_middleware.py, which proves that a request cookie reaches the ``set_current_user`` call. Together @@ -431,13 +431,13 @@ async def test_repository_without_context_raises(tmp_path): await cleanup() -# ── Escape hatch: explicit owner_id=None bypasses filter (for migration) ── +# ── Escape hatch: explicit user_id=None bypasses filter (for migration) ── @pytest.mark.anyio @pytest.mark.no_auto_user async def test_explicit_none_bypasses_filter(tmp_path): - """Migration scripts pass owner_id=None to see all rows regardless of owner.""" + """Migration scripts pass user_id=None to see all rows regardless of owner.""" from deerflow.persistence.engine import get_session_factory from deerflow.persistence.thread_meta import ThreadMetaRepository @@ -452,14 +452,14 @@ async def test_explicit_none_bypasses_filter(tmp_path): await repo.create("t-beta") # Migration-style read: no contextvar, explicit None bypass. - all_rows = await repo.search(owner_id=None) + all_rows = await repo.search(user_id=None) thread_ids = {r["thread_id"] for r in all_rows} assert thread_ids == {"t-alpha", "t-beta"} # Explicit get with None does not apply the filter either. - row_a = await repo.get("t-alpha", owner_id=None) + row_a = await repo.get("t-alpha", user_id=None) assert row_a is not None - row_b = await repo.get("t-beta", owner_id=None) + row_b = await repo.get("t-beta", user_id=None) assert row_b is not None finally: await cleanup() diff --git a/backend/tests/test_persistence_scaffold.py b/backend/tests/test_persistence_scaffold.py index bd098c707..178a08e84 100644 --- a/backend/tests/test_persistence_scaffold.py +++ b/backend/tests/test_persistence_scaffold.py @@ -2,7 +2,7 @@ Tests: 1. DatabaseConfig property derivation (paths, URLs) -2. MemoryRunStore CRUD + owner_id filtering +2. MemoryRunStore CRUD + user_id filtering 3. Base.to_dict() via inspect mixin 4. Engine init/close lifecycle (memory + SQLite) 5. Postgres missing-dep error message @@ -24,18 +24,19 @@ class TestDatabaseConfig: assert c.backend == "memory" assert c.pool_size == 5 - def test_sqlite_paths_are_different(self): + def test_sqlite_paths_unified(self): c = DatabaseConfig(backend="sqlite", sqlite_dir="./mydata") - assert c.checkpointer_sqlite_path.endswith("checkpoints.db") - assert c.app_sqlite_path.endswith("app.db") - assert "mydata" in c.checkpointer_sqlite_path - assert c.checkpointer_sqlite_path != c.app_sqlite_path + assert c.sqlite_path.endswith("deerflow.db") + assert "mydata" in c.sqlite_path + # Backward-compatible aliases point to the same file + assert c.checkpointer_sqlite_path == c.sqlite_path + assert c.app_sqlite_path == c.sqlite_path def test_app_sqlalchemy_url_sqlite(self): c = DatabaseConfig(backend="sqlite", sqlite_dir="./data") url = c.app_sqlalchemy_url assert url.startswith("sqlite+aiosqlite:///") - assert "app.db" in url + assert "deerflow.db" in url def test_app_sqlalchemy_url_postgres(self): c = DatabaseConfig( @@ -105,17 +106,17 @@ class TestMemoryRunStore: @pytest.mark.anyio async def test_list_by_thread_owner_filter(self, store): - await store.put("r1", thread_id="t1", owner_id="alice") - await store.put("r2", thread_id="t1", owner_id="bob") - rows = await store.list_by_thread("t1", owner_id="alice") + await store.put("r1", thread_id="t1", user_id="alice") + await store.put("r2", thread_id="t1", user_id="bob") + rows = await store.list_by_thread("t1", user_id="alice") assert len(rows) == 1 - assert rows[0]["owner_id"] == "alice" + assert rows[0]["user_id"] == "alice" @pytest.mark.anyio async def test_owner_none_returns_all(self, store): - await store.put("r1", thread_id="t1", owner_id="alice") - await store.put("r2", thread_id="t1", owner_id="bob") - rows = await store.list_by_thread("t1", owner_id=None) + await store.put("r1", thread_id="t1", user_id="alice") + await store.put("r2", thread_id="t1", user_id="bob") + rows = await store.list_by_thread("t1", user_id=None) assert len(rows) == 2 @pytest.mark.anyio diff --git a/backend/tests/test_run_journal.py b/backend/tests/test_run_journal.py index dbb307a55..b306f59ec 100644 --- a/backend/tests/test_run_journal.py +++ b/backend/tests/test_run_journal.py @@ -709,6 +709,81 @@ class TestToolResultMessage: assert tool_end["metadata"]["tool_call_id"] == "call_from_obj" assert tool_end["metadata"]["tool_name"] == "web_search" + @pytest.mark.anyio + async def test_tool_invoke_end_to_end_unwraps_command(self, journal_setup): + """End-to-end: invoke a real LangChain tool that returns Command(update={'messages':[ToolMessage]}). + + This goes through the real LangChain callback path (tool.invoke -> CallbackManager + -> on_tool_start/on_tool_end), which is what the production agent uses. Mirrors + the ``present_files`` tool shape exactly. + """ + from langchain_core.callbacks import CallbackManager + from langchain_core.messages import ToolMessage + from langchain_core.tools import tool + from langgraph.types import Command + + j, store = journal_setup + + @tool + def fake_present_files(filepaths: list[str]) -> Command: + """Fake present_files that returns a Command with an inner ToolMessage.""" + return Command( + update={ + "artifacts": filepaths, + "messages": [ToolMessage("Successfully presented files", tool_call_id="tc_123")], + }, + ) + + # Real LangChain callback dispatch (matches production agent path) + cm = CallbackManager(handlers=[j]) + fake_present_files.invoke( + {"filepaths": ["/mnt/user-data/outputs/report.md"]}, + config={"callbacks": cm, "run_id": uuid4()}, + ) + await j.flush() + + messages = await store.list_messages("t1") + assert len(messages) == 1, f"expected 1 message event, got {len(messages)}: {messages}" + content = messages[0]["content"] + assert content["type"] == "tool" + # CRITICAL: must be the inner ToolMessage text, not str(Command(...)) + assert content["content"] == "Successfully presented files", ( + f"Command unwrap failed; stored content = {content['content']!r}" + ) + assert "Command(update=" not in str(content["content"]) + + @pytest.mark.anyio + async def test_tool_end_unwraps_command_with_inner_tool_message(self, journal_setup): + """Tools like ``present_files`` return Command(update={'messages': [ToolMessage(...)]}). + + LangGraph unwraps the inner ToolMessage into checkpoint state, so the + event store must do the same — otherwise it captures ``str(Command(...))`` + and the /history response diverges from the real rendered message. + """ + from langchain_core.messages import ToolMessage + from langgraph.types import Command + + j, store = journal_setup + run_id = uuid4() + inner = ToolMessage( + content="Successfully presented files", + tool_call_id="call_present", + name="present_files", + status="success", + ) + cmd = Command(update={"artifacts": ["/mnt/user-data/outputs/report.md"], "messages": [inner]}) + j.on_tool_end(cmd, run_id=run_id) + await j.flush() + + messages = await store.list_messages("t1") + assert len(messages) == 1 + content = messages[0]["content"] + assert content["type"] == "tool" + assert content["content"] == "Successfully presented files" + assert content["tool_call_id"] == "call_present" + assert content["name"] == "present_files" + assert "Command(update=" not in str(content["content"]) + @pytest.mark.anyio async def test_tool_message_object_overrides_kwargs(self, journal_setup): """ToolMessage object fields take priority over kwargs.""" diff --git a/backend/tests/test_run_repository.py b/backend/tests/test_run_repository.py index 0a3ddc7dc..34ab9b492 100644 --- a/backend/tests/test_run_repository.py +++ b/backend/tests/test_run_repository.py @@ -73,11 +73,11 @@ class TestRunRepository: @pytest.mark.anyio async def test_list_by_thread_owner_filter(self, tmp_path): repo = await _make_repo(tmp_path) - await repo.put("r1", thread_id="t1", owner_id="alice") - await repo.put("r2", thread_id="t1", owner_id="bob") - rows = await repo.list_by_thread("t1", owner_id="alice") + await repo.put("r1", thread_id="t1", user_id="alice") + await repo.put("r2", thread_id="t1", user_id="bob") + rows = await repo.list_by_thread("t1", user_id="alice") assert len(rows) == 1 - assert rows[0]["owner_id"] == "alice" + assert rows[0]["user_id"] == "alice" await _cleanup() @pytest.mark.anyio @@ -189,8 +189,8 @@ class TestRunRepository: @pytest.mark.anyio async def test_owner_none_returns_all(self, tmp_path): repo = await _make_repo(tmp_path) - await repo.put("r1", thread_id="t1", owner_id="alice") - await repo.put("r2", thread_id="t1", owner_id="bob") - rows = await repo.list_by_thread("t1", owner_id=None) + await repo.put("r1", thread_id="t1", user_id="alice") + await repo.put("r2", thread_id="t1", user_id="bob") + rows = await repo.list_by_thread("t1", user_id=None) assert len(rows) == 2 await _cleanup() diff --git a/backend/tests/test_suggestions_router.py b/backend/tests/test_suggestions_router.py index ea9eb41df..a8b9b0915 100644 --- a/backend/tests/test_suggestions_router.py +++ b/backend/tests/test_suggestions_router.py @@ -47,7 +47,7 @@ def test_generate_suggestions_parses_and_limits(monkeypatch): monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) # Bypass the require_permission decorator (which needs request + - # thread_meta_repo) — these tests cover the parsing logic. + # thread_store) — these tests cover the parsing logic. result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) assert result.suggestions == ["Q1", "Q2", "Q3"] @@ -67,7 +67,7 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch): monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) # Bypass the require_permission decorator (which needs request + - # thread_meta_repo) — these tests cover the parsing logic. + # thread_store) — these tests cover the parsing logic. result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) assert result.suggestions == ["Q1", "Q2"] @@ -87,7 +87,7 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch): monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) # Bypass the require_permission decorator (which needs request + - # thread_meta_repo) — these tests cover the parsing logic. + # thread_store) — these tests cover the parsing logic. result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) assert result.suggestions == ["Q1", "Q2"] @@ -104,7 +104,7 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch): monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) # Bypass the require_permission decorator (which needs request + - # thread_meta_repo) — these tests cover the parsing logic. + # thread_store) — these tests cover the parsing logic. result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) assert result.suggestions == [] diff --git a/backend/tests/test_thread_meta_repo.py b/backend/tests/test_thread_meta_repo.py index 442cf388a..3a6532567 100644 --- a/backend/tests/test_thread_meta_repo.py +++ b/backend/tests/test_thread_meta_repo.py @@ -43,8 +43,8 @@ class TestThreadMetaRepository: @pytest.mark.anyio async def test_create_with_owner_and_display_name(self, tmp_path): repo = await _make_repo(tmp_path) - record = await repo.create("t1", owner_id="user1", display_name="My Thread") - assert record["owner_id"] == "user1" + record = await repo.create("t1", user_id="user1", display_name="My Thread") + assert record["user_id"] == "user1" assert record["display_name"] == "My Thread" await _cleanup() @@ -61,26 +61,6 @@ class TestThreadMetaRepository: assert await repo.get("nonexistent") is None await _cleanup() - @pytest.mark.anyio - async def test_list_by_owner(self, tmp_path): - repo = await _make_repo(tmp_path) - await repo.create("t1", owner_id="user1") - await repo.create("t2", owner_id="user1") - await repo.create("t3", owner_id="user2") - results = await repo.list_by_owner("user1") - assert len(results) == 2 - assert all(r["owner_id"] == "user1" for r in results) - await _cleanup() - - @pytest.mark.anyio - async def test_list_by_owner_with_limit_and_offset(self, tmp_path): - repo = await _make_repo(tmp_path) - for i in range(5): - await repo.create(f"t{i}", owner_id="user1") - results = await repo.list_by_owner("user1", limit=2, offset=1) - assert len(results) == 2 - await _cleanup() - @pytest.mark.anyio async def test_check_access_no_record_allows(self, tmp_path): repo = await _make_repo(tmp_path) @@ -90,23 +70,23 @@ class TestThreadMetaRepository: @pytest.mark.anyio async def test_check_access_owner_matches(self, tmp_path): repo = await _make_repo(tmp_path) - await repo.create("t1", owner_id="user1") + await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user1") is True await _cleanup() @pytest.mark.anyio async def test_check_access_owner_mismatch(self, tmp_path): repo = await _make_repo(tmp_path) - await repo.create("t1", owner_id="user1") + await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user2") is False await _cleanup() @pytest.mark.anyio async def test_check_access_no_owner_allows_all(self, tmp_path): repo = await _make_repo(tmp_path) - # Explicit owner_id=None to bypass the new AUTO default that + # Explicit user_id=None to bypass the new AUTO default that # would otherwise pick up the test user from the autouse fixture. - await repo.create("t1", owner_id=None) + await repo.create("t1", user_id=None) assert await repo.check_access("t1", "anyone") is True await _cleanup() @@ -125,27 +105,27 @@ class TestThreadMetaRepository: @pytest.mark.anyio async def test_check_access_strict_owner_match_allowed(self, tmp_path): repo = await _make_repo(tmp_path) - await repo.create("t1", owner_id="user1") + await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user1", require_existing=True) is True await _cleanup() @pytest.mark.anyio async def test_check_access_strict_owner_mismatch_denied(self, tmp_path): repo = await _make_repo(tmp_path) - await repo.create("t1", owner_id="user1") + await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user2", require_existing=True) is False await _cleanup() @pytest.mark.anyio async def test_check_access_strict_null_owner_still_allowed(self, tmp_path): - """Even in strict mode, a row with NULL owner_id stays shared. + """Even in strict mode, a row with NULL user_id stays shared. The strict flag tightens the *missing row* case, not the *shared row* case — legacy pre-auth rows that survived a clean migration without an owner are still everyone's. """ repo = await _make_repo(tmp_path) - await repo.create("t1", owner_id=None) + await repo.create("t1", user_id=None) assert await repo.check_access("t1", "anyone", require_existing=True) is True await _cleanup() diff --git a/backend/tests/test_thread_state_event_store.py b/backend/tests/test_thread_state_event_store.py new file mode 100644 index 000000000..0d3b19761 --- /dev/null +++ b/backend/tests/test_thread_state_event_store.py @@ -0,0 +1,439 @@ +"""Tests for event-store-backed message loading in thread state/history endpoints. + +Covers the helper functions added to ``app/gateway/routers/threads.py``: + +- ``_sanitize_legacy_command_repr`` — extracts inner ToolMessage text from + legacy ``str(Command(...))`` strings captured before the ``journal.py`` + fix for state-updating tools like ``present_files``. +- ``_get_event_store_messages`` — loads the full message stream with full + pagination, copy-on-read id patching, legacy Command sanitization, and + a clean fallback to ``None`` when the event store is unavailable. +""" + +from __future__ import annotations + +import uuid +from types import SimpleNamespace +from typing import Any + +import pytest + +from app.gateway.routers.threads import ( + _get_event_store_messages, + _sanitize_legacy_command_repr, +) +from deerflow.runtime.events.store.memory import MemoryRunEventStore + + +@pytest.fixture() +def event_store() -> MemoryRunEventStore: + return MemoryRunEventStore() + + +class _FakeFeedbackRepo: + """Minimal ``FeedbackRepository`` stand-in that returns a configured map.""" + + def __init__(self, by_run: dict[str, dict] | None = None) -> None: + self._by_run = by_run or {} + + async def list_by_thread_grouped(self, thread_id: str, *, user_id: str | None) -> dict[str, dict]: + return dict(self._by_run) + + +def _make_request( + event_store: MemoryRunEventStore, + feedback_repo: _FakeFeedbackRepo | None = None, +) -> Any: + """Build a minimal FastAPI-like Request object. + + ``get_run_event_store(request)`` reads ``request.app.state.run_event_store``. + ``get_feedback_repo(request)`` reads ``request.app.state.feedback_repo``. + ``get_current_user`` is monkey-patched separately in tests that need it. + """ + state = SimpleNamespace( + run_event_store=event_store, + feedback_repo=feedback_repo or _FakeFeedbackRepo(), + ) + app = SimpleNamespace(state=state) + return SimpleNamespace(app=app) + + +@pytest.fixture(autouse=True) +def _stub_current_user(monkeypatch): + """Stub out ``get_current_user`` so tests don't need real auth context.""" + import app.gateway.routers.threads as threads_mod + + async def _fake(_request): + return None + + monkeypatch.setattr(threads_mod, "get_current_user", _fake) + + +async def _seed_simple_run(store: MemoryRunEventStore, thread_id: str, run_id: str) -> None: + """Seed one run: human + ai_tool_call + tool_result + final ai_message, plus a trace.""" + await store.put( + thread_id=thread_id, run_id=run_id, + event_type="human_message", category="message", + content={ + "type": "human", "id": None, + "content": [{"type": "text", "text": "hello"}], + "additional_kwargs": {}, "response_metadata": {}, "name": None, + }, + ) + await store.put( + thread_id=thread_id, run_id=run_id, + event_type="ai_tool_call", category="message", + content={ + "type": "ai", "id": "lc_run--tc1", + "content": "", + "tool_calls": [{"name": "search", "args": {"q": "x"}, "id": "call_1", "type": "tool_call"}], + "invalid_tool_calls": [], + "additional_kwargs": {}, "response_metadata": {}, "name": None, + "usage_metadata": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, + }, + ) + await store.put( + thread_id=thread_id, run_id=run_id, + event_type="tool_result", category="message", + content={ + "type": "tool", "id": None, + "content": "results", + "tool_call_id": "call_1", "name": "search", + "artifact": None, "status": "success", + "additional_kwargs": {}, "response_metadata": {}, + }, + ) + await store.put( + thread_id=thread_id, run_id=run_id, + event_type="ai_message", category="message", + content={ + "type": "ai", "id": "lc_run--final1", + "content": "done", + "tool_calls": [], "invalid_tool_calls": [], + "additional_kwargs": {}, "response_metadata": {"finish_reason": "stop"}, "name": None, + "usage_metadata": {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}, + }, + ) + # Non-message trace — must be filtered out. + await store.put( + thread_id=thread_id, run_id=run_id, + event_type="llm_request", category="trace", + content={"model": "test"}, + ) + + +class TestSanitizeLegacyCommandRepr: + def test_passthrough_non_string(self): + assert _sanitize_legacy_command_repr(None) is None + assert _sanitize_legacy_command_repr(42) == 42 + assert _sanitize_legacy_command_repr([{"type": "text", "text": "x"}]) == [{"type": "text", "text": "x"}] + + def test_passthrough_plain_string(self): + assert _sanitize_legacy_command_repr("Successfully presented files") == "Successfully presented files" + assert _sanitize_legacy_command_repr("") == "" + + def test_extracts_inner_content_single_quotes(self): + legacy = ( + "Command(update={'artifacts': ['/mnt/user-data/outputs/report.md'], " + "'messages': [ToolMessage(content='Successfully presented files', " + "tool_call_id='call_abc')]})" + ) + assert _sanitize_legacy_command_repr(legacy) == "Successfully presented files" + + def test_extracts_inner_content_double_quotes(self): + legacy = 'Command(update={"messages": [ToolMessage(content="ok", tool_call_id="x")]})' + assert _sanitize_legacy_command_repr(legacy) == "ok" + + def test_unparseable_command_returns_original(self): + legacy = "Command(update={'something_else': 1})" + assert _sanitize_legacy_command_repr(legacy) == legacy + + +class TestGetEventStoreMessages: + @pytest.mark.anyio + async def test_returns_none_when_store_empty(self, event_store): + request = _make_request(event_store) + assert await _get_event_store_messages(request, "t_missing") is None + + @pytest.mark.anyio + async def test_extracts_all_message_types_in_order(self, event_store): + await _seed_simple_run(event_store, "t1", "r1") + request = _make_request(event_store) + messages = await _get_event_store_messages(request, "t1") + assert messages is not None + types = [m["type"] for m in messages] + assert types == ["human", "ai", "tool", "ai"] + # Trace events must not appear + for m in messages: + assert m.get("type") in {"human", "ai", "tool"} + + @pytest.mark.anyio + async def test_null_ids_get_deterministic_uuid5(self, event_store): + await _seed_simple_run(event_store, "t1", "r1") + request = _make_request(event_store) + messages = await _get_event_store_messages(request, "t1") + assert messages is not None + + # AI messages keep their LLM ids + assert messages[1]["id"] == "lc_run--tc1" + assert messages[3]["id"] == "lc_run--final1" + + # Human (seq=1) + tool (seq=3) get deterministic uuid5 + expected_human_id = str(uuid.uuid5(uuid.NAMESPACE_URL, "t1:1")) + expected_tool_id = str(uuid.uuid5(uuid.NAMESPACE_URL, "t1:3")) + assert messages[0]["id"] == expected_human_id + assert messages[2]["id"] == expected_tool_id + + # Re-running produces the same ids (stability across requests) + messages2 = await _get_event_store_messages(request, "t1") + assert [m["id"] for m in messages2] == [m["id"] for m in messages] + + @pytest.mark.anyio + async def test_helper_does_not_mutate_store(self, event_store): + """Helper must copy content dicts; the live store must stay unchanged.""" + await _seed_simple_run(event_store, "t1", "r1") + request = _make_request(event_store) + _ = await _get_event_store_messages(request, "t1") + + # Raw store records still have id=None for human/tool + raw = await event_store.list_messages("t1", limit=500) + human = next(e for e in raw if e["content"]["type"] == "human") + tool = next(e for e in raw if e["content"]["type"] == "tool") + assert human["content"]["id"] is None + assert tool["content"]["id"] is None + + @pytest.mark.anyio + async def test_legacy_command_repr_sanitized(self, event_store): + """A tool_result whose content is a legacy ``str(Command(...))`` is cleaned.""" + legacy = ( + "Command(update={'artifacts': ['/mnt/user-data/outputs/x.md'], " + "'messages': [ToolMessage(content='Successfully presented files', " + "tool_call_id='call_p')]})" + ) + await event_store.put( + thread_id="t2", run_id="r1", + event_type="tool_result", category="message", + content={ + "type": "tool", "id": None, + "content": legacy, + "tool_call_id": "call_p", "name": "present_files", + "artifact": None, "status": "success", + "additional_kwargs": {}, "response_metadata": {}, + }, + ) + request = _make_request(event_store) + messages = await _get_event_store_messages(request, "t2") + assert messages is not None and len(messages) == 1 + assert messages[0]["content"] == "Successfully presented files" + + @pytest.mark.anyio + async def test_pagination_covers_more_than_one_page(self, event_store, monkeypatch): + """Simulate a long thread that exceeds a single page to exercise the loop.""" + thread_id = "t_long" + # Seed 12 human messages + for i in range(12): + await event_store.put( + thread_id=thread_id, run_id="r1", + event_type="human_message", category="message", + content={ + "type": "human", "id": None, + "content": [{"type": "text", "text": f"msg {i}"}], + "additional_kwargs": {}, "response_metadata": {}, "name": None, + }, + ) + + # Force small page size to exercise pagination + import app.gateway.routers.threads as threads_mod + original = threads_mod._get_event_store_messages + + # Monkeypatch MemoryRunEventStore.list_messages to assert it's called with cursor pagination + calls: list[dict] = [] + real_list = event_store.list_messages + + async def spy_list_messages(tid, *, limit=50, before_seq=None, after_seq=None): + calls.append({"limit": limit, "after_seq": after_seq}) + return await real_list(tid, limit=limit, before_seq=before_seq, after_seq=after_seq) + + monkeypatch.setattr(event_store, "list_messages", spy_list_messages) + + request = _make_request(event_store) + messages = await original(request, thread_id) + assert messages is not None + assert len(messages) == 12 + assert [m["content"][0]["text"] for m in messages] == [f"msg {i}" for i in range(12)] + # At least one call was made with after_seq=None (the initial page) + assert any(c["after_seq"] is None for c in calls) + + @pytest.mark.anyio + async def test_summarize_regression_recovers_pre_summarize_messages(self, event_store): + """The exact bug: checkpoint would have only post-summarize messages; + event store must surface the original pre-summarize human query.""" + # Run 1 (pre-summarize) + await event_store.put( + thread_id="t_sum", run_id="r1", + event_type="human_message", category="message", + content={ + "type": "human", "id": None, + "content": [{"type": "text", "text": "original question"}], + "additional_kwargs": {}, "response_metadata": {}, "name": None, + }, + ) + await event_store.put( + thread_id="t_sum", run_id="r1", + event_type="ai_message", category="message", + content={ + "type": "ai", "id": "lc_run--r1", + "content": "first answer", + "tool_calls": [], "invalid_tool_calls": [], + "additional_kwargs": {}, "response_metadata": {}, "name": None, + "usage_metadata": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, + }, + ) + # Run 2 (post-summarize — what the checkpoint still has) + await event_store.put( + thread_id="t_sum", run_id="r2", + event_type="human_message", category="message", + content={ + "type": "human", "id": None, + "content": [{"type": "text", "text": "follow up"}], + "additional_kwargs": {}, "response_metadata": {}, "name": None, + }, + ) + await event_store.put( + thread_id="t_sum", run_id="r2", + event_type="ai_message", category="message", + content={ + "type": "ai", "id": "lc_run--r2", + "content": "second answer", + "tool_calls": [], "invalid_tool_calls": [], + "additional_kwargs": {}, "response_metadata": {}, "name": None, + "usage_metadata": {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, + }, + ) + + request = _make_request(event_store) + messages = await _get_event_store_messages(request, "t_sum") + assert messages is not None + # 4 messages, not 2 (which is what the summarized checkpoint would yield) + assert len(messages) == 4 + assert messages[0]["content"][0]["text"] == "original question" + assert messages[1]["id"] == "lc_run--r1" + assert messages[3]["id"] == "lc_run--r2" + + @pytest.mark.anyio + async def test_run_id_attached_to_every_message(self, event_store): + await _seed_simple_run(event_store, "t1", "r1") + request = _make_request(event_store) + messages = await _get_event_store_messages(request, "t1") + assert messages is not None + assert all(m.get("run_id") == "r1" for m in messages) + + @pytest.mark.anyio + async def test_feedback_attached_only_to_final_ai_message_per_run(self, event_store): + await _seed_simple_run(event_store, "t1", "r1") + feedback_repo = _FakeFeedbackRepo( + {"r1": {"feedback_id": "fb1", "rating": 1, "comment": "great"}} + ) + request = _make_request(event_store, feedback_repo=feedback_repo) + messages = await _get_event_store_messages(request, "t1") + assert messages is not None + + # human (0), ai_tool_call (1), tool (2), ai_message (3) + final_ai = messages[3] + assert final_ai["feedback"] == { + "feedback_id": "fb1", + "rating": 1, + "comment": "great", + } + # Non-final messages must NOT have a feedback key at all — the + # frontend keys button visibility off of this. + assert "feedback" not in messages[0] + assert "feedback" not in messages[1] + assert "feedback" not in messages[2] + + @pytest.mark.anyio + async def test_feedback_none_when_no_row_for_run(self, event_store): + await _seed_simple_run(event_store, "t1", "r1") + request = _make_request(event_store, feedback_repo=_FakeFeedbackRepo({})) + messages = await _get_event_store_messages(request, "t1") + assert messages is not None + # Final ai_message gets an explicit ``None`` — distinguishes "eligible + # but unrated" from "not eligible" (field absent). + assert messages[3]["feedback"] is None + + @pytest.mark.anyio + async def test_feedback_per_run_for_multi_run_thread(self, event_store): + """A thread with two runs: each final ai_message should get its own feedback.""" + # Run 1 + await event_store.put( + thread_id="t_multi", run_id="r1", + event_type="human_message", category="message", + content={"type": "human", "id": None, "content": "q1", + "additional_kwargs": {}, "response_metadata": {}, "name": None}, + ) + await event_store.put( + thread_id="t_multi", run_id="r1", + event_type="ai_message", category="message", + content={"type": "ai", "id": "lc_run--a1", "content": "a1", + "tool_calls": [], "invalid_tool_calls": [], + "additional_kwargs": {}, "response_metadata": {}, "name": None, + "usage_metadata": None}, + ) + # Run 2 + await event_store.put( + thread_id="t_multi", run_id="r2", + event_type="human_message", category="message", + content={"type": "human", "id": None, "content": "q2", + "additional_kwargs": {}, "response_metadata": {}, "name": None}, + ) + await event_store.put( + thread_id="t_multi", run_id="r2", + event_type="ai_message", category="message", + content={"type": "ai", "id": "lc_run--a2", "content": "a2", + "tool_calls": [], "invalid_tool_calls": [], + "additional_kwargs": {}, "response_metadata": {}, "name": None, + "usage_metadata": None}, + ) + feedback_repo = _FakeFeedbackRepo({ + "r1": {"feedback_id": "fb_r1", "rating": 1, "comment": None}, + "r2": {"feedback_id": "fb_r2", "rating": -1, "comment": "meh"}, + }) + request = _make_request(event_store, feedback_repo=feedback_repo) + messages = await _get_event_store_messages(request, "t_multi") + assert messages is not None + # human[r1], ai[r1], human[r2], ai[r2] + assert messages[1]["feedback"]["feedback_id"] == "fb_r1" + assert messages[1]["feedback"]["rating"] == 1 + assert messages[3]["feedback"]["feedback_id"] == "fb_r2" + assert messages[3]["feedback"]["rating"] == -1 + # Humans don't get feedback + assert "feedback" not in messages[0] + assert "feedback" not in messages[2] + + @pytest.mark.anyio + async def test_feedback_repo_failure_does_not_break_helper(self, monkeypatch, event_store): + """If feedback lookup throws, messages still come back without feedback.""" + await _seed_simple_run(event_store, "t1", "r1") + + class _BoomRepo: + async def list_by_thread_grouped(self, *a, **kw): + raise RuntimeError("db down") + + request = _make_request(event_store, feedback_repo=_BoomRepo()) + messages = await _get_event_store_messages(request, "t1") + assert messages is not None + assert len(messages) == 4 + for m in messages: + assert "feedback" not in m + + @pytest.mark.anyio + async def test_returns_none_when_dep_raises(self, monkeypatch, event_store): + """When ``get_run_event_store`` is not configured, helper returns None.""" + import app.gateway.routers.threads as threads_mod + + def boom(_request): + raise RuntimeError("no store") + + monkeypatch.setattr(threads_mod, "get_run_event_store", boom) + request = _make_request(event_store) + assert await threads_mod._get_event_store_messages(request, "t1") is None diff --git a/backend/tests/test_threads_router.py b/backend/tests/test_threads_router.py index 5864350a1..c6f063e32 100644 --- a/backend/tests/test_threads_router.py +++ b/backend/tests/test_threads_router.py @@ -113,14 +113,8 @@ def test_delete_thread_data_returns_generic_500_error(tmp_path): # ── Server-reserved metadata key stripping ────────────────────────────────── -def test_strip_reserved_metadata_removes_owner_id(): - """Client-supplied owner_id is dropped to prevent reflection attacks.""" - out = threads._strip_reserved_metadata({"owner_id": "victim-id", "title": "ok"}) - assert out == {"title": "ok"} - - def test_strip_reserved_metadata_removes_user_id(): - """user_id is also reserved (defense in depth for any future use).""" + """Client-supplied user_id is dropped to prevent reflection attacks.""" out = threads._strip_reserved_metadata({"user_id": "victim-id", "title": "ok"}) assert out == {"title": "ok"} @@ -136,6 +130,6 @@ def test_strip_reserved_metadata_empty_input(): assert threads._strip_reserved_metadata({}) == {} -def test_strip_reserved_metadata_strips_both_simultaneously(): - out = threads._strip_reserved_metadata({"owner_id": "x", "user_id": "y", "keep": "me"}) +def test_strip_reserved_metadata_strips_all_reserved_keys(): + out = threads._strip_reserved_metadata({"user_id": "x", "keep": "me"}) assert out == {"keep": "me"} diff --git a/config.example.yaml b/config.example.yaml index aa78cc67c..07ef54bb3 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -740,8 +740,8 @@ skill_evolution: # backend: sqlite -- Single-node deployment, files in sqlite_dir # backend: postgres -- Production multi-node deployment # -# SQLite mode automatically uses separate .db files for checkpointer -# and application data to avoid write-lock contention. +# SQLite mode uses a single deerflow.db file with WAL journal mode +# for both checkpointer and application data. # # Postgres mode: put your connection URL in .env as DATABASE_URL, # then reference it here with $DATABASE_URL. diff --git a/docs/superpowers/plans/2026-04-10-event-store-history.md b/docs/superpowers/plans/2026-04-10-event-store-history.md new file mode 100644 index 000000000..0e3eb1c35 --- /dev/null +++ b/docs/superpowers/plans/2026-04-10-event-store-history.md @@ -0,0 +1,471 @@ +# Event Store History — Backend Compatibility Layer + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Replace checkpoint state with the append-only event store as the message source in the thread state/history endpoints, so summarization never causes message loss. + +**Architecture:** The Gateway's `get_thread_state` and `get_thread_history` endpoints currently read messages from `checkpoint.channel_values["messages"]`. After summarization, those messages are replaced with a synthetic summary-as-human message and all pre-summarize messages are gone. We modify these endpoints to read messages from the RunEventStore instead (append-only, unaffected by summarization). The response shape for each message stays identical so the chat render path needs no changes, but the frontend's feedback hook must be aligned to use the same full-history view (see Task 4). + +**Tech Stack:** Python (FastAPI, SQLAlchemy), pytest, TypeScript (React Query) + +**Scope:** Gateway mode only (`make dev-pro`). Standard mode uses the LangGraph Server directly and does not go through these endpoints; the summarize bug is still present there and must be tracked as a separate follow-up (see §"Follow-ups" at end of plan). + +**Prerequisite already landed:** `backend/packages/harness/deerflow/runtime/journal.py` now unwraps `Command(update={'messages':[ToolMessage(...)]})` in `on_tool_end`, so new runs that use state-updating tools (e.g. `present_files`) write the inner `ToolMessage` content to the event store instead of `str(Command(...))`. Legacy data captured before this fix is cleaned up defensively by the new helper (see Task 1 Step 3 `_sanitize_legacy_command_repr`). + +--- + +## Real Data Alignment Analysis + +Compared real `POST /history` response (checkpoint-based) with `run_events` table for thread `6d30913e-dcd4-41c8-8941-f66c716cf359` (docs/resp.json + backend/.deer-flow/data/deerflow.db). See `docs/superpowers/specs/2026-04-11-runjournal-history-evaluation.md` for full evidence chain. + +| Message type | Fields compared | Difference | +|-------------|----------------|------------| +| human_message | all fields | `id` is `None` in event store, has UUID in checkpoint | +| ai_message (tool_call) | all fields, 6 overlapping | **IDENTICAL** (0 diffs) | +| ai_message (final) | all fields | **IDENTICAL** | +| tool_result (normal) | all fields | Only `id` differs (`None` vs UUID) | +| tool_result (from `Command`-returning tool) | content | **Legacy data stored `str(Command(...))` repr instead of inner ToolMessage** — fixed in journal.py for new runs; legacy rows sanitized by helper | + +**Root cause for id difference:** LangGraph's checkpoint assigns `id` to HumanMessage and ToolMessage during graph execution. Event store writes happen earlier, when those ids are still None. AI messages receive `id` from the LLM response (`lc_run--*`) and are unaffected. + +**Fix for id:** Generate deterministic UUIDs for `id=None` messages using `uuid5(NAMESPACE_URL, f"{thread_id}:{seq}")` at read time. Patch a **copy** of the content dict, never the live store object. + +**Summarize impact quantified on the reproducer thread**: event_store has 16 messages (7 AI + 9 others); checkpoint has 12 after summarize (5 AI + 7 others). AI id overlap: 5 of 7 — the 2 missing AI messages are pre-summarize. + +--- + +## File Structure + +| File | Action | Responsibility | +|------|--------|----------------| +| `backend/app/gateway/routers/threads.py` | Modify | Replace checkpoint messages with event store messages in `get_thread_state` and `get_thread_history` | +| `backend/tests/test_thread_state_event_store.py` | Create | Tests for the modified endpoints | + +--- + +### Task 1: Add `_get_event_store_messages` helper to `threads.py` + +A shared helper that loads the **full** message stream from the event store, patches `id=None` messages with deterministic UUIDs, and defensively sanitizes legacy `Command(update=...)` reprs captured before the journal.py fix. Patches a copy of each content dict so the live store is never mutated. + +**Design constraints (derived from evaluation §3, §4, §5):** +- **Full pagination**, not `limit=1000`. `RunEventStore.list_messages` returns "latest N records" — a fixed limit silently truncates older messages. Use `count_messages()` to size the request or loop with `after_seq` cursors. +- **Copy before mutate**. `MemoryRunEventStore` returns live dict references; the JSONL/DB stores may return detached rows but we must not rely on that. Always `content = dict(evt["content"])` before patching `id`. +- **Legacy Command sanitization.** Legacy data contains `content["content"] == "Command(update={'artifacts': [...], 'messages': [ToolMessage(content='X', ...)]})"`. Regex-extract the inner ToolMessage content string and replace; if extraction fails, leave content as-is (still strictly better than nothing because checkpoint fallback is also wrong for summarized threads). +- **User context.** `DbRunEventStore.list_messages` is user-scoped via `resolve_user_id(AUTO)` and relies on the auth contextvar set by `@require_permission`. Both endpoints are already decorated — document this dependency in the helper docstring. + +**Files:** +- Modify: `backend/app/gateway/routers/threads.py` +- Test: `backend/tests/test_thread_state_event_store.py` + +- [ ] **Step 1: Write the test** + +Create `backend/tests/test_thread_state_event_store.py`: + +```python +"""Tests for event-store-backed message loading in thread state/history endpoints.""" + +from __future__ import annotations + +import uuid + +import pytest + +from deerflow.runtime.events.store.memory import MemoryRunEventStore + + +@pytest.fixture() +def event_store(): + return MemoryRunEventStore() + + +async def _seed_conversation(event_store: MemoryRunEventStore, thread_id: str = "t1"): + """Seed a realistic multi-turn conversation matching real checkpoint format.""" + # human_message: id is None (same as real data) + await event_store.put( + thread_id=thread_id, run_id="r1", + event_type="human_message", category="message", + content={ + "type": "human", "id": None, + "content": [{"type": "text", "text": "Hello"}], + "additional_kwargs": {}, "response_metadata": {}, "name": None, + }, + ) + # ai_tool_call: id is set by LLM + await event_store.put( + thread_id=thread_id, run_id="r1", + event_type="ai_tool_call", category="message", + content={ + "type": "ai", "id": "lc_run--abc123", + "content": "", + "tool_calls": [{"name": "search", "args": {"q": "cats"}, "id": "call_1", "type": "tool_call"}], + "invalid_tool_calls": [], + "additional_kwargs": {}, "response_metadata": {}, "name": None, + "usage_metadata": {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}, + }, + ) + # tool_result: id is None (same as real data) + await event_store.put( + thread_id=thread_id, run_id="r1", + event_type="tool_result", category="message", + content={ + "type": "tool", "id": None, + "content": "Found 10 results", + "tool_call_id": "call_1", "name": "search", + "artifact": None, "status": "success", + "additional_kwargs": {}, "response_metadata": {}, + }, + ) + # ai_message: id is set by LLM + await event_store.put( + thread_id=thread_id, run_id="r1", + event_type="ai_message", category="message", + content={ + "type": "ai", "id": "lc_run--def456", + "content": "I found 10 results about cats.", + "tool_calls": [], "invalid_tool_calls": [], + "additional_kwargs": {}, "response_metadata": {"finish_reason": "stop"}, "name": None, + "usage_metadata": {"input_tokens": 200, "output_tokens": 100, "total_tokens": 300}, + }, + ) + # Also add a trace event — should NOT appear + await event_store.put( + thread_id=thread_id, run_id="r1", + event_type="llm_request", category="trace", + content={"model": "gpt-4"}, + ) + + +class TestGetEventStoreMessages: + """Verify event store message extraction with id patching.""" + + @pytest.mark.asyncio + async def test_extracts_all_message_types(self, event_store): + await _seed_conversation(event_store) + events = await event_store.list_messages("t1", limit=500) + messages = [evt["content"] for evt in events if isinstance(evt.get("content"), dict) and "type" in evt["content"]] + assert len(messages) == 4 + assert [m["type"] for m in messages] == ["human", "ai", "tool", "ai"] + + @pytest.mark.asyncio + async def test_null_ids_get_patched(self, event_store): + """Messages with id=None should get deterministic UUIDs.""" + await _seed_conversation(event_store) + events = await event_store.list_messages("t1", limit=500) + messages = [] + for evt in events: + content = evt.get("content") + if isinstance(content, dict) and "type" in content: + if content.get("id") is None: + content["id"] = str(uuid.uuid5(uuid.NAMESPACE_URL, f"t1:{evt['seq']}")) + messages.append(content) + + # All messages now have an id + for m in messages: + assert m["id"] is not None + assert isinstance(m["id"], str) + assert len(m["id"]) > 0 + + # AI messages keep their original id + assert messages[1]["id"] == "lc_run--abc123" + assert messages[3]["id"] == "lc_run--def456" + + # Human and tool messages get deterministic ids (same input = same output) + human_id_1 = str(uuid.uuid5(uuid.NAMESPACE_URL, "t1:1")) + assert messages[0]["id"] == human_id_1 + + @pytest.mark.asyncio + async def test_empty_thread(self, event_store): + events = await event_store.list_messages("nonexistent", limit=500) + messages = [evt["content"] for evt in events if isinstance(evt.get("content"), dict)] + assert messages == [] + + @pytest.mark.asyncio + async def test_tool_call_fields_preserved(self, event_store): + await _seed_conversation(event_store) + events = await event_store.list_messages("t1", limit=500) + messages = [evt["content"] for evt in events if isinstance(evt.get("content"), dict) and "type" in evt["content"]] + + # AI tool_call message + ai_tc = messages[1] + assert ai_tc["tool_calls"][0]["name"] == "search" + assert ai_tc["tool_calls"][0]["id"] == "call_1" + + # Tool result + tool = messages[2] + assert tool["tool_call_id"] == "call_1" + assert tool["status"] == "success" +``` + +- [ ] **Step 2: Run tests to verify they pass** + +Run: `cd backend && PYTHONPATH=. uv run pytest tests/test_thread_state_event_store.py -v` + +- [ ] **Step 3: Add the helper function and modify `get_thread_history`** + +In `backend/app/gateway/routers/threads.py`: + +1. Add import at the top: +```python +import uuid # ADD (may already exist, check first) +from app.gateway.deps import get_run_event_store # ADD +``` + +2. Add the helper function (before the endpoint functions, after the model definitions): + +```python +_LEGACY_CMD_INNER_CONTENT_RE = re.compile( + r"ToolMessage\(content=(?P['\"])(?P.*?)(?P=q)", + re.DOTALL, +) + + +def _sanitize_legacy_command_repr(content_field: Any) -> Any: + """Recover the inner ToolMessage text from a legacy ``str(Command(...))`` repr. + + Runs that pre-date the ``on_tool_end`` fix in ``journal.py`` stored + ``str(Command(update={'messages':[ToolMessage(content='X', ...)]}))`` as the + tool_result content. New runs store ``'X'`` directly. For old threads, try + to extract ``'X'`` defensively; return the original string if extraction + fails (still no worse than the current checkpoint-based fallback, which is + broken for summarized threads anyway). + """ + if not isinstance(content_field, str) or not content_field.startswith("Command(update="): + return content_field + match = _LEGACY_CMD_INNER_CONTENT_RE.search(content_field) + return match.group("inner") if match else content_field + + +async def _get_event_store_messages(request: Request, thread_id: str) -> list[dict] | None: + """Load messages from the event store, returning None if unavailable. + + The event store is append-only and immune to summarization. Each + message event's ``content`` field contains a ``model_dump()``'d + LangChain Message dict that is already JSON-serialisable. + + **Full pagination, not a fixed limit.** ``RunEventStore.list_messages`` + returns the newest ``limit`` records when no cursor is given, which + silently drops older messages. We call ``count_messages()`` first and + request that many records. For stores that may return fewer (e.g. filtered + by user), we also fall back to ``after_seq``-cursor pagination. + + **Copy-on-read.** Each content dict is copied before ``id`` is patched so + the live store object is never mutated; ``MemoryRunEventStore`` returns + live references. + + **Legacy Command repr sanitization.** See ``_sanitize_legacy_command_repr``. + + **User context.** ``DbRunEventStore`` is user-scoped by default via + ``resolve_user_id(AUTO)`` (see ``runtime/user_context.py``). Callers of + this helper must be inside a request where ``@require_permission`` has + populated the user contextvar. Both ``get_thread_history`` and + ``get_thread_state`` satisfy that. Do not call this helper from CLI or + migration scripts without passing ``user_id=None`` explicitly. + + Returns ``None`` when the event store is not configured or contains no + messages for this thread, so callers can fall back to checkpoint messages. + """ + try: + event_store = get_run_event_store(request) + except Exception: + return None + + try: + total = await event_store.count_messages(thread_id) + except Exception: + logger.exception("count_messages failed for thread %s", sanitize_log_param(thread_id)) + return None + if not total: + return None + + # Batch by page_size to keep memory bounded for very long threads. + page_size = 500 + collected: list[dict] = [] + after_seq: int | None = None + while True: + page = await event_store.list_messages(thread_id, limit=page_size, after_seq=after_seq) + if not page: + break + collected.extend(page) + if len(page) < page_size: + break + after_seq = page[-1].get("seq") + if after_seq is None: + break + + messages: list[dict] = [] + for evt in collected: + raw = evt.get("content") + if not isinstance(raw, dict) or "type" not in raw: + continue + # Copy to avoid mutating the store-owned dict. + content = dict(raw) + if content.get("id") is None: + content["id"] = str(uuid.uuid5(uuid.NAMESPACE_URL, f"{thread_id}:{evt['seq']}")) + # Sanitize legacy Command reprs on tool_result messages only. + if content.get("type") == "tool": + content["content"] = _sanitize_legacy_command_repr(content.get("content")) + messages.append(content) + return messages if messages else None +``` + +Also add `import re` at the top of the file if it isn't already imported. + +3. In `get_thread_history` (around line 585-590), replace the messages section: + +**Before:** +```python + # Attach messages from checkpointer only for the latest checkpoint + if is_latest_checkpoint: + messages = channel_values.get("messages") + if messages: + values["messages"] = serialize_channel_values({"messages": messages}).get("messages", []) + is_latest_checkpoint = False +``` + +**After:** +```python + # Attach messages: prefer event store (immune to summarization), + # fall back to checkpoint messages when event store is unavailable. + if is_latest_checkpoint: + es_messages = await _get_event_store_messages(request, thread_id) + if es_messages is not None: + values["messages"] = es_messages + else: + messages = channel_values.get("messages") + if messages: + values["messages"] = serialize_channel_values({"messages": messages}).get("messages", []) + is_latest_checkpoint = False +``` + +- [ ] **Step 4: Modify `get_thread_state` similarly** + +In `get_thread_state` (around line 443-444), replace: + +**Before:** +```python + return ThreadStateResponse( + values=serialize_channel_values(channel_values), +``` + +**After:** +```python + values = serialize_channel_values(channel_values) + + # Override messages with event store data (immune to summarization) + es_messages = await _get_event_store_messages(request, thread_id) + if es_messages is not None: + values["messages"] = es_messages + + return ThreadStateResponse( + values=values, +``` + +- [ ] **Step 5: Run all backend tests** + +Run: `cd backend && PYTHONPATH=. uv run pytest tests/ -v --timeout=30 -x` + +- [ ] **Step 6: Commit** + +```bash +git add backend/app/gateway/routers/threads.py backend/tests/test_thread_state_event_store.py +git commit -m "feat(threads): load messages from event store instead of checkpoint state + +Event store is append-only and immune to summarization. Messages with +null ids (human, tool) get deterministic UUIDs based on thread_id:seq +for stable frontend rendering." +``` + +--- + +### Task 2 (OPTIONAL, deferred): Reduce flush_threshold for shorter mid-stream gap + +**Status:** Not a correctness fix. Re-evaluation (see spec) found that `RunJournal` already flushes on `run_end`, `run_error`, cancel, and worker `finally` paths. The only window this tuning narrows is a hard process crash or mid-run reload. Defer and decide separately; do not couple with Task 1 merge. + +If pursued: change `flush_threshold` default from 20 → 5 in `journal.py:42`, rerun `tests/test_run_journal.py`, commit as a separate `perf(journal): …` commit. + +--- + +### Task 3: Fix `useThreadFeedback` pagination in frontend + +Once `/history` returns the full event-store-backed message stream, the frontend's `runIdByAiIndex` map must also cover the full stream or its positional AI-index mapping drifts and feedback clicks go to the wrong `run_id`. The current hook hardcodes `limit=200`. + +**Files:** +- Modify: `frontend/src/core/threads/hooks.ts` (around line 679) + +- [ ] **Step 1: Replace the fixed `?limit=200` with full pagination** + +Change: + +```ts +const res = await fetchWithAuth( + `${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/messages?limit=200`, +); +``` + +to a loop that pages via `after_seq` (or an equivalent query param exposed by the `/messages` endpoint — check `backend/app/gateway/routers/thread_runs.py:285-323` for the actual parameter names before writing the TS code). Accumulate `messages` until a page returns fewer than the page size. + +- [ ] **Step 2: Defensive index guard** + +`runIdByAiIndex[aiMessageIndex]` can still be `undefined` when the frontend renders optimistic state before the messages query refreshes. The current `?? undefined` in `message-list.tsx:71` already handles this; do not remove it. + +- [ ] **Step 3: Invalidate `["thread-feedback", threadId]` after a new run** + +In `useThreadStream` (or wherever stream-end is handled), call `queryClient.invalidateQueries({ queryKey: ["thread-feedback", threadId] })` when the stream closes so the runIdByAiIndex picks up the new run's AI message immediately. + +- [ ] **Step 4: Run `pnpm check`** + +```bash +cd frontend && pnpm check +``` + +- [ ] **Step 5: Commit** + +```bash +git add frontend/src/core/threads/hooks.ts +git commit -m "fix(feedback): paginate useThreadFeedback and invalidate after stream" +``` + +--- + +### Task 4: End-to-end test — summarize + multi-run feedback + +Add a regression test that exercises the exact bug class we are fixing: a summarized thread with at least two runs, where feedback clicks must target the correct `run_id`. + +**Files:** +- Modify: `backend/tests/test_thread_state_event_store.py` + +- [ ] **Step 1: Write the test** + +Seed a `MemoryRunEventStore` with two runs worth of messages (`r1`: human + ai + human + ai, `r2`: human + ai), then simulate a summarized checkpoint state that drops the `r1` messages. Call `_get_event_store_messages` and assert: + +- Length matches the event store, not the checkpoint +- The first message is the original `r1` human, not a summary +- AI messages preserve their `lc_run--*` ids in order +- Any `id=None` messages get a stable `uuid5(...)` id +- A legacy `str(Command(update=...))` content field in a tool_result is sanitized to the inner text + +- [ ] **Step 2: Run the new test** + +```bash +cd backend && PYTHONPATH=. uv run pytest tests/test_thread_state_event_store.py -v +``` + +- [ ] **Step 3: Commit with Tasks 1, 3 changes** + +Bundle with the Task 1 commit so tests always land alongside the implementation. + +--- + +### Task 5: Standard mode follow-up (documentation only) + +Standard mode (`make dev`) hits LangGraph Server directly for `/threads/{id}/history` and does not go through the Gateway router we just patched. The summarize bug is still present there. + +**Files:** +- Modify: this plan (add follow-up section at the bottom, see below) OR create a separate tracking issue + +- [ ] **Step 1: Record the gap** + +Append to the bottom of this plan (or open a GitHub issue and link it): + +> **Follow-up — Standard mode summarize bug** +> `get_thread_history` in `backend/app/gateway/routers/threads.py` is only hit in Gateway mode. Standard mode proxies `/api/langgraph/*` directly to the LangGraph Server (see `backend/CLAUDE.md` nginx routing and `frontend/CLAUDE.md` `NEXT_PUBLIC_LANGGRAPH_BASE_URL`). The summarize-message-loss symptom is still reproducible there. Options: (a) teach the LangGraph Server checkpointer to branch on an override, (b) move `/history` behind Gateway in Standard mode as well, (c) accept as known limitation for Standard mode. Decide before GA. diff --git a/docs/superpowers/specs/2026-04-11-runjournal-history-evaluation.md b/docs/superpowers/specs/2026-04-11-runjournal-history-evaluation.md new file mode 100644 index 000000000..44a466960 --- /dev/null +++ b/docs/superpowers/specs/2026-04-11-runjournal-history-evaluation.md @@ -0,0 +1,191 @@ +# RunJournal 替换 History Messages — 方案评估与对比 + +**日期**:2026-04-11 +**分支**:`rayhpeng/fix-persistence-new` +**相关 plan**:[`docs/superpowers/plans/2026-04-10-event-store-history.md`](../plans/2026-04-10-event-store-history.md)(尚未落地) + +--- + +## 1. 问题与数据核对 + +**症状**:SummarizationMiddleware 触发后,前端历史中无法展示 summarize 之前的真实用户消息。 + +**复现数据**(thread `6d30913e-dcd4-41c8-8941-f66c716cf359`): + +| 数据源 | seq=1 的 message | 总 message 数 | 是否保留原始 human | +|---|---|---:|---| +| `run_events`(SQLite) | human `"最新伊美局势"` | 9(1 human + 7 ai_tool_call + 9 tool_result + 1 ai_message) | ✅ | +| `/history` 响应(`docs/resp.json`) | type=human,content=`"Here is a summary of the conversation to date:…"` | 不定 | ❌(已被 summary 替换)| + +**根因**:`backend/app/gateway/routers/threads.py:587-589` 的 `get_thread_history` 从 `checkpoint.channel_values["messages"]` 读取,而 LangGraph 的 SummarizationMiddleware 会原地改写这个列表。 + +--- + +## 2. 候选方案 + +| 方案 | 描述 | 本次是否推荐 | +|---|---|---| +| **A. event_store 覆盖 messages**(已有 plan) | `/history`、`/state` 改读 `RunEventStore.list_messages()`,覆盖 `channel_values["messages"]`;其它字段保持 checkpoint 来源 | ✅ 主方案 | +| B. 修 SummarizationMiddleware | 让 summarize 不原地替换 messages(作为附加 system message) | ❌ 违背 summarize 的 token 预算初衷 | +| C. 双读合并(checkpoint + event_store diff) | 合并 summarize 切点前后的两段 | ❌ 合并逻辑复杂无额外收益 | +| D. 切到现有 `/api/threads/{id}/messages` 端点 | 前端直接消费已经存在的 event-store 消息端点(`thread_runs.py:285-323`)| ⚠️ 更干净但需要前端改动 | + +--- + +## 3. Claude 自评 vs Codex 独立评估 + +两方独立分析了同一份 plan。重合点基本一致,但 **Codex 发现了一个我遗漏的关键 bug**。 + +### 3.1 一致结论 + +| 维度 | 结论 | +|---|---| +| 正确性方向 | event_store 是 append-only + 不受 summarize 影响,方向正确 | +| ID 补齐 | `uuid5(NAMESPACE_URL, f"{thread_id}:{seq}")` 稳定且确定性,安全 | +| 前端 schema | 零改动 | +| Non-message 字段(artifacts/todos/title/thread_data) | summarize 只影响 messages,不需要覆盖其它字段 | +| 多 checkpoint 语义 | 前端 `useStream` 只取 `limit: 1`(`frontend/src/core/threads/hooks.ts:203-210`),不做时间旅行;latest-only 可接受但应在注释/文档写清楚 | +| 作用域 | 仅 Gateway mode;Standard mode 直连 LangGraph Server,bug 在默认部署路径仍然存在 | + +### 3.2 Claude 的独立观察 + +1. 已验证数据对齐:plan 文档第 15-28 行的真实数据对齐表与本次 `run_events` 导出一致(9 条消息 id 分布:AI 来自 LLM `lc_run--*`、human/tool 为 None)。 +2. 担心 `run_end` / `run_error` / `cancel` 路径未必都 flush —— 这一点 Codex 实际核查了代码并给出确定结论(见下)。 +3. 方案 A 的单文件改动约 60 行,复杂度小。 + +### 3.3 Codex 的关键补充(Claude 遗漏) + +> **Bug #1 — Plan 用 `limit=1000` 并非全量** +> `RunEventStore.list_messages()` 的语义是"返回最新 limit 条"(`base.py:51-65`、`db.py:151-181`)。对于消息数超过 1000 的长对话,plan 当前写法会**丢掉最早的消息**,再次引入"消息丢失"bug(只是换了丢失的段)。 + +> **Bug #2 — helper 就地修改了 store 的 dict** +> plan 的 helper 里对 `content` 原地写 `id`;`MemoryRunEventStore` 返回的是**活引用**,会污染 store 中的对象。应 deep-copy 或 dict 推导出新对象。 + +> **Flush 路径已核查**: +> `RunJournal` 在 threshold (`journal.py:360-373`)、`run_end` (`91-96`)、`run_error` (`97-106`)、worker `finally` (`worker.py:280-286`) 都会 flush;`CancelledError` 也走 finally。**正常 end/error/cancel 都 flush,仅硬 kill / 进程崩溃会丢缓冲区**。 +> 因此 `flush_threshold 20 → 5` 的意义**仅在于硬崩溃窗口**与 mid-run reload 可见性,**不是正确性修复**,属于可选 tuning。代价是更多 put_batch / SQLite churn;且 `_flush_sync()` (`383-398`) 已防止并发 flush,所以"每 5 条一 flush"是 best-effort 非严格保证。 + +### 3.4 Codex 未否决但提示的次要点 + +- 方案 D(消费现有 `/api/threads/{id}/messages` 端点)更干净但需前端改动。 +- `/history` 一旦被方案 A 改过,就不再是严格意义上的"按 checkpoint 快照"API(对 messages 字段),应写进注释和 API 文档。 +- Standard mode 的 summarize bug 应建立独立 follow-up issue。 + +--- + +## 4. 最终合并判决 + +**Codex**:APPROVE-WITH-CHANGES +**Claude**:同意 Codex 的判决 + +### 合并前必须修改(Top 3) + +1. **修复分页 bug**:不能用固定 `limit=1000`。必须用以下之一: + - `count = await event_store.count_messages(thread_id)`,再 `list_messages(thread_id, limit=count)` + - 或循环 cursor 分页(`after_seq`)直到耗尽 +2. **不要原地修改 store dict**:helper 对 `content` 的 id 补齐需要 copy(`dict(content)` 浅拷贝足够,因为只写 top-level `id`) +3. **Standard mode 显式 follow-up**:在 plan 文末加 "Standard-mode follow-up: TODO #xxx",或在合并 PR 描述中明确这是 Gateway-only 止血 + +### 可选(非阻塞) + +4. `flush_threshold 20 → 5` 降级为"可选 tuning",不是修复的一部分;或独立一条 commit 并说明只对硬崩溃窗口有用 +5. `get_thread_history` 新增注释,说明 messages 字段脱离了 checkpoint 快照语义 +6. 测试覆盖:模拟 summarize 后的 checkpoint + 真实 event_store,端到端验证 `/history` 返回包含原始 human 消息 + +--- + +## 5. 推荐执行顺序 + +1. 按本文档 §4 修订 `docs/superpowers/plans/2026-04-10-event-store-history.md`(主要是 Task 1 的 helper 实现 + 分页) +2. 按修订后的 plan 执行(走 `superpowers:executing-plans`) +3. 合并后立即建 Standard mode follow-up issue + +## 6. Feedback 影响分析(2026-04-11 补充) + +### 6.1 数据模型 + +`feedback` 表(`persistence/feedback/model.py`): + +| 字段 | 说明 | +|---|---| +| `feedback_id` PK | - | +| `run_id` NOT NULL | 反馈目标 run | +| `thread_id` NOT NULL | - | +| `user_id` | - | +| `message_id` nullable | 注释明确写:`optional RunEventStore event identifier` — 已经面向 event_store 设计 | +| UNIQUE(thread_id, run_id, user_id) | 每 run 每用户至多一条 | + +**结论**:feedback **不按 message uuid 存**,按 `run_id` 存,所以 summarize 导致的 checkpoint messages 丢失**不会影响 feedback 存储**。schema 天生与 event_store 兼容,**无需数据迁移**。 + +### 6.2 前端的 runId 映射:发现隐藏 bug + +前端 feedback 目前走两条并行的数据链: + +| 用途 | 数据源 | 位置 | +|---|---|---| +| 渲染消息体 | `POST /history`(checkpoint) | `useStream` → `thread.messages` | +| 拿 `runId` 映射 | `GET /api/threads/{id}/messages?limit=200`(**event_store**) | `useThreadFeedback` (`hooks.ts:669-709`) | + +两者通过 **"AI 消息的序号"** 对齐: + +```ts +// hooks.ts:691-698 +for (const msg of messages) { + if (msg.event_type === "ai_message") { + runIdByAiIndex.push(msg.run_id); // 只按 AI 顺序 push + } +} +// message-list.tsx:70-71 +runId = feedbackData.runIdByAiIndex[aiMessageIndex] +``` + +**Bug**:summarize 过的 thread 里,两条数据链的 AI 消息数量和顺序**不一致**: + +| 数据源 | 本 thread 的 AI 消息序列 | 数量 | +|---|---|---:| +| `/history`(checkpoint,summarize 后) | seq=19,31,37,45,53 | 5 | +| `/messages`(event_store,完整) | seq=5,13,19,31,37,45,53 | 7 | + +结果:前端渲染的"第 0 条 AI 消息"是 seq=19,但 `runIdByAiIndex[0]` 指向 seq=5 的 run(本例同一 run 里没事,**跨多 run 的 thread 点赞就会打到错的 run 上**)。 + +**这个 bug 和本次 plan 无关,已经存在了**。只是用户未必注意到。 + +### 6.3 方案 A 对 feedback 的影响 + +**负面**:无。feedback 存储不受影响。 + +**正面(意外收益)**:`/history` 切换到 event_store 后,**两条数据链的 AI 消息序列自动对齐**,§6.2 的隐藏 bug 被顺带修好。 + +**前提条件**(加入 Top 3 改动之一同等重要): + +- 新 helper 必须和 `/messages` 端点用**同样的消息获取逻辑**(same store, same filter)。否则两条链仍然可能在边界条件下漂移 +- 具体说:**两边都要做完整分页**。目前 `/messages?limit=200` 在前端硬编码 200,如果 thread 有 >200 条消息就会截断;plan 的 `limit=1000` 也一样有上限。两个上限不一致 → 两边顺序不再对齐 → feedback 映射错位 +- **必须修**:`useThreadFeedback` 的 `limit=200` 需要改成分页获取全部,或者 `/messages` 后端改为默认全量 + +### 6.4 对前端改造顺序的影响 + +原 plan 声明"零前端改动",但加入 feedback 考虑后应修正为: + +| 改动 | 必须 | 可选 | +|---|---|---| +| 后端 `/history` 改读 event_store | ✅ | - | +| 后端 helper 用分页而非 `limit=1000` | ✅ | - | +| 前端 `useThreadFeedback` 改用分页或提升 limit | ✅ | - | +| `runIdByAiIndex` 增加防御:索引越界 fallback `undefined`(已有)| - | ✅ 已经是 | +| 前端改用 `/messages` 直接做渲染(方案 D) | - | ✅ 长期更干净 | + +### 6.5 feedback 相关的新 Top 3 补充 + +在原来的 Top 3 之外,再加: + +4. **前端 `useThreadFeedback` 必须分页或拉全**(`frontend/src/core/threads/hooks.ts:679`),否则和 `/history` 的新全量行为仍然错位 +5. **端到端测试**:一个 thread 跨 >1 个 run + 触发 summarize + 给历史 AI 消息点赞,确认 feedback 打到正确的 run_id +6. **TanStack Query 缓存协调**:`thread-feedback` 与 history 查询的 `staleTime` / invalidation 需要在新 run 结束时同步刷新,否则新消息写入后 `runIdByAiIndex` 没更新,点赞会打到上一个 run + +--- + +## 8. 未决问题 + +- `RunEventStore.count_messages()` 与 `list_messages(after_seq=...)` 的实际性能(SQLite 上对于数千消息级别应无问题,但未压测) +- `MemoryRunEventStore` 与 `DbRunEventStore` 分页语义是否一致(Codex 只核查了 `db.py`,`memory.py` 需确认) +- 是否应把 `/api/threads/{id}/messages` 提升为前端主用 endpoint,把 `/history` 保留为纯 checkpoint API —— 架构层面更干净但成本更高 diff --git a/docs/superpowers/specs/2026-04-11-summarize-marker-design.md b/docs/superpowers/specs/2026-04-11-summarize-marker-design.md new file mode 100644 index 000000000..79cd748d4 --- /dev/null +++ b/docs/superpowers/specs/2026-04-11-summarize-marker-design.md @@ -0,0 +1,203 @@ +# Summarize Marker in History — Design & Verification + +**Date**: 2026-04-11 +**Branch**: `rayhpeng/fix-persistence-new` +**Status**: Design approved, implementation deferred to a follow-up PR +**Depends on**: [`2026-04-11-runjournal-history-evaluation.md`](./2026-04-11-runjournal-history-evaluation.md) (the event-store-backed history fix this builds on) + +--- + +## 1. Goal + +Display a "summarization happened here" marker in the conversation history UI when `SummarizationMiddleware` ran mid-run, so users understand why earlier messages look condensed or missing. The event-store-backed `/history` fix already recovered the original messages; this spec adds a **visible marker** at the seq position where summarization occurred, optionally showing the generated summary text. + +## 2. Investigation findings + +### 2.1 Today's state: zero middleware records + +Full scan of `backend/.deer-flow/data/deerflow.db` `run_events`: + +| category | rows | +|---|---:| +| trace | 76 | +| message | 34 | +| lifecycle | 8 | +| **middleware** | **0** | + +No row has `event_type` containing `summariz` or `middleware`. The middleware category is dead in production. + +### 2.2 Why: two dead code paths in `journal.py` + +| Location | Status | +|---|---| +| `journal.py:343-362` — `on_custom_event("summarization", ...)` writes one trace event + one `category="middleware"` event. | Dead. Only fires when something calls `adispatch_custom_event("summarization", {...})`. The upstream LangChain `SummarizationMiddleware` (`.venv/.../langchain/agents/middleware/summarization.py:272`) **never emits custom events** — its `before_model`/`abefore_model` just mutate messages in place and return `{'messages': new_messages}`. Callback never triggered. | +| `journal.py:449` — `record_middleware(tag, *, name, hook, action, changes)` helper | Dead. Grep shows zero callers in the harness. Added speculatively, never wired up. | + +### 2.3 Concrete evidence of summarize running unlogged + +Thread `3d5dea4a-0983-4727-a4e8-41a64428933a`: + +- `run_events` seq=1 → original human `"写一份关于deer-flow的详细技术报告"` ✓ (event store is fine) +- `run_events` seq=43 → `llm_request` trace whose `messages[0]` literal contains `"Here is a summary of the conversation to date:"` — proof that SummarizationMiddleware did inject a summary mid-run +- Zero rows with `category='middleware'` for this thread → nothing captured for UI to render + +## 3. Approaches considered + +### A. Subclass `SummarizationMiddleware` and dispatch a custom event + +Wrap the upstream class, override `abefore_model`, call `await adispatch_custom_event("summarization", {...})` after super(). Journal's existing `on_custom_event` path captures it. + +### B. Frontend-only diff heuristic + +Compare `event_store.count_messages()` vs rendered count, infer summarization happened from the gap. **Rejected**: can't pinpoint position in the stream, can't show summary text. Only yields a vague badge. + +### C. Hybrid A + frontend inline card rendered at the middleware event's seq position + +Same backend as A, plus frontend renders an inline `[N messages condensed]` card at the correct chronological position. **Recommended terminal state**. + +## 4. Subagent's wrong claim and its rebuttal + +An independent agent flagged approach A as structurally broken because: + +> `RunnableCallable(trace=False)` skips `set_config_context`, therefore `var_child_runnable_config` is never set, therefore `adispatch_custom_event` raises `RuntimeError("Unable to dispatch an adhoc event without a parent run id")`. + +**This is wrong.** The user's counter-intuition was correct: `trace=False` does not prevent `adispatch_custom_event` from working, as long as the middleware signature explicitly accepts `config: RunnableConfig`. The mechanism: + +1. `RunnableCallable.__init__` (`langgraph/_internal/_runnable.py:293-319`) inspects the function signature. If it accepts `config: RunnableConfig`, that parameter is recorded in `self.func_accepts`. +2. Both `trace=True` and `trace=False` branches of `ainvoke` run the same kwarg-injection loop (`_runnable.py:349-356`): `if kw == "config": kw_value = config`. The `config` passed to `ainvoke` (from Pregel's `task.proc.ainvoke(task.input, config)` at `pregel/_retry.py:138`) is the task config with callbacks already bound. +3. Inside the middleware, passing that `config` explicitly to `adispatch_custom_event(..., config=config)` means the function doesn't rely on `var_child_runnable_config.get()` at all. The LangChain docstring at `langchain_core/callbacks/manager.py:2574-2579` even says "If using python 3.10 and async, you MUST specify the config parameter" — which is exactly this path. + +`trace=False` only changes whether **this runnable layer creates a new child callback scope**. It does not affect whether the outer-layer config (with callbacks including `RunJournal`) is passed down to the function. + +## 5. Verification + +Ran `/tmp/verify_summarize_event.py` (standalone minimal reproduction): + +- Minimal `AgentMiddleware` subclass with `abefore_model(self, state, runtime, config: RunnableConfig)` +- Calls `await adispatch_custom_event("summarization", {...}, config=config)` inside +- `create_agent(model=FakeChatModel, middleware=[probe])` +- `agent.ainvoke({...}, config={"callbacks": [RecordingHandler()]})` + +**Result**: + +``` +INFO verify: ProbeMiddleware.abefore_model called +INFO verify: config keys: ['callbacks', 'configurable', 'metadata'] +INFO verify: config.callbacks type: AsyncCallbackManager +INFO verify: config.metadata: {'langgraph_step': 1, 'langgraph_node': 'probe.before_model', ...} +INFO verify: on_custom_event fired: name=summarization + run_id=019d7d19-1727-7830-aa33-648ecbee4b95 + data={'summary': 'fake summary', 'replaced_count': 3} +SUCCESS: approach A is viable (config injection + adispatch work) +``` + +All five predictions held: + +1. ✅ `config: RunnableConfig` signature triggers auto-injection despite `trace=False` +2. ✅ `config.callbacks` is an `AsyncCallbackManager` with `parent_run_id` set +3. ✅ `adispatch_custom_event(..., config=config)` runs without error +4. ✅ `RecordingHandler.on_custom_event` receives the event +5. ✅ The received `run_id` is a valid UUID tied to the running graph + +**Bonus finding**: `config.metadata` contains `langgraph_step` and `langgraph_node`. These can be included in the middleware event's metadata to help the frontend position the marker on the timeline. + +## 6. Recommended implementation (approach C) + +### 6.1 Backend + +**New wrapper middleware** in `backend/packages/harness/deerflow/agents/lead_agent/agent.py`: + +```python +from langchain.agents.middleware.summarization import SummarizationMiddleware +from langchain_core.callbacks import adispatch_custom_event +from langchain_core.runnables import RunnableConfig + + +class _TrackingSummarizationMiddleware(SummarizationMiddleware): + """Wraps upstream SummarizationMiddleware to emit a ``summarization`` + custom event on every actual summarization, so RunJournal can persist + a middleware:summarize row to the event store. + + The upstream class does not emit events of its own. Declaring + ``config: RunnableConfig`` in the override lets LangGraph's + ``RunnableCallable`` inject the Pregel task config (with callbacks + and parent_run_id) regardless of ``trace=False`` on the node. + """ + + async def abefore_model(self, state, runtime, config: RunnableConfig): + before_count = len(state.get("messages") or []) + result = await super().abefore_model(state, runtime) + if result is None: + return None + + new_messages = result.get("messages") or [] + replaced_count = max(0, before_count - len(new_messages)) + summary_text = _extract_summary_text(new_messages) + + await adispatch_custom_event( + "summarization", + { + "summary": summary_text, + "replaced_count": replaced_count, + }, + config=config, + ) + return result + + +def _extract_summary_text(messages: list) -> str: + """Pull the summary string out of the HumanMessage the upstream class + injects as ``Here is a summary of the conversation to date:...``.""" + for msg in messages: + if getattr(msg, "type", None) == "human": + content = getattr(msg, "content", "") + text = content if isinstance(content, str) else "" + if text.startswith("Here is a summary of the conversation to date"): + return text + return "" +``` + +Swap the existing `SummarizationMiddleware()` instantiation in `_build_middlewares` for `_TrackingSummarizationMiddleware(...)` with the same args. + +**Journal change**: **zero**. `on_custom_event("summarization", ...)` in `journal.py:343-362` already writes both a trace and a `category="middleware"` row. + +**History helper change**: extend `_get_event_store_messages` in `backend/app/gateway/routers/threads.py` to surface `category="middleware"` rows as pseudo-messages, e.g.: + +```python +# In the per-event loop, after the existing message branch: +if evt.get("category") == "middleware" and evt.get("event_type") == "middleware:summarize": + meta = evt.get("metadata") or {} + messages.append({ + "id": f"summary-marker-{evt['seq']}", + "type": "summary_marker", + "replaced_count": meta.get("replaced_count", 0), + "summary": (raw or {}).get("content", "") if isinstance(raw, dict) else "", + "run_id": evt.get("run_id"), + }) +``` + +The marker uses a sentinel `type` (`summary_marker`) that doesn't collide with any LangChain message type, so downstream consumers that loop over messages can skip or render it explicitly. + +### 6.2 Frontend + +- `core/messages/utils.ts`: extend the message grouping to recognize `type === "summary_marker"` and yield it as its own group (`"assistant:summary-marker"`) +- `components/workspace/messages/message-list.tsx`: add a branch in the grouped render switch that renders a distinctive inline card showing `N messages condensed` and a collapsible panel with the summary text +- No changes to feedback logic: the marker has no `feedback` field so the button naturally doesn't render on it + +## 7. Risks + +1. **Synchronous path**. The upstream class has both `before_model` and `abefore_model`. Our wrapper only overrides the async variant. If any deer-flow code path ever uses the sync flow, those summarizations won't be captured. Mitigation: also override `before_model` and use `dispatch_custom_event` (sync variant) with the same pattern. +2. **`_extract_summary_text` fragility**. It depends on the upstream class prefix `"Here is a summary of the conversation to date"` in the injected `HumanMessage`. Any upstream template change breaks detection. Mitigation: pick the first new `HumanMessage` that wasn't in `state["messages"]` before super() — resilient to template wording changes at the cost of a small diff helper. +3. **`replaced_count` accuracy when concurrent updates**. If another middleware in the chain also modifies `state["messages"]` before super() returns, the naive `before_count - len(new_messages)` arithmetic is wrong. Mitigation: inspect the `RemoveMessage(id=REMOVE_ALL_MESSAGES)` that upstream emits and count from the original input list directly. +4. **History helper contract change**. Introducing a non-LangChain-typed entry (`type="summary_marker"`) in the `/history` response could break frontend code that blindly casts entries to `Message`. Mitigation: the frontend change above adds an explicit branch; type-check the frontend end-to-end before merging. + +## 8. Out of scope / deferred + +- Other middleware types (Title, Guardrail, HITL) do not emit custom events either. If we want markers for those too, repeat the wrapper pattern for each. Not in this design. +- Retroactive markers for old threads (captured before this patch) are impossible without re-running the graph. Legacy threads will show the event-store-recovered messages without a marker. +- Standard mode (`make dev`) — agent runs inside LangGraph Server, not the Gateway-embedded runtime. `RunJournal` may not be wired there, so the custom event fires but is captured by no one. Tracked as a separate follow-up. + +## 9. Next actions + +1. Land the current summarize-message-loss fixes (journal `Command` unwrap + event-store-backed `/history` + inline feedback) — implementation verified, being committed now as three commits on `rayhpeng/fix-persistence-new` +2. Summarize-marker implementation (this spec) → separate follow-up PR based on the above verified design diff --git a/frontend/src/components/workspace/messages/message-list-item.tsx b/frontend/src/components/workspace/messages/message-list-item.tsx index a5faf4cd1..faad2373e 100644 --- a/frontend/src/components/workspace/messages/message-list-item.tsx +++ b/frontend/src/components/workspace/messages/message-list-item.tsx @@ -1,6 +1,6 @@ import type { Message } from "@langchain/langgraph-sdk"; -import { FileIcon, Loader2Icon } from "lucide-react"; -import { memo, useMemo, type ImgHTMLAttributes } from "react"; +import { FileIcon, Loader2Icon, ThumbsDownIcon, ThumbsUpIcon } from "lucide-react"; +import { memo, useCallback, useMemo, useState, type ImgHTMLAttributes } from "react"; import rehypeKatex from "rehype-katex"; import { Loader } from "@/components/ai-elements/loader"; @@ -17,6 +17,11 @@ import { } from "@/components/ai-elements/reasoning"; import { Task, TaskTrigger } from "@/components/ai-elements/task"; import { Badge } from "@/components/ui/badge"; +import { + deleteFeedback, + upsertFeedback, + type FeedbackData, +} from "@/core/api/feedback"; import { resolveArtifactURL } from "@/core/artifacts/utils"; import { useI18n } from "@/core/i18n/hooks"; import { @@ -34,16 +39,85 @@ import { CopyButton } from "../copy-button"; import { MarkdownContent } from "./markdown-content"; +function FeedbackButtons({ + threadId, + runId, + initialFeedback, +}: { + threadId: string; + runId: string; + initialFeedback: FeedbackData | null; +}) { + const [feedback, setFeedback] = useState(initialFeedback); + const [isSubmitting, setIsSubmitting] = useState(false); + + const handleClick = useCallback( + async (rating: number) => { + if (isSubmitting) return; + setIsSubmitting(true); + try { + if (feedback?.rating === rating) { + await deleteFeedback(threadId, runId); + setFeedback(null); + } else { + const result = await upsertFeedback(threadId, runId, rating); + setFeedback(result); + } + } catch { + // Revert on error — feedback state unchanged on catch + } finally { + setIsSubmitting(false); + } + }, + [threadId, runId, feedback, isSubmitting], + ); + + return ( +
+ + +
+ ); +} + export function MessageListItem({ className, + threadId, message, isLoading, - threadId, + feedback, + runId, }: { className?: string; + threadId: string; message: Message; isLoading?: boolean; - threadId: string; + feedback?: FeedbackData | null; + runId?: string; }) { const isHuman = message.type === "human"; return ( @@ -61,7 +135,7 @@ export function MessageListItem({
@@ -72,6 +146,13 @@ export function MessageListItem({ "" } /> + {feedback !== undefined && runId && threadId && ( + + )}
)} diff --git a/frontend/src/components/workspace/messages/message-list.tsx b/frontend/src/components/workspace/messages/message-list.tsx index b7089bb72..b02ec2716 100644 --- a/frontend/src/components/workspace/messages/message-list.tsx +++ b/frontend/src/components/workspace/messages/message-list.tsx @@ -18,6 +18,7 @@ import { useRehypeSplitWordsIntoSpans } from "@/core/rehype"; import type { Subtask } from "@/core/tasks"; import { useUpdateSubtask } from "@/core/tasks/context"; import type { AgentThreadState } from "@/core/threads"; +import { useThreadMessageEnrichment } from "@/core/threads/hooks"; import { cn } from "@/lib/utils"; import { ArtifactFileList } from "../artifacts/artifact-file-list"; @@ -47,6 +48,8 @@ export function MessageList({ const rehypePlugins = useRehypeSplitWordsIntoSpans(thread.isLoading); const updateSubtask = useUpdateSubtask(); const messages = thread.messages; + const { data: enrichment } = useThreadMessageEnrichment(threadId); + if (thread.isThreadLoading && messages.length === 0) { return ; } @@ -58,12 +61,15 @@ export function MessageList({ {groupMessages(messages, (group) => { if (group.type === "human" || group.type === "assistant") { return group.messages.map((msg) => { + const entry = msg.id ? enrichment?.get(msg.id) : undefined; return ( ); }); @@ -167,7 +173,7 @@ export function MessageList({ results.push(
{t.subtasks.executing(tasks.size)}
, diff --git a/frontend/src/core/api/feedback.ts b/frontend/src/core/api/feedback.ts new file mode 100644 index 000000000..5af3f02c8 --- /dev/null +++ b/frontend/src/core/api/feedback.ts @@ -0,0 +1,42 @@ +import { getBackendBaseURL } from "../config"; + +import { fetchWithAuth } from "./fetcher"; + +export interface FeedbackData { + feedback_id: string; + rating: number; + comment: string | null; +} + +export async function upsertFeedback( + threadId: string, + runId: string, + rating: number, + comment?: string, +): Promise { + const res = await fetchWithAuth( + `${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/runs/${encodeURIComponent(runId)}/feedback`, + { + method: "PUT", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ rating, comment: comment ?? null }), + }, + ); + if (!res.ok) { + throw new Error(`Failed to submit feedback: ${res.status}`); + } + return res.json(); +} + +export async function deleteFeedback( + threadId: string, + runId: string, +): Promise { + const res = await fetchWithAuth( + `${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/runs/${encodeURIComponent(runId)}/feedback`, + { method: "DELETE" }, + ); + if (!res.ok && res.status !== 404) { + throw new Error(`Failed to delete feedback: ${res.status}`); + } +} diff --git a/frontend/src/core/threads/hooks.ts b/frontend/src/core/threads/hooks.ts index 33424aeee..3b0aafda2 100644 --- a/frontend/src/core/threads/hooks.ts +++ b/frontend/src/core/threads/hooks.ts @@ -8,6 +8,7 @@ import { toast } from "sonner"; import type { PromptInputMessage } from "@/components/ai-elements/prompt-input"; import { getAPIClient } from "../api"; +import type { FeedbackData } from "../api/feedback"; import { fetchWithAuth } from "../api/fetcher"; import { getBackendBaseURL } from "../config"; import { useI18n } from "../i18n/hooks"; @@ -294,6 +295,9 @@ export function useThreadStream({ onFinish(state) { listeners.current.onFinish?.(state.values); void queryClient.invalidateQueries({ queryKey: ["threads", "search"] }); + void queryClient.invalidateQueries({ + queryKey: ["thread-message-enrichment"], + }); }, }); @@ -678,3 +682,65 @@ export function useRenameThread() { }, }); } + +/** Per-message enrichment data attached by the backend ``/history`` helper. */ +export interface MessageEnrichment { + run_id: string; + /** ``undefined`` = not feedback-eligible; ``null`` = eligible but unrated. */ + feedback?: FeedbackData | null; +} + +/** + * Fetch ``/history`` once and index feedback + run_id by message id. + * + * Replaces the old ``useThreadFeedback`` hook which keyed by AI-message + * ordinal position — an inherently fragile mapping that broke whenever + * ``ai_tool_call`` messages were interleaved with ``ai_message`` messages. + * Keying by ``message.id`` is stable regardless of run count, tool-call + * chains, or summarization. + * + * The ``/history`` response is refreshed on every stream completion via + * ``invalidateQueries(["thread-message-enrichment"])`` in ``onFinish``. + */ +export function useThreadMessageEnrichment( + threadId: string | null | undefined, +) { + return useQuery({ + queryKey: ["thread-message-enrichment", threadId], + queryFn: async (): Promise> => { + const empty = new Map(); + if (!threadId) return empty; + const res = await fetchWithAuth( + `${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadId)}/history`, + { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ limit: 1 }), + }, + ); + if (!res.ok) return empty; + const entries = (await res.json()) as Array<{ + values?: { + messages?: Array<{ + id?: string; + run_id?: string; + feedback?: FeedbackData | null; + }>; + }; + }>; + const messages = entries[0]?.values?.messages ?? []; + const map = new Map(); + for (const m of messages) { + if (!m.id || !m.run_id) continue; + const entry: MessageEnrichment = { run_id: m.run_id }; + // Preserve presence: "feedback" key absent → ineligible; present with + // null → eligible but unrated; present with object → rated. + if ("feedback" in m) entry.feedback = m.feedback; + map.set(m.id, entry); + } + return map; + }, + enabled: !!threadId, + staleTime: 30_000, + }); +}