diff --git a/backend/app/gateway/deps.py b/backend/app/gateway/deps.py index 8cd907af1..ebfb0c987 100644 --- a/backend/app/gateway/deps.py +++ b/backend/app/gateway/deps.py @@ -31,27 +31,23 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: from deerflow.config import get_app_config from deerflow.persistence.engine import close_engine, init_engine_from_config from deerflow.runtime import make_store, make_stream_bridge - from deerflow.runtime.runs.store.memory import MemoryRunStore async with AsyncExitStack() as stack: app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge()) app.state.checkpointer = await stack.enter_async_context(make_checkpointer()) app.state.store = await stack.enter_async_context(make_store()) - app.state.run_manager = RunManager() - # Initialize persistence layer from unified database config config = get_app_config() await init_engine_from_config(config.database) - # Initialize run store (MemoryRunStore for now; switch to ORM-backed - # RunRepository when models are implemented) - app.state.run_store = MemoryRunStore() + # Initialize run store (RunRepository if DB available, else MemoryRunStore) + app.state.run_store = _make_run_store() - # Initialize run event store (MemoryRunEventStore for now) - # TODO(Phase 2-B): switch to db/jsonl backend based on config.run_events.backend - from deerflow.runtime.events.store.memory import MemoryRunEventStore + # Initialize run event store based on config + app.state.run_event_store = _make_run_event_store(config) - app.state.run_event_store = MemoryRunEventStore() + # RunManager with store backing for persistence + app.state.run_manager = RunManager(store=app.state.run_store) try: yield @@ -59,6 +55,32 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: await close_engine() +# --------------------------------------------------------------------------- +# Factories +# --------------------------------------------------------------------------- + + +def _make_run_store() -> RunStore: + """Create a RunStore: RunRepository if DB engine is available, else MemoryRunStore.""" + from deerflow.persistence.engine import get_session_factory + + sf = get_session_factory() + if sf is not None: + from deerflow.persistence.repositories.run_repo import RunRepository + + return RunRepository(sf) + from deerflow.runtime.runs.store.memory import MemoryRunStore + + return MemoryRunStore() + + +def _make_run_event_store(config) -> RunEventStore: + from deerflow.runtime.events.store import make_run_event_store + + run_events_config = getattr(config, "run_events", None) + return make_run_event_store(run_events_config) + + # --------------------------------------------------------------------------- # Getters -- called by routers per-request # --------------------------------------------------------------------------- diff --git a/backend/app/gateway/routers/thread_runs.py b/backend/app/gateway/routers/thread_runs.py index 217605685..8cc944a36 100644 --- a/backend/app/gateway/routers/thread_runs.py +++ b/backend/app/gateway/routers/thread_runs.py @@ -19,7 +19,7 @@ from fastapi import APIRouter, HTTPException, Query, Request from fastapi.responses import Response, StreamingResponse from pydantic import BaseModel, Field -from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge +from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge from app.gateway.services import sse_consumer, start_run from deerflow.runtime import RunRecord, serialize_channel_values @@ -263,3 +263,77 @@ async def stream_existing_run( "X-Accel-Buffering": "no", }, ) + + +# --------------------------------------------------------------------------- +# Messages / Events / Token usage endpoints +# --------------------------------------------------------------------------- + + +@router.get("/{thread_id}/messages") +async def list_thread_messages( + thread_id: str, + request: Request, + limit: int = Query(default=50, le=200), + before_seq: int | None = Query(default=None), + after_seq: int | None = Query(default=None), +) -> list[dict]: + """Return displayable messages for a thread (across all runs).""" + event_store = get_run_event_store(request) + return await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq) + + +@router.get("/{thread_id}/runs/{run_id}/messages") +async def list_run_messages(thread_id: str, run_id: str, request: Request) -> list[dict]: + """Return displayable messages for a specific run.""" + event_store = get_run_event_store(request) + return await event_store.list_messages_by_run(thread_id, run_id) + + +@router.get("/{thread_id}/runs/{run_id}/events") +async def list_run_events( + thread_id: str, + run_id: str, + request: Request, + event_types: str | None = Query(default=None), + limit: int = Query(default=500, le=2000), +) -> list[dict]: + """Return the full event stream for a run (debug/audit).""" + event_store = get_run_event_store(request) + types = event_types.split(",") if event_types else None + return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit) + + +@router.get("/{thread_id}/token-usage") +async def thread_token_usage(thread_id: str, request: Request) -> dict: + """Thread-level token usage aggregation.""" + run_store = get_run_store(request) + runs = await run_store.list_by_thread(thread_id, limit=10000) + completed = [r for r in runs if r.get("status") in ("success", "error")] + + total_tokens = sum(r.get("total_tokens", 0) for r in completed) + total_input = sum(r.get("total_input_tokens", 0) for r in completed) + total_output = sum(r.get("total_output_tokens", 0) for r in completed) + + by_model: dict[str, dict] = {} + for r in completed: + model = r.get("model_name") or "unknown" + entry = by_model.setdefault(model, {"tokens": 0, "runs": 0}) + entry["tokens"] += r.get("total_tokens", 0) + entry["runs"] += 1 + + by_caller = { + "lead_agent": sum(r.get("lead_agent_tokens", 0) for r in completed), + "subagent": sum(r.get("subagent_tokens", 0) for r in completed), + "middleware": sum(r.get("middleware_tokens", 0) for r in completed), + } + + return { + "thread_id": thread_id, + "total_tokens": total_tokens, + "total_input_tokens": total_input, + "total_output_tokens": total_output, + "total_runs": len(completed), + "by_model": by_model, + "by_caller": by_caller, + } diff --git a/backend/app/gateway/routers/threads.py b/backend/app/gateway/routers/threads.py index 562edfdb7..8c13c797e 100644 --- a/backend/app/gateway/routers/threads.py +++ b/backend/app/gateway/routers/threads.py @@ -63,6 +63,7 @@ class ThreadCreateRequest(BaseModel): """Request body for creating a thread.""" thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)") + assistant_id: str | None = Field(default=None, description="Associate thread with an assistant") metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata") diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index ea9e8662b..dcb79deee 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -17,7 +17,7 @@ from typing import Any from fastapi import HTTPException, Request from langchain_core.messages import HumanMessage -from app.gateway.deps import get_checkpointer, get_run_manager, get_store, get_stream_bridge +from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_store, get_stream_bridge from deerflow.runtime import ( END_SENTINEL, HEARTBEAT_SENTINEL, @@ -245,6 +245,12 @@ async def start_run( run_mgr = get_run_manager(request) checkpointer = get_checkpointer(request) store = get_store(request) + event_store = get_run_event_store(request) + + # Get run_events config for journal + from deerflow.config import get_app_config + + run_events_config = getattr(get_app_config(), "run_events", None) disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_ @@ -287,6 +293,8 @@ async def start_run( stream_subgraphs=body.stream_subgraphs, interrupt_before=body.interrupt_before, interrupt_after=body.interrupt_after, + event_store=event_store, + run_events_config=run_events_config, ) ) record.task = task diff --git a/backend/packages/harness/deerflow/persistence/models/__init__.py b/backend/packages/harness/deerflow/persistence/models/__init__.py new file mode 100644 index 000000000..9b9bf8d39 --- /dev/null +++ b/backend/packages/harness/deerflow/persistence/models/__init__.py @@ -0,0 +1,5 @@ +from deerflow.persistence.models.run import RunRow +from deerflow.persistence.models.run_event import RunEventRow +from deerflow.persistence.models.thread_meta import ThreadMetaRow + +__all__ = ["RunEventRow", "RunRow", "ThreadMetaRow"] diff --git a/backend/packages/harness/deerflow/persistence/models/run.py b/backend/packages/harness/deerflow/persistence/models/run.py new file mode 100644 index 000000000..209098945 --- /dev/null +++ b/backend/packages/harness/deerflow/persistence/models/run.py @@ -0,0 +1,49 @@ +"""ORM model for run metadata.""" + +from __future__ import annotations + +from datetime import UTC, datetime + +from sqlalchemy import JSON, Index, String, Text +from sqlalchemy.orm import Mapped, mapped_column + +from deerflow.persistence.base import Base + + +class RunRow(Base): + __tablename__ = "runs" + + run_id: Mapped[str] = mapped_column(String(64), primary_key=True) + thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True) + assistant_id: Mapped[str | None] = mapped_column(String(128)) + owner_id: Mapped[str | None] = mapped_column(String(64), index=True) + status: Mapped[str] = mapped_column(String(20), default="pending") + # "pending" | "running" | "success" | "error" | "timeout" | "interrupted" + + model_name: Mapped[str | None] = mapped_column(String(128)) + multitask_strategy: Mapped[str] = mapped_column(String(20), default="reject") + metadata_json: Mapped[dict] = mapped_column(JSON, default=dict) + kwargs_json: Mapped[dict] = mapped_column(JSON, default=dict) + error: Mapped[str | None] = mapped_column(Text) + + # Convenience fields (for listing pages without querying RunEventStore) + message_count: Mapped[int] = mapped_column(default=0) + first_human_message: Mapped[str | None] = mapped_column(Text) + last_ai_message: Mapped[str | None] = mapped_column(Text) + + # Token usage (accumulated in-memory by RunJournal, written on run completion) + total_input_tokens: Mapped[int] = mapped_column(default=0) + total_output_tokens: Mapped[int] = mapped_column(default=0) + total_tokens: Mapped[int] = mapped_column(default=0) + llm_call_count: Mapped[int] = mapped_column(default=0) + lead_agent_tokens: Mapped[int] = mapped_column(default=0) + subagent_tokens: Mapped[int] = mapped_column(default=0) + middleware_tokens: Mapped[int] = mapped_column(default=0) + + # Follow-up association + follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64)) + + created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC)) + updated_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)) + + __table_args__ = (Index("ix_runs_thread_status", "thread_id", "status"),) diff --git a/backend/packages/harness/deerflow/persistence/models/run_event.py b/backend/packages/harness/deerflow/persistence/models/run_event.py new file mode 100644 index 000000000..10ffab830 --- /dev/null +++ b/backend/packages/harness/deerflow/persistence/models/run_event.py @@ -0,0 +1,30 @@ +"""ORM model for run events.""" + +from __future__ import annotations + +from datetime import UTC, datetime + +from sqlalchemy import JSON, Index, String, Text +from sqlalchemy.orm import Mapped, mapped_column + +from deerflow.persistence.base import Base + + +class RunEventRow(Base): + __tablename__ = "run_events" + + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + thread_id: Mapped[str] = mapped_column(String(64), nullable=False) + run_id: Mapped[str] = mapped_column(String(64), nullable=False) + event_type: Mapped[str] = mapped_column(String(32), nullable=False) + category: Mapped[str] = mapped_column(String(16), nullable=False) + # "message" | "trace" | "lifecycle" + content: Mapped[str] = mapped_column(Text, default="") + event_metadata: Mapped[dict] = mapped_column(JSON, default=dict) + seq: Mapped[int] = mapped_column(nullable=False) + created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC)) + + __table_args__ = ( + Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"), + Index("ix_events_run", "thread_id", "run_id", "seq"), + ) diff --git a/backend/packages/harness/deerflow/persistence/models/thread_meta.py b/backend/packages/harness/deerflow/persistence/models/thread_meta.py new file mode 100644 index 000000000..02254105b --- /dev/null +++ b/backend/packages/harness/deerflow/persistence/models/thread_meta.py @@ -0,0 +1,23 @@ +"""ORM model for thread metadata.""" + +from __future__ import annotations + +from datetime import UTC, datetime + +from sqlalchemy import JSON, String +from sqlalchemy.orm import Mapped, mapped_column + +from deerflow.persistence.base import Base + + +class ThreadMetaRow(Base): + __tablename__ = "threads_meta" + + thread_id: Mapped[str] = mapped_column(String(64), primary_key=True) + assistant_id: Mapped[str | None] = mapped_column(String(128), index=True) + owner_id: Mapped[str | None] = mapped_column(String(64), index=True) + display_name: Mapped[str | None] = mapped_column(String(256)) + status: Mapped[str] = mapped_column(String(20), default="idle") + metadata_json: Mapped[dict] = mapped_column(JSON, default=dict) + created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC)) + updated_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)) diff --git a/backend/packages/harness/deerflow/persistence/repositories/__init__.py b/backend/packages/harness/deerflow/persistence/repositories/__init__.py new file mode 100644 index 000000000..52c913669 --- /dev/null +++ b/backend/packages/harness/deerflow/persistence/repositories/__init__.py @@ -0,0 +1,4 @@ +from deerflow.persistence.repositories.run_repo import RunRepository +from deerflow.persistence.repositories.thread_meta_repo import ThreadMetaRepository + +__all__ = ["RunRepository", "ThreadMetaRepository"] diff --git a/backend/packages/harness/deerflow/persistence/repositories/run_repo.py b/backend/packages/harness/deerflow/persistence/repositories/run_repo.py new file mode 100644 index 000000000..f4727b155 --- /dev/null +++ b/backend/packages/harness/deerflow/persistence/repositories/run_repo.py @@ -0,0 +1,174 @@ +"""SQLAlchemy-backed RunStore implementation. + +Each method acquires and releases its own short-lived session. +Run status updates happen from background workers that may live +minutes -- we don't hold connections across long execution. +""" + +from __future__ import annotations + +import json +import logging +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from deerflow.persistence.models.run import RunRow +from deerflow.runtime.runs.store.base import RunStore + +logger = logging.getLogger(__name__) + + +class RunRepository(RunStore): + def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: + self._sf = session_factory + + @staticmethod + def _safe_json(obj: Any) -> Any: + """Ensure obj is JSON-serializable. Falls back to model_dump() or str().""" + if obj is None: + return None + if isinstance(obj, (str, int, float, bool)): + return obj + if isinstance(obj, dict): + return {k: RunRepository._safe_json(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [RunRepository._safe_json(v) for v in obj] + if hasattr(obj, "model_dump"): + try: + return obj.model_dump() + except Exception: + pass + if hasattr(obj, "dict"): + try: + return obj.dict() + except Exception: + pass + try: + json.dumps(obj) + return obj + except (TypeError, ValueError): + return str(obj) + + @staticmethod + def _row_to_dict(row: RunRow) -> dict[str, Any]: + d = row.to_dict() + # Remap JSON columns to match RunStore interface + d["metadata"] = d.pop("metadata_json", {}) + d["kwargs"] = d.pop("kwargs_json", {}) + # Convert datetime to ISO string for consistency with MemoryRunStore + for key in ("created_at", "updated_at"): + val = d.get(key) + if isinstance(val, datetime): + d[key] = val.isoformat() + return d + + async def put( + self, + run_id, + *, + thread_id, + assistant_id=None, + owner_id=None, + status="pending", + multitask_strategy="reject", + metadata=None, + kwargs=None, + error=None, + created_at=None, + ): + now = datetime.now(UTC) + row = RunRow( + run_id=run_id, + thread_id=thread_id, + assistant_id=assistant_id, + owner_id=owner_id, + status=status, + multitask_strategy=multitask_strategy, + metadata_json=self._safe_json(metadata) or {}, + kwargs_json=self._safe_json(kwargs) or {}, + error=error, + created_at=datetime.fromisoformat(created_at) if created_at else now, + updated_at=now, + ) + async with self._sf() as session: + session.add(row) + await session.commit() + + async def get(self, run_id): + async with self._sf() as session: + row = await session.get(RunRow, run_id) + return self._row_to_dict(row) if row else None + + async def list_by_thread(self, thread_id, *, owner_id=None, limit=100): + stmt = select(RunRow).where(RunRow.thread_id == thread_id) + if owner_id is not None: + stmt = stmt.where(RunRow.owner_id == owner_id) + stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit) + async with self._sf() as session: + result = await session.execute(stmt) + return [self._row_to_dict(r) for r in result.scalars()] + + async def update_status(self, run_id, status, *, error=None): + values: dict[str, Any] = {"status": status, "updated_at": datetime.now(UTC)} + if error is not None: + values["error"] = error + async with self._sf() as session: + await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) + await session.commit() + + async def delete(self, run_id): + async with self._sf() as session: + row = await session.get(RunRow, run_id) + if row is not None: + await session.delete(row) + await session.commit() + + async def list_pending(self, *, before=None): + now = before or datetime.now(UTC).isoformat() + stmt = select(RunRow).where(RunRow.status == "pending", RunRow.created_at <= now).order_by(RunRow.created_at.asc()) + async with self._sf() as session: + result = await session.execute(stmt) + return [self._row_to_dict(r) for r in result.scalars()] + + async def update_run_completion( + self, + run_id: str, + *, + status: str, + total_input_tokens: int = 0, + total_output_tokens: int = 0, + total_tokens: int = 0, + llm_call_count: int = 0, + lead_agent_tokens: int = 0, + subagent_tokens: int = 0, + middleware_tokens: int = 0, + message_count: int = 0, + last_ai_message: str | None = None, + first_human_message: str | None = None, + error: str | None = None, + ) -> None: + """Update status + token usage + convenience fields on run completion.""" + values: dict[str, Any] = { + "status": status, + "total_input_tokens": total_input_tokens, + "total_output_tokens": total_output_tokens, + "total_tokens": total_tokens, + "llm_call_count": llm_call_count, + "lead_agent_tokens": lead_agent_tokens, + "subagent_tokens": subagent_tokens, + "middleware_tokens": middleware_tokens, + "message_count": message_count, + "updated_at": datetime.now(UTC), + } + if last_ai_message is not None: + values["last_ai_message"] = last_ai_message[:2000] + if first_human_message is not None: + values["first_human_message"] = first_human_message[:2000] + if error is not None: + values["error"] = error + async with self._sf() as session: + await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) + await session.commit() diff --git a/backend/packages/harness/deerflow/persistence/repositories/thread_meta_repo.py b/backend/packages/harness/deerflow/persistence/repositories/thread_meta_repo.py new file mode 100644 index 000000000..54523d528 --- /dev/null +++ b/backend/packages/harness/deerflow/persistence/repositories/thread_meta_repo.py @@ -0,0 +1,91 @@ +"""SQLAlchemy-backed thread metadata repository.""" + +from __future__ import annotations + +import logging +from datetime import UTC, datetime +from typing import Any + +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from deerflow.persistence.models.thread_meta import ThreadMetaRow + +logger = logging.getLogger(__name__) + + +class ThreadMetaRepository: + def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: + self._sf = session_factory + + @staticmethod + def _row_to_dict(row: ThreadMetaRow) -> dict[str, Any]: + d = row.to_dict() + d["metadata"] = d.pop("metadata_json", {}) + for key in ("created_at", "updated_at"): + val = d.get(key) + if isinstance(val, datetime): + d[key] = val.isoformat() + return d + + async def create( + self, + thread_id: str, + *, + assistant_id: str | None = None, + owner_id: str | None = None, + display_name: str | None = None, + metadata: dict | None = None, + ) -> dict: + now = datetime.now(UTC) + row = ThreadMetaRow( + thread_id=thread_id, + assistant_id=assistant_id, + owner_id=owner_id, + display_name=display_name, + metadata_json=metadata or {}, + created_at=now, + updated_at=now, + ) + async with self._sf() as session: + session.add(row) + await session.commit() + await session.refresh(row) + return self._row_to_dict(row) + + async def get(self, thread_id: str) -> dict | None: + async with self._sf() as session: + row = await session.get(ThreadMetaRow, thread_id) + return self._row_to_dict(row) if row else None + + async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]: + stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset) + async with self._sf() as session: + result = await session.execute(stmt) + return [self._row_to_dict(r) for r in result.scalars()] + + async def check_access(self, thread_id: str, owner_id: str) -> bool: + """Check if owner_id has access to thread_id. + + Returns True if: row doesn't exist (untracked thread), owner_id + is None on the row (shared thread), or owner_id matches. + """ + async with self._sf() as session: + row = await session.get(ThreadMetaRow, thread_id) + if row is None: + return True + if row.owner_id is None: + return True + return row.owner_id == owner_id + + async def update_status(self, thread_id: str, status: str) -> None: + async with self._sf() as session: + await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC))) + await session.commit() + + async def delete(self, thread_id: str) -> None: + async with self._sf() as session: + row = await session.get(ThreadMetaRow, thread_id) + if row is not None: + await session.delete(row) + await session.commit() diff --git a/backend/packages/harness/deerflow/runtime/events/store/__init__.py b/backend/packages/harness/deerflow/runtime/events/store/__init__.py index 0da8fabe5..55f0dd33f 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/__init__.py +++ b/backend/packages/harness/deerflow/runtime/events/store/__init__.py @@ -1,4 +1,26 @@ from deerflow.runtime.events.store.base import RunEventStore from deerflow.runtime.events.store.memory import MemoryRunEventStore -__all__ = ["MemoryRunEventStore", "RunEventStore"] + +def make_run_event_store(config=None) -> RunEventStore: + """Create a RunEventStore based on run_events.backend configuration.""" + if config is None or config.backend == "memory": + return MemoryRunEventStore() + if config.backend == "db": + from deerflow.persistence.engine import get_session_factory + + sf = get_session_factory() + if sf is None: + # database.backend=memory but run_events.backend=db -> fallback + return MemoryRunEventStore() + from deerflow.runtime.events.store.db import DbRunEventStore + + return DbRunEventStore(sf, max_trace_content=config.max_trace_content) + if config.backend == "jsonl": + from deerflow.runtime.events.store.jsonl import JsonlRunEventStore + + return JsonlRunEventStore() + raise ValueError(f"Unknown run_events backend: {config.backend!r}") + + +__all__ = ["MemoryRunEventStore", "RunEventStore", "make_run_event_store"] diff --git a/backend/packages/harness/deerflow/runtime/events/store/db.py b/backend/packages/harness/deerflow/runtime/events/store/db.py new file mode 100644 index 000000000..0842608ca --- /dev/null +++ b/backend/packages/harness/deerflow/runtime/events/store/db.py @@ -0,0 +1,148 @@ +"""SQLAlchemy-backed RunEventStore implementation. + +Persists events to the ``run_events`` table. Trace content is truncated +at ``max_trace_content`` bytes to avoid bloating the database. +""" + +from __future__ import annotations + +from datetime import UTC, datetime + +from sqlalchemy import delete, func, select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from deerflow.persistence.models.run_event import RunEventRow +from deerflow.runtime.events.store.base import RunEventStore + + +class DbRunEventStore(RunEventStore): + def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, max_trace_content: int = 10240): + self._sf = session_factory + self._max_trace_content = max_trace_content + + @staticmethod + def _row_to_dict(row: RunEventRow) -> dict: + d = row.to_dict() + d["metadata"] = d.pop("event_metadata", {}) + val = d.get("created_at") + if isinstance(val, datetime): + d["created_at"] = val.isoformat() + d.pop("id", None) + return d + + def _truncate_trace(self, category: str, content: str, metadata: dict | None) -> tuple[str, dict]: + if category == "trace" and len(content) > self._max_trace_content: + content = content[: self._max_trace_content] + metadata = {**(metadata or {}), "content_truncated": True} + return content, metadata or {} + + async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): + content, metadata = self._truncate_trace(category, content, metadata) + async with self._sf() as session: + max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id)) + seq = (max_seq or 0) + 1 + row = RunEventRow( + thread_id=thread_id, + run_id=run_id, + event_type=event_type, + category=category, + content=content, + event_metadata=metadata, + seq=seq, + created_at=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC), + ) + session.add(row) + await session.commit() + await session.refresh(row) + return self._row_to_dict(row) + + async def put_batch(self, events): + if not events: + return [] + async with self._sf() as session: + # Get max seq for the thread (assume all events in batch belong to same thread) + thread_id = events[0]["thread_id"] + max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id)) + seq = max_seq or 0 + rows = [] + for e in events: + seq += 1 + content = e.get("content", "") + category = e.get("category", "trace") + metadata = e.get("metadata") + content, metadata = self._truncate_trace(category, content, metadata) + row = RunEventRow( + thread_id=e["thread_id"], + run_id=e["run_id"], + event_type=e["event_type"], + category=category, + content=content, + event_metadata=metadata, + seq=seq, + created_at=datetime.fromisoformat(e["created_at"]) if e.get("created_at") else datetime.now(UTC), + ) + session.add(row) + rows.append(row) + await session.commit() + for row in rows: + await session.refresh(row) + return [self._row_to_dict(r) for r in rows] + + async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None): + stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message") + if before_seq is not None: + stmt = stmt.where(RunEventRow.seq < before_seq) + if after_seq is not None: + stmt = stmt.where(RunEventRow.seq > after_seq) + + if after_seq is not None: + # Forward pagination: first `limit` records after cursor + stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit) + async with self._sf() as session: + result = await session.execute(stmt) + return [self._row_to_dict(r) for r in result.scalars()] + else: + # before_seq or default (latest): take last `limit` records, return ascending + stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit) + async with self._sf() as session: + result = await session.execute(stmt) + rows = list(result.scalars()) + return [self._row_to_dict(r) for r in reversed(rows)] + + async def list_events(self, thread_id, run_id, *, event_types=None, limit=500): + stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id) + if event_types: + stmt = stmt.where(RunEventRow.event_type.in_(event_types)) + stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit) + async with self._sf() as session: + result = await session.execute(stmt) + return [self._row_to_dict(r) for r in result.scalars()] + + async def list_messages_by_run(self, thread_id, run_id): + stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message").order_by(RunEventRow.seq.asc()) + async with self._sf() as session: + result = await session.execute(stmt) + return [self._row_to_dict(r) for r in result.scalars()] + + async def count_messages(self, thread_id): + stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message") + async with self._sf() as session: + return await session.scalar(stmt) or 0 + + async def delete_by_thread(self, thread_id): + async with self._sf() as session: + count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id) + count = await session.scalar(count_stmt) or 0 + if count > 0: + await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id)) + await session.commit() + return count + + async def delete_by_run(self, thread_id, run_id): + async with self._sf() as session: + count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id) + count = await session.scalar(count_stmt) or 0 + if count > 0: + await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)) + await session.commit() + return count diff --git a/backend/packages/harness/deerflow/runtime/events/store/jsonl.py b/backend/packages/harness/deerflow/runtime/events/store/jsonl.py new file mode 100644 index 000000000..ef26d48d1 --- /dev/null +++ b/backend/packages/harness/deerflow/runtime/events/store/jsonl.py @@ -0,0 +1,164 @@ +"""JSONL file-backed RunEventStore implementation. + +Each run's events are stored in a single file: +``.deer-flow/threads/{thread_id}/runs/{run_id}.jsonl`` + +All categories (message, trace, lifecycle) are in the same file. +This backend is suitable for lightweight single-node deployments. + +Known trade-off: ``list_messages()`` must scan all run files for a +thread since messages from multiple runs need unified seq ordering. +``list_events()`` reads only one file -- the fast path. +""" + +from __future__ import annotations + +import json +import logging +from datetime import UTC, datetime +from pathlib import Path + +from deerflow.runtime.events.store.base import RunEventStore + +logger = logging.getLogger(__name__) + + +class JsonlRunEventStore(RunEventStore): + def __init__(self, base_dir: str | Path | None = None): + self._base_dir = Path(base_dir) if base_dir else Path(".deer-flow") + self._seq_counters: dict[str, int] = {} # thread_id -> current max seq + + def _thread_dir(self, thread_id: str) -> Path: + return self._base_dir / "threads" / thread_id / "runs" + + def _run_file(self, thread_id: str, run_id: str) -> Path: + return self._thread_dir(thread_id) / f"{run_id}.jsonl" + + def _next_seq(self, thread_id: str) -> int: + self._seq_counters[thread_id] = self._seq_counters.get(thread_id, 0) + 1 + return self._seq_counters[thread_id] + + def _ensure_seq_loaded(self, thread_id: str) -> None: + """Load max seq from existing files if not yet cached.""" + if thread_id in self._seq_counters: + return + max_seq = 0 + thread_dir = self._thread_dir(thread_id) + if thread_dir.exists(): + for f in thread_dir.glob("*.jsonl"): + for line in f.read_text(encoding="utf-8").strip().splitlines(): + try: + record = json.loads(line) + max_seq = max(max_seq, record.get("seq", 0)) + except json.JSONDecodeError: + continue + self._seq_counters[thread_id] = max_seq + + def _write_record(self, record: dict) -> None: + path = self._run_file(record["thread_id"], record["run_id"]) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "a", encoding="utf-8") as f: + f.write(json.dumps(record, default=str, ensure_ascii=False) + "\n") + + def _read_thread_events(self, thread_id: str) -> list[dict]: + """Read all events for a thread, sorted by seq.""" + events = [] + thread_dir = self._thread_dir(thread_id) + if not thread_dir.exists(): + return events + for f in sorted(thread_dir.glob("*.jsonl")): + for line in f.read_text(encoding="utf-8").strip().splitlines(): + if not line: + continue + try: + events.append(json.loads(line)) + except json.JSONDecodeError: + continue + events.sort(key=lambda e: e.get("seq", 0)) + return events + + def _read_run_events(self, thread_id: str, run_id: str) -> list[dict]: + """Read events for a specific run file.""" + path = self._run_file(thread_id, run_id) + if not path.exists(): + return [] + events = [] + for line in path.read_text(encoding="utf-8").strip().splitlines(): + if not line: + continue + try: + events.append(json.loads(line)) + except json.JSONDecodeError: + continue + events.sort(key=lambda e: e.get("seq", 0)) + return events + + async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): + self._ensure_seq_loaded(thread_id) + seq = self._next_seq(thread_id) + record = { + "thread_id": thread_id, + "run_id": run_id, + "event_type": event_type, + "category": category, + "content": content, + "metadata": metadata or {}, + "seq": seq, + "created_at": created_at or datetime.now(UTC).isoformat(), + } + self._write_record(record) + return record + + async def put_batch(self, events): + if not events: + return [] + results = [] + for ev in events: + record = await self.put(**ev) + results.append(record) + return results + + async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None): + all_events = self._read_thread_events(thread_id) + messages = [e for e in all_events if e.get("category") == "message"] + + if before_seq is not None: + messages = [e for e in messages if e["seq"] < before_seq] + return messages[-limit:] + elif after_seq is not None: + messages = [e for e in messages if e["seq"] > after_seq] + return messages[:limit] + else: + return messages[-limit:] + + async def list_events(self, thread_id, run_id, *, event_types=None, limit=500): + events = self._read_run_events(thread_id, run_id) + if event_types is not None: + events = [e for e in events if e.get("event_type") in event_types] + return events[:limit] + + async def list_messages_by_run(self, thread_id, run_id): + events = self._read_run_events(thread_id, run_id) + return [e for e in events if e.get("category") == "message"] + + async def count_messages(self, thread_id): + all_events = self._read_thread_events(thread_id) + return sum(1 for e in all_events if e.get("category") == "message") + + async def delete_by_thread(self, thread_id): + all_events = self._read_thread_events(thread_id) + count = len(all_events) + thread_dir = self._thread_dir(thread_id) + if thread_dir.exists(): + for f in thread_dir.glob("*.jsonl"): + f.unlink() + self._seq_counters.pop(thread_id, None) + return count + + async def delete_by_run(self, thread_id, run_id): + events = self._read_run_events(thread_id, run_id) + count = len(events) + path = self._run_file(thread_id, run_id) + if path.exists(): + path.unlink() + return count diff --git a/backend/packages/harness/deerflow/runtime/journal.py b/backend/packages/harness/deerflow/runtime/journal.py new file mode 100644 index 000000000..6ffa3c206 --- /dev/null +++ b/backend/packages/harness/deerflow/runtime/journal.py @@ -0,0 +1,333 @@ +"""Run event capture via LangChain callbacks. + +RunJournal sits between LangChain's callback mechanism and the pluggable +RunEventStore. It standardizes callback data into RunEvent records and +handles token usage accumulation. + +Key design decisions: +- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end +- All LangChain objects serialized via serialize_lc_object (same as worker.py SSE) +- Token usage accumulated in memory, written to RunRow on run completion +- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name}) +""" + +from __future__ import annotations + +import asyncio +import logging +import time +from collections.abc import Callable +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any +from uuid import UUID + +from langchain_core.callbacks import BaseCallbackHandler + +if TYPE_CHECKING: + from deerflow.runtime.events.store.base import RunEventStore + +logger = logging.getLogger(__name__) + + +class RunJournal(BaseCallbackHandler): + """LangChain callback handler that captures events to RunEventStore.""" + + def __init__( + self, + run_id: str, + thread_id: str, + event_store: RunEventStore, + *, + track_token_usage: bool = True, + on_complete: Callable[..., Any] | None = None, + flush_threshold: int = 20, + ): + super().__init__() + self.run_id = run_id + self.thread_id = thread_id + self._store = event_store + self._track_tokens = track_token_usage + self._on_complete = on_complete + self._flush_threshold = flush_threshold + + # Write buffer + self._buffer: list[dict] = [] + + # Token accumulators + self._total_input_tokens = 0 + self._total_output_tokens = 0 + self._total_tokens = 0 + self._llm_call_count = 0 + self._lead_agent_tokens = 0 + self._subagent_tokens = 0 + self._middleware_tokens = 0 + + # Convenience fields + self._last_ai_msg: str | None = None + self._first_human_msg: str | None = None + self._msg_count = 0 + + # Latency tracking + self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time + + # -- Lifecycle callbacks -- + + def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None: + # Only record for the top-level chain (parent_run_id is None) + if kwargs.get("parent_run_id") is not None: + return + self._put( + event_type="run_start", + category="lifecycle", + metadata={"input_preview": str(inputs)[:500]}, + ) + + def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None: + if kwargs.get("parent_run_id") is not None: + return + self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"}) + self._flush_sync() + if self._on_complete: + self._on_complete( + total_input_tokens=self._total_input_tokens, + total_output_tokens=self._total_output_tokens, + total_tokens=self._total_tokens, + llm_call_count=self._llm_call_count, + lead_agent_tokens=self._lead_agent_tokens, + subagent_tokens=self._subagent_tokens, + middleware_tokens=self._middleware_tokens, + message_count=self._msg_count, + last_ai_message=self._last_ai_msg, + first_human_message=self._first_human_msg, + ) + + def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: + if kwargs.get("parent_run_id") is not None: + return + self._put( + event_type="run_error", + category="lifecycle", + content=str(error), + metadata={"error_type": type(error).__name__}, + ) + self._flush_sync() + + # -- LLM callbacks -- + + def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None: + self._llm_start_times[str(run_id)] = time.monotonic() + self._put( + event_type="llm_start", + category="trace", + metadata={"model_name": serialized.get("name", "")}, + ) + + def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None: + from deerflow.runtime.serialization import serialize_lc_object + + try: + message = response.generations[0][0].message + except (IndexError, AttributeError): + logger.debug("on_llm_end: could not extract message from response") + return + + serialized_msg = serialize_lc_object(message) + caller = self._identify_caller(kwargs) + + # Latency + start = self._llm_start_times.pop(str(run_id), None) + latency_ms = int((time.monotonic() - start) * 1000) if start else None + + # Token usage from message + usage = getattr(message, "usage_metadata", None) + usage_dict = dict(usage) if usage else {} + + # trace event: llm_end (every LLM call) + self._put( + event_type="llm_end", + category="trace", + content=getattr(message, "content", "") if isinstance(getattr(message, "content", ""), str) else str(getattr(message, "content", "")), + metadata={ + "message": serialized_msg, + "caller": caller, + "usage": usage_dict, + "latency_ms": latency_ms, + }, + ) + + # message event: ai_message (only lead_agent final replies with content) + if caller == "lead_agent": + content = getattr(message, "content", "") + if isinstance(content, str) and content: + tool_calls = getattr(message, "tool_calls", None) or [] + tool_calls_summary = [{"name": tc.get("name", ""), "status": "success"} for tc in tool_calls if isinstance(tc, dict)] + resp_meta = getattr(message, "response_metadata", None) or {} + model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None + self._put( + event_type="ai_message", + category="message", + content=content, + metadata={ + "model_name": model_name, + "tool_calls": tool_calls_summary, + }, + ) + self._last_ai_msg = content[:2000] + self._msg_count += 1 + + # Token accumulation + input_tk = usage_dict.get("input_tokens", 0) or 0 + output_tk = usage_dict.get("output_tokens", 0) or 0 + total_tk = usage_dict.get("total_tokens", 0) or 0 + if self._track_tokens and total_tk > 0: + self._total_input_tokens += input_tk + self._total_output_tokens += output_tk + self._total_tokens += total_tk + self._llm_call_count += 1 + if caller.startswith("subagent:"): + self._subagent_tokens += total_tk + elif caller.startswith("middleware:"): + self._middleware_tokens += total_tk + else: + self._lead_agent_tokens += total_tk + + def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: + self._llm_start_times.pop(str(run_id), None) + self._put(event_type="llm_error", category="trace", content=str(error)) + + # -- Tool callbacks -- + + def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None: + self._put( + event_type="tool_start", + category="trace", + metadata={ + "tool_name": serialized.get("name", ""), + "tool_call_id": kwargs.get("tool_call_id"), + "args": str(input_str)[:2000], + }, + ) + + def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None: + self._put( + event_type="tool_end", + category="trace", + content=str(output), + metadata={ + "tool_name": kwargs.get("name", ""), + "tool_call_id": kwargs.get("tool_call_id"), + "status": "success", + }, + ) + + def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: + self._put( + event_type="tool_error", + category="trace", + content=str(error), + metadata={ + "tool_name": kwargs.get("name", ""), + "tool_call_id": kwargs.get("tool_call_id"), + }, + ) + + # -- Custom event callback -- + + def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None: + from deerflow.runtime.serialization import serialize_lc_object + + if name == "summarization": + data_dict = data if isinstance(data, dict) else {} + self._put( + event_type="summarization", + category="trace", + content=data_dict.get("summary", ""), + metadata={ + "replaced_message_ids": data_dict.get("replaced_message_ids", []), + "replaced_count": data_dict.get("replaced_count", 0), + }, + ) + self._put( + event_type="summary", + category="message", + content=data_dict.get("summary", ""), + metadata={"replaced_count": data_dict.get("replaced_count", 0)}, + ) + else: + event_data = serialize_lc_object(data) if not isinstance(data, dict) else data + self._put( + event_type=name, + category="trace", + metadata=event_data if isinstance(event_data, dict) else {"data": event_data}, + ) + + # -- Internal methods -- + + def _put(self, *, event_type: str, category: str, content: str = "", metadata: dict | None = None) -> None: + self._buffer.append({ + "thread_id": self.thread_id, + "run_id": self.run_id, + "event_type": event_type, + "category": category, + "content": content, + "metadata": metadata or {}, + "created_at": datetime.now(UTC).isoformat(), + }) + if len(self._buffer) >= self._flush_threshold: + self._flush_sync() + + def _flush_sync(self) -> None: + """Flush buffer to RunEventStore. + + BaseCallbackHandler methods are synchronous. We schedule the async + put_batch via the current event loop. + """ + if not self._buffer: + return + batch = self._buffer.copy() + self._buffer.clear() + try: + loop = asyncio.get_running_loop() + loop.create_task(self._flush_async(batch)) + except RuntimeError: + logger.warning("RunJournal: no event loop, dropping %d events", len(batch)) + + async def _flush_async(self, batch: list[dict]) -> None: + try: + await self._store.put_batch(batch) + except Exception: + logger.warning("RunJournal: failed to flush %d events", len(batch), exc_info=True) + + def _identify_caller(self, kwargs: dict) -> str: + for tag in kwargs.get("tags") or []: + if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"): + return tag + return "unknown" + + # -- Public methods (called by worker) -- + + def set_first_human_message(self, content: str) -> None: + """Record the first human message for convenience fields.""" + self._first_human_msg = content[:2000] if content else None + + async def flush(self) -> None: + """Force flush. Used in cancel/error paths.""" + if self._buffer: + batch = self._buffer.copy() + self._buffer.clear() + await self._store.put_batch(batch) + + def get_completion_data(self) -> dict: + """Return accumulated token and message data for run completion.""" + return { + "total_input_tokens": self._total_input_tokens, + "total_output_tokens": self._total_output_tokens, + "total_tokens": self._total_tokens, + "llm_call_count": self._llm_call_count, + "lead_agent_tokens": self._lead_agent_tokens, + "subagent_tokens": self._subagent_tokens, + "middleware_tokens": self._middleware_tokens, + "message_count": self._msg_count, + "last_ai_message": self._last_ai_msg, + "first_human_message": self._first_human_msg, + } diff --git a/backend/packages/harness/deerflow/runtime/runs/manager.py b/backend/packages/harness/deerflow/runtime/runs/manager.py index e61a1707f..9e4d9089e 100644 --- a/backend/packages/harness/deerflow/runtime/runs/manager.py +++ b/backend/packages/harness/deerflow/runtime/runs/manager.py @@ -1,4 +1,4 @@ -"""In-memory run registry.""" +"""In-memory run registry with optional persistent RunStore backing.""" from __future__ import annotations @@ -7,9 +7,13 @@ import logging import uuid from dataclasses import dataclass, field from datetime import UTC, datetime +from typing import TYPE_CHECKING from .schemas import DisconnectMode, RunStatus +if TYPE_CHECKING: + from deerflow.runtime.runs.store.base import RunStore + logger = logging.getLogger(__name__) @@ -38,11 +42,17 @@ class RunRecord: class RunManager: - """In-memory run registry. All mutations are protected by an asyncio lock.""" + """In-memory run registry with optional persistent RunStore backing. - def __init__(self) -> None: + All mutations are protected by an asyncio lock. When a ``store`` is + provided, serializable metadata is also persisted to the store so + that run history survives process restarts. + """ + + def __init__(self, store: RunStore | None = None) -> None: self._runs: dict[str, RunRecord] = {} self._lock = asyncio.Lock() + self._store = store async def create( self, @@ -71,6 +81,20 @@ class RunManager: ) async with self._lock: self._runs[run_id] = record + if self._store is not None: + try: + await self._store.put( + run_id, + thread_id=thread_id, + assistant_id=assistant_id, + status=RunStatus.pending.value, + multitask_strategy=multitask_strategy, + metadata=metadata or {}, + kwargs=kwargs or {}, + created_at=now, + ) + except Exception: + logger.warning("Failed to persist run %s to store", run_id, exc_info=True) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record @@ -96,6 +120,11 @@ class RunManager: record.updated_at = _now_iso() if error is not None: record.error = error + if self._store is not None: + try: + await self._store.update_status(run_id, status.value, error=error) + except Exception: + logger.warning("Failed to persist status update for run %s", run_id, exc_info=True) logger.info("Run %s -> %s", run_id, status.value) async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool: @@ -185,6 +214,21 @@ class RunManager: ) self._runs[run_id] = record + if self._store is not None: + try: + await self._store.put( + run_id, + thread_id=thread_id, + assistant_id=assistant_id, + status=RunStatus.pending.value, + multitask_strategy=multitask_strategy, + metadata=metadata or {}, + kwargs=kwargs or {}, + created_at=now, + ) + except Exception: + logger.warning("Failed to persist run %s to store", run_id, exc_info=True) + logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index deaec055a..82b77ca1e 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -45,6 +45,8 @@ async def run_agent( stream_subgraphs: bool = False, interrupt_before: list[str] | Literal["*"] | None = None, interrupt_after: list[str] | Literal["*"] | None = None, + event_store: Any | None = None, + run_events_config: Any | None = None, ) -> None: """Execute an agent in the background, publishing events to *bridge*.""" @@ -52,6 +54,30 @@ async def run_agent( thread_id = record.thread_id requested_modes: set[str] = set(stream_modes or ["values"]) + # Initialize RunJournal for event capture + journal = None + if event_store is not None: + from deerflow.runtime.journal import RunJournal + + journal = RunJournal( + run_id=run_id, + thread_id=thread_id, + event_store=event_store, + track_token_usage=getattr(run_events_config, "track_token_usage", True), + ) + + # Write human_message event + user_input = _extract_user_input(graph_input) + if user_input: + await event_store.put( + thread_id=thread_id, + run_id=run_id, + event_type="human_message", + category="message", + content=user_input, + ) + journal.set_first_human_message(user_input) + # Track whether "events" was requested but skipped if "events" in requested_modes: logger.info( @@ -92,6 +118,10 @@ async def run_agent( runtime = Runtime(context={"thread_id": thread_id}, store=store) config.setdefault("configurable", {})["__pregel_runtime"] = runtime + # Inject RunJournal as a callback + if journal is not None: + config.setdefault("callbacks", []).append(journal) + runnable_config = RunnableConfig(**config) agent = agent_factory(config=runnable_config) @@ -206,6 +236,13 @@ async def run_agent( ) finally: + # Flush any buffered journal events + if journal is not None: + try: + await journal.flush() + except Exception: + logger.warning("Failed to flush journal for run %s", run_id, exc_info=True) + await bridge.publish_end(run_id) asyncio.create_task(bridge.cleanup(run_id, delay=60)) @@ -227,6 +264,23 @@ def _lg_mode_to_sse_event(mode: str) -> str: return mode +def _extract_user_input(graph_input: dict) -> str: + """Extract user input text from graph_input for event recording.""" + messages = graph_input.get("messages") + if not messages: + return "" + # Take the last message (usually the user's input) + last = messages[-1] if isinstance(messages, list) else messages + if isinstance(last, str): + return last + if hasattr(last, "content"): + content = last.content + return content if isinstance(content, str) else str(content) + if isinstance(last, dict): + return str(last.get("content", "")) + return "" + + def _unpack_stream_item( item: Any, lg_modes: list[str], diff --git a/backend/tests/test_phase2b_integration.py b/backend/tests/test_phase2b_integration.py new file mode 100644 index 000000000..0858efdce --- /dev/null +++ b/backend/tests/test_phase2b_integration.py @@ -0,0 +1,154 @@ +"""Phase 2-B integration tests. + +End-to-end test: simulate a run's complete lifecycle, verify data +is correctly written to both RunStore and RunEventStore. +""" + +import asyncio +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest + +from deerflow.runtime.events.store.memory import MemoryRunEventStore +from deerflow.runtime.journal import RunJournal +from deerflow.runtime.runs.store.memory import MemoryRunStore + + +def _make_llm_response(content="Hello", usage=None): + msg = MagicMock() + msg.content = content + msg.tool_calls = [] + msg.response_metadata = {"model_name": "test-model"} + msg.usage_metadata = usage + + gen = MagicMock() + gen.message = msg + + response = MagicMock() + response.generations = [[gen]] + return response + + +class TestRunLifecycle: + @pytest.mark.anyio + async def test_full_run_lifecycle(self): + """Simulate a complete run lifecycle with RunStore + RunEventStore.""" + run_store = MemoryRunStore() + event_store = MemoryRunEventStore() + + # 1. Create run + await run_store.put("r1", thread_id="t1", status="pending") + + # 2. Write human_message + await event_store.put( + thread_id="t1", + run_id="r1", + event_type="human_message", + category="message", + content="What is AI?", + ) + + # 3. Simulate RunJournal callback sequence + on_complete_data = {} + + def on_complete(**data): + on_complete_data.update(data) + + journal = RunJournal("r1", "t1", event_store, on_complete=on_complete, flush_threshold=100) + journal.set_first_human_message("What is AI?") + + # chain_start (top-level) + journal.on_chain_start({}, {"messages": ["What is AI?"]}, run_id=uuid4(), parent_run_id=None) + + # llm_start + llm_end + llm_run_id = uuid4() + journal.on_llm_start({"name": "gpt-4"}, ["prompt"], run_id=llm_run_id, tags=["lead_agent"]) + usage = {"input_tokens": 50, "output_tokens": 100, "total_tokens": 150} + journal.on_llm_end(_make_llm_response("AI is artificial intelligence.", usage=usage), run_id=llm_run_id, tags=["lead_agent"]) + + # chain_end (triggers on_complete + flush_sync which creates a task) + journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None) + await journal.flush() + # Let event loop process any pending flush tasks from _flush_sync + await asyncio.sleep(0.05) + + # 4. Verify messages + messages = await event_store.list_messages("t1") + assert len(messages) == 2 # human + ai + assert messages[0]["event_type"] == "human_message" + assert messages[1]["event_type"] == "ai_message" + assert messages[1]["content"] == "AI is artificial intelligence." + + # 5. Verify events + events = await event_store.list_events("t1", "r1") + event_types = {e["event_type"] for e in events} + assert "run_start" in event_types + assert "llm_start" in event_types + assert "llm_end" in event_types + assert "run_end" in event_types + + # 6. Verify on_complete data + assert on_complete_data["total_tokens"] == 150 + assert on_complete_data["llm_call_count"] == 1 + assert on_complete_data["lead_agent_tokens"] == 150 + assert on_complete_data["message_count"] == 1 + assert on_complete_data["last_ai_message"] == "AI is artificial intelligence." + assert on_complete_data["first_human_message"] == "What is AI?" + + @pytest.mark.anyio + async def test_run_with_tool_calls(self): + """Simulate a run that uses tools.""" + event_store = MemoryRunEventStore() + journal = RunJournal("r1", "t1", event_store, flush_threshold=100) + + # tool_start + tool_end + journal.on_tool_start({"name": "web_search"}, '{"query": "AI"}', run_id=uuid4()) + journal.on_tool_end("Search results...", run_id=uuid4(), name="web_search") + await journal.flush() + + events = await event_store.list_events("t1", "r1") + assert len(events) == 2 + assert events[0]["event_type"] == "tool_start" + assert events[1]["event_type"] == "tool_end" + + @pytest.mark.anyio + async def test_multi_run_thread(self): + """Multiple runs on the same thread maintain unified seq ordering.""" + event_store = MemoryRunEventStore() + + # Run 1 + await event_store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="Q1") + await event_store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="A1") + + # Run 2 + await event_store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message", content="Q2") + await event_store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message", content="A2") + + messages = await event_store.list_messages("t1") + assert len(messages) == 4 + assert [m["seq"] for m in messages] == [1, 2, 3, 4] + assert messages[0]["run_id"] == "r1" + assert messages[2]["run_id"] == "r2" + + @pytest.mark.anyio + async def test_runmanager_with_store_backing(self): + """RunManager persists to RunStore when one is provided.""" + from deerflow.runtime.runs.manager import RunManager + + run_store = MemoryRunStore() + mgr = RunManager(store=run_store) + + record = await mgr.create("t1", assistant_id="lead_agent") + # Verify persisted to store + row = await run_store.get(record.run_id) + assert row is not None + assert row["thread_id"] == "t1" + assert row["status"] == "pending" + + # Status update + from deerflow.runtime.runs.schemas import RunStatus + + await mgr.set_status(record.run_id, RunStatus.running) + row = await run_store.get(record.run_id) + assert row["status"] == "running" diff --git a/backend/tests/test_run_event_store.py b/backend/tests/test_run_event_store.py index a493a486b..f8c6cb177 100644 --- a/backend/tests/test_run_event_store.py +++ b/backend/tests/test_run_event_store.py @@ -1,14 +1,7 @@ -"""Tests for RunEventStore ABC + MemoryRunEventStore. +"""Tests for RunEventStore contract across all backends. -Covers: -- Basic write and query (put, seq assignment, cross-thread independence) -- list_messages (category filtering, pagination, cross-run ordering) -- list_events (run filtering, event_types filtering) -- list_messages_by_run -- count_messages -- put_batch -- delete_by_thread, delete_by_run -- Edge cases (empty thread/run) +Uses a helper to create the store for each backend type. +Memory tests run directly; DB and JSONL tests create stores inside each test. """ import pytest @@ -35,7 +28,6 @@ class TestPutAndSeq: assert record["event_type"] == "human_message" assert record["category"] == "message" assert record["content"] == "hello" - assert record["metadata"] == {} assert "created_at" in record @pytest.mark.anyio @@ -91,7 +83,6 @@ class TestListMessages: @pytest.mark.anyio async def test_before_seq_pagination(self, store): - # Put 10 messages with seq 1..10 for i in range(10): await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i)) messages = await store.list_messages("t1", before_seq=6, limit=3) @@ -236,7 +227,6 @@ class TestDelete: await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace") count = await store.delete_by_run("t1", "r2") assert count == 2 - # r1 events should still be there messages = await store.list_messages("t1") assert len(messages) == 1 assert messages[0]["run_id"] == "r1" @@ -270,3 +260,145 @@ class TestEdgeCases: @pytest.mark.anyio async def test_empty_thread_count_messages(self, store): assert await store.count_messages("empty") == 0 + + +# -- DB-specific tests -- + + +class TestDbRunEventStore: + """Tests for DbRunEventStore with temp SQLite.""" + + @pytest.mark.anyio + async def test_basic_crud(self, tmp_path): + from deerflow.persistence.engine import close_engine, get_session_factory, init_engine + from deerflow.runtime.events.store.db import DbRunEventStore + + url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" + await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) + s = DbRunEventStore(get_session_factory()) + + r = await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hi") + assert r["seq"] == 1 + r2 = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="hello") + assert r2["seq"] == 2 + + messages = await s.list_messages("t1") + assert len(messages) == 2 + + count = await s.count_messages("t1") + assert count == 2 + + await close_engine() + + @pytest.mark.anyio + async def test_trace_content_truncation(self, tmp_path): + from deerflow.persistence.engine import close_engine, get_session_factory, init_engine + from deerflow.runtime.events.store.db import DbRunEventStore + + url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" + await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) + s = DbRunEventStore(get_session_factory(), max_trace_content=100) + + long = "x" * 200 + r = await s.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace", content=long) + assert len(r["content"]) == 100 + assert r["metadata"].get("content_truncated") is True + + # message content NOT truncated + m = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content=long) + assert len(m["content"]) == 200 + + await close_engine() + + @pytest.mark.anyio + async def test_pagination(self, tmp_path): + from deerflow.persistence.engine import close_engine, get_session_factory, init_engine + from deerflow.runtime.events.store.db import DbRunEventStore + + url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" + await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) + s = DbRunEventStore(get_session_factory()) + + for i in range(10): + await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i)) + + # before_seq + msgs = await s.list_messages("t1", before_seq=6, limit=3) + assert [m["seq"] for m in msgs] == [3, 4, 5] + + # after_seq + msgs = await s.list_messages("t1", after_seq=7, limit=3) + assert [m["seq"] for m in msgs] == [8, 9, 10] + + # default (latest) + msgs = await s.list_messages("t1", limit=3) + assert [m["seq"] for m in msgs] == [8, 9, 10] + + await close_engine() + + @pytest.mark.anyio + async def test_delete(self, tmp_path): + from deerflow.persistence.engine import close_engine, get_session_factory, init_engine + from deerflow.runtime.events.store.db import DbRunEventStore + + url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" + await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) + s = DbRunEventStore(get_session_factory()) + + await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") + await s.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message") + c = await s.delete_by_run("t1", "r2") + assert c == 1 + assert await s.count_messages("t1") == 1 + + c = await s.delete_by_thread("t1") + assert c == 1 + assert await s.count_messages("t1") == 0 + + await close_engine() + + +# -- JSONL-specific tests -- + + +class TestJsonlRunEventStore: + @pytest.mark.anyio + async def test_basic_crud(self, tmp_path): + from deerflow.runtime.events.store.jsonl import JsonlRunEventStore + + s = JsonlRunEventStore(base_dir=tmp_path / "jsonl") + r = await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hi") + assert r["seq"] == 1 + messages = await s.list_messages("t1") + assert len(messages) == 1 + + @pytest.mark.anyio + async def test_file_at_correct_path(self, tmp_path): + from deerflow.runtime.events.store.jsonl import JsonlRunEventStore + + s = JsonlRunEventStore(base_dir=tmp_path / "jsonl") + await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") + assert (tmp_path / "jsonl" / "threads" / "t1" / "runs" / "r1.jsonl").exists() + + @pytest.mark.anyio + async def test_cross_run_messages(self, tmp_path): + from deerflow.runtime.events.store.jsonl import JsonlRunEventStore + + s = JsonlRunEventStore(base_dir=tmp_path / "jsonl") + await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") + await s.put(thread_id="t1", run_id="r2", event_type="human_message", category="message") + messages = await s.list_messages("t1") + assert len(messages) == 2 + assert [m["seq"] for m in messages] == [1, 2] + + @pytest.mark.anyio + async def test_delete_by_run(self, tmp_path): + from deerflow.runtime.events.store.jsonl import JsonlRunEventStore + + s = JsonlRunEventStore(base_dir=tmp_path / "jsonl") + await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message") + await s.put(thread_id="t1", run_id="r2", event_type="human_message", category="message") + c = await s.delete_by_run("t1", "r2") + assert c == 1 + assert not (tmp_path / "jsonl" / "threads" / "t1" / "runs" / "r2.jsonl").exists() + assert await s.count_messages("t1") == 1 diff --git a/backend/tests/test_run_journal.py b/backend/tests/test_run_journal.py new file mode 100644 index 000000000..28dcfbddc --- /dev/null +++ b/backend/tests/test_run_journal.py @@ -0,0 +1,230 @@ +"""Tests for RunJournal callback handler. + +Uses MemoryRunEventStore as the backend for direct event inspection. +""" + +import asyncio +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest + +from deerflow.runtime.events.store.memory import MemoryRunEventStore +from deerflow.runtime.journal import RunJournal + + +@pytest.fixture +def journal_setup(): + store = MemoryRunEventStore() + on_complete_data = {} + + def on_complete(**data): + on_complete_data.update(data) + + j = RunJournal("r1", "t1", store, on_complete=on_complete, flush_threshold=100) + return j, store, on_complete_data + + +def _make_llm_response(content="Hello", usage=None): + """Create a mock LLM response with a message.""" + msg = MagicMock() + msg.content = content + msg.tool_calls = [] + msg.response_metadata = {"model_name": "test-model"} + msg.usage_metadata = usage + + gen = MagicMock() + gen.message = msg + + response = MagicMock() + response.generations = [[gen]] + return response + + +class TestLlmCallbacks: + @pytest.mark.anyio + async def test_on_llm_end_produces_trace_event(self, journal_setup): + j, store, _ = journal_setup + run_id = uuid4() + j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, tags=["lead_agent"]) + await j.flush() + events = await store.list_events("t1", "r1") + trace_events = [e for e in events if e["event_type"] == "llm_end"] + assert len(trace_events) == 1 + assert trace_events[0]["category"] == "trace" + + @pytest.mark.anyio + async def test_on_llm_end_lead_agent_produces_ai_message(self, journal_setup): + j, store, _ = journal_setup + run_id = uuid4() + j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("Answer"), run_id=run_id, tags=["lead_agent"]) + await j.flush() + messages = await store.list_messages("t1") + assert len(messages) == 1 + assert messages[0]["event_type"] == "ai_message" + assert messages[0]["content"] == "Answer" + + @pytest.mark.anyio + async def test_on_llm_end_subagent_no_ai_message(self, journal_setup): + j, store, _ = journal_setup + run_id = uuid4() + j.on_llm_start({}, [], run_id=run_id, tags=["subagent:research"]) + j.on_llm_end(_make_llm_response("Sub answer"), run_id=run_id, tags=["subagent:research"]) + await j.flush() + messages = await store.list_messages("t1") + assert len(messages) == 0 + + @pytest.mark.anyio + async def test_token_accumulation(self, journal_setup): + j, store, on_complete_data = journal_setup + usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30} + j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4(), tags=["lead_agent"]) + j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4(), tags=["lead_agent"]) + assert j._total_input_tokens == 30 + assert j._total_output_tokens == 15 + assert j._total_tokens == 45 + assert j._llm_call_count == 2 + + @pytest.mark.anyio + async def test_caller_token_classification(self, journal_setup): + j, store, _ = journal_setup + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), tags=["lead_agent"]) + j.on_llm_start({}, [], run_id=uuid4(), tags=["subagent:research"]) + j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), tags=["subagent:research"]) + j.on_llm_start({}, [], run_id=uuid4(), tags=["middleware:summarization"]) + j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), tags=["middleware:summarization"]) + assert j._lead_agent_tokens == 15 + assert j._subagent_tokens == 15 + assert j._middleware_tokens == 15 + + @pytest.mark.anyio + async def test_usage_metadata_none_no_crash(self, journal_setup): + j, store, _ = journal_setup + j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("No usage", usage=None), run_id=uuid4(), tags=["lead_agent"]) + # Should not raise + await j.flush() + + @pytest.mark.anyio + async def test_latency_tracking(self, journal_setup): + j, store, _ = journal_setup + run_id = uuid4() + j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, tags=["lead_agent"]) + await j.flush() + events = await store.list_events("t1", "r1") + llm_end = [e for e in events if e["event_type"] == "llm_end"][0] + assert "latency_ms" in llm_end["metadata"] + assert llm_end["metadata"]["latency_ms"] is not None + + +class TestLifecycleCallbacks: + @pytest.mark.anyio + async def test_on_chain_end_triggers_on_complete(self, journal_setup): + j, store, on_complete_data = journal_setup + j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None) + j.on_chain_end({}, run_id=uuid4(), parent_run_id=None) + assert "total_tokens" in on_complete_data + assert "message_count" in on_complete_data + + @pytest.mark.anyio + async def test_nested_chain_ignored(self, journal_setup): + j, store, on_complete_data = journal_setup + parent_id = uuid4() + j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id) + j.on_chain_end({}, run_id=uuid4(), parent_run_id=parent_id) + await j.flush() + events = await store.list_events("t1", "r1") + lifecycle = [e for e in events if e["category"] == "lifecycle"] + assert len(lifecycle) == 0 + + +class TestToolCallbacks: + @pytest.mark.anyio + async def test_tool_start_end_produce_trace(self, journal_setup): + j, store, _ = journal_setup + j.on_tool_start({"name": "web_search"}, "query", run_id=uuid4()) + j.on_tool_end("results", run_id=uuid4(), name="web_search") + await j.flush() + events = await store.list_events("t1", "r1") + types = {e["event_type"] for e in events} + assert "tool_start" in types + assert "tool_end" in types + + +class TestCustomEvents: + @pytest.mark.anyio + async def test_summarization_event(self, journal_setup): + j, store, _ = journal_setup + j.on_custom_event( + "summarization", + {"summary": "Context was summarized.", "replaced_count": 5, "replaced_message_ids": ["a", "b"]}, + run_id=uuid4(), + ) + await j.flush() + events = await store.list_events("t1", "r1") + trace = [e for e in events if e["event_type"] == "summarization"] + assert len(trace) == 1 + messages = await store.list_messages("t1") + assert len(messages) == 1 + assert messages[0]["event_type"] == "summary" + + +class TestBufferFlush: + @pytest.mark.anyio + async def test_flush_threshold(self, journal_setup): + j, store, _ = journal_setup + j._flush_threshold = 3 + j.on_tool_start({"name": "a"}, "x", run_id=uuid4()) + j.on_tool_start({"name": "b"}, "x", run_id=uuid4()) + # Buffer has 2 events, not yet flushed + assert len(j._buffer) == 2 + j.on_tool_start({"name": "c"}, "x", run_id=uuid4()) + # Buffer should have been flushed (threshold=3 triggers flush) + # Give the async task a chance to complete + await asyncio.sleep(0.1) + events = await store.list_events("t1", "r1") + assert len(events) >= 3 + + +class TestIdentifyCaller: + def test_lead_agent_tag(self, journal_setup): + j, _, _ = journal_setup + assert j._identify_caller({"tags": ["lead_agent"]}) == "lead_agent" + + def test_subagent_tag(self, journal_setup): + j, _, _ = journal_setup + assert j._identify_caller({"tags": ["subagent:research"]}) == "subagent:research" + + def test_middleware_tag(self, journal_setup): + j, _, _ = journal_setup + assert j._identify_caller({"tags": ["middleware:summarization"]}) == "middleware:summarization" + + def test_no_tags_returns_unknown(self, journal_setup): + j, _, _ = journal_setup + assert j._identify_caller({"tags": []}) == "unknown" + assert j._identify_caller({}) == "unknown" + + +class TestPublicMethods: + @pytest.mark.anyio + async def test_set_first_human_message(self, journal_setup): + j, _, _ = journal_setup + j.set_first_human_message("Hello world") + assert j._first_human_msg == "Hello world" + + @pytest.mark.anyio + async def test_get_completion_data(self, journal_setup): + j, _, _ = journal_setup + j._total_tokens = 100 + j._msg_count = 5 + data = j.get_completion_data() + assert data["total_tokens"] == 100 + assert data["message_count"] == 5 diff --git a/backend/tests/test_run_repository.py b/backend/tests/test_run_repository.py new file mode 100644 index 000000000..5198952c4 --- /dev/null +++ b/backend/tests/test_run_repository.py @@ -0,0 +1,155 @@ +"""Tests for RunRepository (SQLAlchemy-backed RunStore). + +Uses a temp SQLite DB to test ORM-backed CRUD operations. +""" + +import pytest + +from deerflow.persistence.repositories.run_repo import RunRepository + + +async def _make_repo(tmp_path): + from deerflow.persistence.engine import get_session_factory, init_engine + + url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" + await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) + return RunRepository(get_session_factory()) + + +async def _cleanup(): + from deerflow.persistence.engine import close_engine + + await close_engine() + + +class TestRunRepository: + @pytest.mark.anyio + async def test_put_and_get(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", status="pending") + row = await repo.get("r1") + assert row is not None + assert row["run_id"] == "r1" + assert row["thread_id"] == "t1" + assert row["status"] == "pending" + await _cleanup() + + @pytest.mark.anyio + async def test_get_missing_returns_none(self, tmp_path): + repo = await _make_repo(tmp_path) + assert await repo.get("nope") is None + await _cleanup() + + @pytest.mark.anyio + async def test_update_status(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1") + await repo.update_status("r1", "running") + row = await repo.get("r1") + assert row["status"] == "running" + await _cleanup() + + @pytest.mark.anyio + async def test_update_status_with_error(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1") + await repo.update_status("r1", "error", error="boom") + row = await repo.get("r1") + assert row["status"] == "error" + assert row["error"] == "boom" + await _cleanup() + + @pytest.mark.anyio + async def test_list_by_thread(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1") + await repo.put("r2", thread_id="t1") + await repo.put("r3", thread_id="t2") + rows = await repo.list_by_thread("t1") + assert len(rows) == 2 + assert all(r["thread_id"] == "t1" for r in rows) + await _cleanup() + + @pytest.mark.anyio + async def test_list_by_thread_owner_filter(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", owner_id="alice") + await repo.put("r2", thread_id="t1", owner_id="bob") + rows = await repo.list_by_thread("t1", owner_id="alice") + assert len(rows) == 1 + assert rows[0]["owner_id"] == "alice" + await _cleanup() + + @pytest.mark.anyio + async def test_delete(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1") + await repo.delete("r1") + assert await repo.get("r1") is None + await _cleanup() + + @pytest.mark.anyio + async def test_delete_nonexistent_is_noop(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.delete("nope") # should not raise + await _cleanup() + + @pytest.mark.anyio + async def test_list_pending(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", status="pending") + await repo.put("r2", thread_id="t1", status="running") + await repo.put("r3", thread_id="t2", status="pending") + pending = await repo.list_pending() + assert len(pending) == 2 + assert all(r["status"] == "pending" for r in pending) + await _cleanup() + + @pytest.mark.anyio + async def test_update_run_completion(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", status="running") + await repo.update_run_completion( + "r1", + status="success", + total_input_tokens=100, + total_output_tokens=50, + total_tokens=150, + llm_call_count=2, + lead_agent_tokens=120, + subagent_tokens=20, + middleware_tokens=10, + message_count=3, + last_ai_message="The answer is 42", + first_human_message="What is the meaning?", + ) + row = await repo.get("r1") + assert row["status"] == "success" + assert row["total_tokens"] == 150 + assert row["llm_call_count"] == 2 + assert row["lead_agent_tokens"] == 120 + assert row["message_count"] == 3 + assert row["last_ai_message"] == "The answer is 42" + assert row["first_human_message"] == "What is the meaning?" + await _cleanup() + + @pytest.mark.anyio + async def test_metadata_preserved(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", metadata={"key": "value"}) + row = await repo.get("r1") + assert row["metadata"] == {"key": "value"} + await _cleanup() + + @pytest.mark.anyio + async def test_kwargs_with_non_serializable(self, tmp_path): + """kwargs containing non-JSON-serializable objects should be safely handled.""" + repo = await _make_repo(tmp_path) + + class Dummy: + pass + + await repo.put("r1", thread_id="t1", kwargs={"obj": Dummy()}) + row = await repo.get("r1") + assert "obj" in row["kwargs"] + await _cleanup()