feat(persistence): add ORM models, repositories, DB/JSONL event stores, RunJournal, and API endpoints

Phase 2-B: run persistence + event storage + token tracking.

- ORM models: RunRow (with token fields), ThreadMetaRow, RunEventRow
- RunRepository implements RunStore ABC via SQLAlchemy ORM
- ThreadMetaRepository with owner access control
- DbRunEventStore with trace content truncation and cursor pagination
- JsonlRunEventStore with per-run files and seq recovery from disk
- RunJournal (BaseCallbackHandler) captures LLM/tool/lifecycle events,
  accumulates token usage by caller type, buffers and flushes to store
- RunManager now accepts optional RunStore for persistent backing
- Worker creates RunJournal, writes human_message, injects callbacks
- Gateway deps use factory functions (RunRepository when DB available)
- New endpoints: messages, run messages, run events, token-usage
- ThreadCreateRequest gains assistant_id field
- 92 tests pass (33 new), zero regressions

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng 2026-04-02 19:03:38 +08:00
parent 23eacf9533
commit e3179cd54d
21 changed files with 1946 additions and 29 deletions

View File

@ -31,27 +31,23 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
from deerflow.config import get_app_config
from deerflow.persistence.engine import close_engine, init_engine_from_config
from deerflow.runtime import make_store, make_stream_bridge
from deerflow.runtime.runs.store.memory import MemoryRunStore
async with AsyncExitStack() as stack:
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
app.state.store = await stack.enter_async_context(make_store())
app.state.run_manager = RunManager()
# Initialize persistence layer from unified database config
config = get_app_config()
await init_engine_from_config(config.database)
# Initialize run store (MemoryRunStore for now; switch to ORM-backed
# RunRepository when models are implemented)
app.state.run_store = MemoryRunStore()
# Initialize run store (RunRepository if DB available, else MemoryRunStore)
app.state.run_store = _make_run_store()
# Initialize run event store (MemoryRunEventStore for now)
# TODO(Phase 2-B): switch to db/jsonl backend based on config.run_events.backend
from deerflow.runtime.events.store.memory import MemoryRunEventStore
# Initialize run event store based on config
app.state.run_event_store = _make_run_event_store(config)
app.state.run_event_store = MemoryRunEventStore()
# RunManager with store backing for persistence
app.state.run_manager = RunManager(store=app.state.run_store)
try:
yield
@ -59,6 +55,32 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
await close_engine()
# ---------------------------------------------------------------------------
# Factories
# ---------------------------------------------------------------------------
def _make_run_store() -> RunStore:
"""Create a RunStore: RunRepository if DB engine is available, else MemoryRunStore."""
from deerflow.persistence.engine import get_session_factory
sf = get_session_factory()
if sf is not None:
from deerflow.persistence.repositories.run_repo import RunRepository
return RunRepository(sf)
from deerflow.runtime.runs.store.memory import MemoryRunStore
return MemoryRunStore()
def _make_run_event_store(config) -> RunEventStore:
from deerflow.runtime.events.store import make_run_event_store
run_events_config = getattr(config, "run_events", None)
return make_run_event_store(run_events_config)
# ---------------------------------------------------------------------------
# Getters -- called by routers per-request
# ---------------------------------------------------------------------------

View File

@ -19,7 +19,7 @@ from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel, Field
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.services import sse_consumer, start_run
from deerflow.runtime import RunRecord, serialize_channel_values
@ -263,3 +263,77 @@ async def stream_existing_run(
"X-Accel-Buffering": "no",
},
)
# ---------------------------------------------------------------------------
# Messages / Events / Token usage endpoints
# ---------------------------------------------------------------------------
@router.get("/{thread_id}/messages")
async def list_thread_messages(
thread_id: str,
request: Request,
limit: int = Query(default=50, le=200),
before_seq: int | None = Query(default=None),
after_seq: int | None = Query(default=None),
) -> list[dict]:
"""Return displayable messages for a thread (across all runs)."""
event_store = get_run_event_store(request)
return await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
@router.get("/{thread_id}/runs/{run_id}/messages")
async def list_run_messages(thread_id: str, run_id: str, request: Request) -> list[dict]:
"""Return displayable messages for a specific run."""
event_store = get_run_event_store(request)
return await event_store.list_messages_by_run(thread_id, run_id)
@router.get("/{thread_id}/runs/{run_id}/events")
async def list_run_events(
thread_id: str,
run_id: str,
request: Request,
event_types: str | None = Query(default=None),
limit: int = Query(default=500, le=2000),
) -> list[dict]:
"""Return the full event stream for a run (debug/audit)."""
event_store = get_run_event_store(request)
types = event_types.split(",") if event_types else None
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
@router.get("/{thread_id}/token-usage")
async def thread_token_usage(thread_id: str, request: Request) -> dict:
"""Thread-level token usage aggregation."""
run_store = get_run_store(request)
runs = await run_store.list_by_thread(thread_id, limit=10000)
completed = [r for r in runs if r.get("status") in ("success", "error")]
total_tokens = sum(r.get("total_tokens", 0) for r in completed)
total_input = sum(r.get("total_input_tokens", 0) for r in completed)
total_output = sum(r.get("total_output_tokens", 0) for r in completed)
by_model: dict[str, dict] = {}
for r in completed:
model = r.get("model_name") or "unknown"
entry = by_model.setdefault(model, {"tokens": 0, "runs": 0})
entry["tokens"] += r.get("total_tokens", 0)
entry["runs"] += 1
by_caller = {
"lead_agent": sum(r.get("lead_agent_tokens", 0) for r in completed),
"subagent": sum(r.get("subagent_tokens", 0) for r in completed),
"middleware": sum(r.get("middleware_tokens", 0) for r in completed),
}
return {
"thread_id": thread_id,
"total_tokens": total_tokens,
"total_input_tokens": total_input,
"total_output_tokens": total_output,
"total_runs": len(completed),
"by_model": by_model,
"by_caller": by_caller,
}

View File

@ -63,6 +63,7 @@ class ThreadCreateRequest(BaseModel):
"""Request body for creating a thread."""
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")

View File

@ -17,7 +17,7 @@ from typing import Any
from fastapi import HTTPException, Request
from langchain_core.messages import HumanMessage
from app.gateway.deps import get_checkpointer, get_run_manager, get_store, get_stream_bridge
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_store, get_stream_bridge
from deerflow.runtime import (
END_SENTINEL,
HEARTBEAT_SENTINEL,
@ -245,6 +245,12 @@ async def start_run(
run_mgr = get_run_manager(request)
checkpointer = get_checkpointer(request)
store = get_store(request)
event_store = get_run_event_store(request)
# Get run_events config for journal
from deerflow.config import get_app_config
run_events_config = getattr(get_app_config(), "run_events", None)
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
@ -287,6 +293,8 @@ async def start_run(
stream_subgraphs=body.stream_subgraphs,
interrupt_before=body.interrupt_before,
interrupt_after=body.interrupt_after,
event_store=event_store,
run_events_config=run_events_config,
)
)
record.task = task

View File

@ -0,0 +1,5 @@
from deerflow.persistence.models.run import RunRow
from deerflow.persistence.models.run_event import RunEventRow
from deerflow.persistence.models.thread_meta import ThreadMetaRow
__all__ = ["RunEventRow", "RunRow", "ThreadMetaRow"]

View File

@ -0,0 +1,49 @@
"""ORM model for run metadata."""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import JSON, Index, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
class RunRow(Base):
__tablename__ = "runs"
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
assistant_id: Mapped[str | None] = mapped_column(String(128))
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
status: Mapped[str] = mapped_column(String(20), default="pending")
# "pending" | "running" | "success" | "error" | "timeout" | "interrupted"
model_name: Mapped[str | None] = mapped_column(String(128))
multitask_strategy: Mapped[str] = mapped_column(String(20), default="reject")
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
kwargs_json: Mapped[dict] = mapped_column(JSON, default=dict)
error: Mapped[str | None] = mapped_column(Text)
# Convenience fields (for listing pages without querying RunEventStore)
message_count: Mapped[int] = mapped_column(default=0)
first_human_message: Mapped[str | None] = mapped_column(Text)
last_ai_message: Mapped[str | None] = mapped_column(Text)
# Token usage (accumulated in-memory by RunJournal, written on run completion)
total_input_tokens: Mapped[int] = mapped_column(default=0)
total_output_tokens: Mapped[int] = mapped_column(default=0)
total_tokens: Mapped[int] = mapped_column(default=0)
llm_call_count: Mapped[int] = mapped_column(default=0)
lead_agent_tokens: Mapped[int] = mapped_column(default=0)
subagent_tokens: Mapped[int] = mapped_column(default=0)
middleware_tokens: Mapped[int] = mapped_column(default=0)
# Follow-up association
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64))
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))
updated_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
__table_args__ = (Index("ix_runs_thread_status", "thread_id", "status"),)

View File

@ -0,0 +1,30 @@
"""ORM model for run events."""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import JSON, Index, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
class RunEventRow(Base):
__tablename__ = "run_events"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
thread_id: Mapped[str] = mapped_column(String(64), nullable=False)
run_id: Mapped[str] = mapped_column(String(64), nullable=False)
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
category: Mapped[str] = mapped_column(String(16), nullable=False)
# "message" | "trace" | "lifecycle"
content: Mapped[str] = mapped_column(Text, default="")
event_metadata: Mapped[dict] = mapped_column(JSON, default=dict)
seq: Mapped[int] = mapped_column(nullable=False)
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))
__table_args__ = (
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
Index("ix_events_run", "thread_id", "run_id", "seq"),
)

View File

@ -0,0 +1,23 @@
"""ORM model for thread metadata."""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import JSON, String
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
class ThreadMetaRow(Base):
__tablename__ = "threads_meta"
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
assistant_id: Mapped[str | None] = mapped_column(String(128), index=True)
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
display_name: Mapped[str | None] = mapped_column(String(256))
status: Mapped[str] = mapped_column(String(20), default="idle")
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))
updated_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))

View File

@ -0,0 +1,4 @@
from deerflow.persistence.repositories.run_repo import RunRepository
from deerflow.persistence.repositories.thread_meta_repo import ThreadMetaRepository
__all__ = ["RunRepository", "ThreadMetaRepository"]

View File

@ -0,0 +1,174 @@
"""SQLAlchemy-backed RunStore implementation.
Each method acquires and releases its own short-lived session.
Run status updates happen from background workers that may live
minutes -- we don't hold connections across long execution.
"""
from __future__ import annotations
import json
import logging
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.run import RunRow
from deerflow.runtime.runs.store.base import RunStore
logger = logging.getLogger(__name__)
class RunRepository(RunStore):
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory
@staticmethod
def _safe_json(obj: Any) -> Any:
"""Ensure obj is JSON-serializable. Falls back to model_dump() or str()."""
if obj is None:
return None
if isinstance(obj, (str, int, float, bool)):
return obj
if isinstance(obj, dict):
return {k: RunRepository._safe_json(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return [RunRepository._safe_json(v) for v in obj]
if hasattr(obj, "model_dump"):
try:
return obj.model_dump()
except Exception:
pass
if hasattr(obj, "dict"):
try:
return obj.dict()
except Exception:
pass
try:
json.dumps(obj)
return obj
except (TypeError, ValueError):
return str(obj)
@staticmethod
def _row_to_dict(row: RunRow) -> dict[str, Any]:
d = row.to_dict()
# Remap JSON columns to match RunStore interface
d["metadata"] = d.pop("metadata_json", {})
d["kwargs"] = d.pop("kwargs_json", {})
# Convert datetime to ISO string for consistency with MemoryRunStore
for key in ("created_at", "updated_at"):
val = d.get(key)
if isinstance(val, datetime):
d[key] = val.isoformat()
return d
async def put(
self,
run_id,
*,
thread_id,
assistant_id=None,
owner_id=None,
status="pending",
multitask_strategy="reject",
metadata=None,
kwargs=None,
error=None,
created_at=None,
):
now = datetime.now(UTC)
row = RunRow(
run_id=run_id,
thread_id=thread_id,
assistant_id=assistant_id,
owner_id=owner_id,
status=status,
multitask_strategy=multitask_strategy,
metadata_json=self._safe_json(metadata) or {},
kwargs_json=self._safe_json(kwargs) or {},
error=error,
created_at=datetime.fromisoformat(created_at) if created_at else now,
updated_at=now,
)
async with self._sf() as session:
session.add(row)
await session.commit()
async def get(self, run_id):
async with self._sf() as session:
row = await session.get(RunRow, run_id)
return self._row_to_dict(row) if row else None
async def list_by_thread(self, thread_id, *, owner_id=None, limit=100):
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
if owner_id is not None:
stmt = stmt.where(RunRow.owner_id == owner_id)
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def update_status(self, run_id, status, *, error=None):
values: dict[str, Any] = {"status": status, "updated_at": datetime.now(UTC)}
if error is not None:
values["error"] = error
async with self._sf() as session:
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit()
async def delete(self, run_id):
async with self._sf() as session:
row = await session.get(RunRow, run_id)
if row is not None:
await session.delete(row)
await session.commit()
async def list_pending(self, *, before=None):
now = before or datetime.now(UTC).isoformat()
stmt = select(RunRow).where(RunRow.status == "pending", RunRow.created_at <= now).order_by(RunRow.created_at.asc())
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def update_run_completion(
self,
run_id: str,
*,
status: str,
total_input_tokens: int = 0,
total_output_tokens: int = 0,
total_tokens: int = 0,
llm_call_count: int = 0,
lead_agent_tokens: int = 0,
subagent_tokens: int = 0,
middleware_tokens: int = 0,
message_count: int = 0,
last_ai_message: str | None = None,
first_human_message: str | None = None,
error: str | None = None,
) -> None:
"""Update status + token usage + convenience fields on run completion."""
values: dict[str, Any] = {
"status": status,
"total_input_tokens": total_input_tokens,
"total_output_tokens": total_output_tokens,
"total_tokens": total_tokens,
"llm_call_count": llm_call_count,
"lead_agent_tokens": lead_agent_tokens,
"subagent_tokens": subagent_tokens,
"middleware_tokens": middleware_tokens,
"message_count": message_count,
"updated_at": datetime.now(UTC),
}
if last_ai_message is not None:
values["last_ai_message"] = last_ai_message[:2000]
if first_human_message is not None:
values["first_human_message"] = first_human_message[:2000]
if error is not None:
values["error"] = error
async with self._sf() as session:
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit()

View File

@ -0,0 +1,91 @@
"""SQLAlchemy-backed thread metadata repository."""
from __future__ import annotations
import logging
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.thread_meta import ThreadMetaRow
logger = logging.getLogger(__name__)
class ThreadMetaRepository:
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory
@staticmethod
def _row_to_dict(row: ThreadMetaRow) -> dict[str, Any]:
d = row.to_dict()
d["metadata"] = d.pop("metadata_json", {})
for key in ("created_at", "updated_at"):
val = d.get(key)
if isinstance(val, datetime):
d[key] = val.isoformat()
return d
async def create(
self,
thread_id: str,
*,
assistant_id: str | None = None,
owner_id: str | None = None,
display_name: str | None = None,
metadata: dict | None = None,
) -> dict:
now = datetime.now(UTC)
row = ThreadMetaRow(
thread_id=thread_id,
assistant_id=assistant_id,
owner_id=owner_id,
display_name=display_name,
metadata_json=metadata or {},
created_at=now,
updated_at=now,
)
async with self._sf() as session:
session.add(row)
await session.commit()
await session.refresh(row)
return self._row_to_dict(row)
async def get(self, thread_id: str) -> dict | None:
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
return self._row_to_dict(row) if row else None
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]:
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def check_access(self, thread_id: str, owner_id: str) -> bool:
"""Check if owner_id has access to thread_id.
Returns True if: row doesn't exist (untracked thread), owner_id
is None on the row (shared thread), or owner_id matches.
"""
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
if row is None:
return True
if row.owner_id is None:
return True
return row.owner_id == owner_id
async def update_status(self, thread_id: str, status: str) -> None:
async with self._sf() as session:
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
await session.commit()
async def delete(self, thread_id: str) -> None:
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
if row is not None:
await session.delete(row)
await session.commit()

View File

@ -1,4 +1,26 @@
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.events.store.memory import MemoryRunEventStore
__all__ = ["MemoryRunEventStore", "RunEventStore"]
def make_run_event_store(config=None) -> RunEventStore:
"""Create a RunEventStore based on run_events.backend configuration."""
if config is None or config.backend == "memory":
return MemoryRunEventStore()
if config.backend == "db":
from deerflow.persistence.engine import get_session_factory
sf = get_session_factory()
if sf is None:
# database.backend=memory but run_events.backend=db -> fallback
return MemoryRunEventStore()
from deerflow.runtime.events.store.db import DbRunEventStore
return DbRunEventStore(sf, max_trace_content=config.max_trace_content)
if config.backend == "jsonl":
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
return JsonlRunEventStore()
raise ValueError(f"Unknown run_events backend: {config.backend!r}")
__all__ = ["MemoryRunEventStore", "RunEventStore", "make_run_event_store"]

View File

@ -0,0 +1,148 @@
"""SQLAlchemy-backed RunEventStore implementation.
Persists events to the ``run_events`` table. Trace content is truncated
at ``max_trace_content`` bytes to avoid bloating the database.
"""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.run_event import RunEventRow
from deerflow.runtime.events.store.base import RunEventStore
class DbRunEventStore(RunEventStore):
def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, max_trace_content: int = 10240):
self._sf = session_factory
self._max_trace_content = max_trace_content
@staticmethod
def _row_to_dict(row: RunEventRow) -> dict:
d = row.to_dict()
d["metadata"] = d.pop("event_metadata", {})
val = d.get("created_at")
if isinstance(val, datetime):
d["created_at"] = val.isoformat()
d.pop("id", None)
return d
def _truncate_trace(self, category: str, content: str, metadata: dict | None) -> tuple[str, dict]:
if category == "trace" and len(content) > self._max_trace_content:
content = content[: self._max_trace_content]
metadata = {**(metadata or {}), "content_truncated": True}
return content, metadata or {}
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None):
content, metadata = self._truncate_trace(category, content, metadata)
async with self._sf() as session:
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id))
seq = (max_seq or 0) + 1
row = RunEventRow(
thread_id=thread_id,
run_id=run_id,
event_type=event_type,
category=category,
content=content,
event_metadata=metadata,
seq=seq,
created_at=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC),
)
session.add(row)
await session.commit()
await session.refresh(row)
return self._row_to_dict(row)
async def put_batch(self, events):
if not events:
return []
async with self._sf() as session:
# Get max seq for the thread (assume all events in batch belong to same thread)
thread_id = events[0]["thread_id"]
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id))
seq = max_seq or 0
rows = []
for e in events:
seq += 1
content = e.get("content", "")
category = e.get("category", "trace")
metadata = e.get("metadata")
content, metadata = self._truncate_trace(category, content, metadata)
row = RunEventRow(
thread_id=e["thread_id"],
run_id=e["run_id"],
event_type=e["event_type"],
category=category,
content=content,
event_metadata=metadata,
seq=seq,
created_at=datetime.fromisoformat(e["created_at"]) if e.get("created_at") else datetime.now(UTC),
)
session.add(row)
rows.append(row)
await session.commit()
for row in rows:
await session.refresh(row)
return [self._row_to_dict(r) for r in rows]
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
if before_seq is not None:
stmt = stmt.where(RunEventRow.seq < before_seq)
if after_seq is not None:
stmt = stmt.where(RunEventRow.seq > after_seq)
if after_seq is not None:
# Forward pagination: first `limit` records after cursor
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
else:
# before_seq or default (latest): take last `limit` records, return ascending
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
rows = list(result.scalars())
return [self._row_to_dict(r) for r in reversed(rows)]
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
if event_types:
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def list_messages_by_run(self, thread_id, run_id):
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message").order_by(RunEventRow.seq.asc())
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def count_messages(self, thread_id):
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
async with self._sf() as session:
return await session.scalar(stmt) or 0
async def delete_by_thread(self, thread_id):
async with self._sf() as session:
count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id)
count = await session.scalar(count_stmt) or 0
if count > 0:
await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id))
await session.commit()
return count
async def delete_by_run(self, thread_id, run_id):
async with self._sf() as session:
count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
count = await session.scalar(count_stmt) or 0
if count > 0:
await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id))
await session.commit()
return count

View File

@ -0,0 +1,164 @@
"""JSONL file-backed RunEventStore implementation.
Each run's events are stored in a single file:
``.deer-flow/threads/{thread_id}/runs/{run_id}.jsonl``
All categories (message, trace, lifecycle) are in the same file.
This backend is suitable for lightweight single-node deployments.
Known trade-off: ``list_messages()`` must scan all run files for a
thread since messages from multiple runs need unified seq ordering.
``list_events()`` reads only one file -- the fast path.
"""
from __future__ import annotations
import json
import logging
from datetime import UTC, datetime
from pathlib import Path
from deerflow.runtime.events.store.base import RunEventStore
logger = logging.getLogger(__name__)
class JsonlRunEventStore(RunEventStore):
def __init__(self, base_dir: str | Path | None = None):
self._base_dir = Path(base_dir) if base_dir else Path(".deer-flow")
self._seq_counters: dict[str, int] = {} # thread_id -> current max seq
def _thread_dir(self, thread_id: str) -> Path:
return self._base_dir / "threads" / thread_id / "runs"
def _run_file(self, thread_id: str, run_id: str) -> Path:
return self._thread_dir(thread_id) / f"{run_id}.jsonl"
def _next_seq(self, thread_id: str) -> int:
self._seq_counters[thread_id] = self._seq_counters.get(thread_id, 0) + 1
return self._seq_counters[thread_id]
def _ensure_seq_loaded(self, thread_id: str) -> None:
"""Load max seq from existing files if not yet cached."""
if thread_id in self._seq_counters:
return
max_seq = 0
thread_dir = self._thread_dir(thread_id)
if thread_dir.exists():
for f in thread_dir.glob("*.jsonl"):
for line in f.read_text(encoding="utf-8").strip().splitlines():
try:
record = json.loads(line)
max_seq = max(max_seq, record.get("seq", 0))
except json.JSONDecodeError:
continue
self._seq_counters[thread_id] = max_seq
def _write_record(self, record: dict) -> None:
path = self._run_file(record["thread_id"], record["run_id"])
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "a", encoding="utf-8") as f:
f.write(json.dumps(record, default=str, ensure_ascii=False) + "\n")
def _read_thread_events(self, thread_id: str) -> list[dict]:
"""Read all events for a thread, sorted by seq."""
events = []
thread_dir = self._thread_dir(thread_id)
if not thread_dir.exists():
return events
for f in sorted(thread_dir.glob("*.jsonl")):
for line in f.read_text(encoding="utf-8").strip().splitlines():
if not line:
continue
try:
events.append(json.loads(line))
except json.JSONDecodeError:
continue
events.sort(key=lambda e: e.get("seq", 0))
return events
def _read_run_events(self, thread_id: str, run_id: str) -> list[dict]:
"""Read events for a specific run file."""
path = self._run_file(thread_id, run_id)
if not path.exists():
return []
events = []
for line in path.read_text(encoding="utf-8").strip().splitlines():
if not line:
continue
try:
events.append(json.loads(line))
except json.JSONDecodeError:
continue
events.sort(key=lambda e: e.get("seq", 0))
return events
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None):
self._ensure_seq_loaded(thread_id)
seq = self._next_seq(thread_id)
record = {
"thread_id": thread_id,
"run_id": run_id,
"event_type": event_type,
"category": category,
"content": content,
"metadata": metadata or {},
"seq": seq,
"created_at": created_at or datetime.now(UTC).isoformat(),
}
self._write_record(record)
return record
async def put_batch(self, events):
if not events:
return []
results = []
for ev in events:
record = await self.put(**ev)
results.append(record)
return results
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
all_events = self._read_thread_events(thread_id)
messages = [e for e in all_events if e.get("category") == "message"]
if before_seq is not None:
messages = [e for e in messages if e["seq"] < before_seq]
return messages[-limit:]
elif after_seq is not None:
messages = [e for e in messages if e["seq"] > after_seq]
return messages[:limit]
else:
return messages[-limit:]
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
events = self._read_run_events(thread_id, run_id)
if event_types is not None:
events = [e for e in events if e.get("event_type") in event_types]
return events[:limit]
async def list_messages_by_run(self, thread_id, run_id):
events = self._read_run_events(thread_id, run_id)
return [e for e in events if e.get("category") == "message"]
async def count_messages(self, thread_id):
all_events = self._read_thread_events(thread_id)
return sum(1 for e in all_events if e.get("category") == "message")
async def delete_by_thread(self, thread_id):
all_events = self._read_thread_events(thread_id)
count = len(all_events)
thread_dir = self._thread_dir(thread_id)
if thread_dir.exists():
for f in thread_dir.glob("*.jsonl"):
f.unlink()
self._seq_counters.pop(thread_id, None)
return count
async def delete_by_run(self, thread_id, run_id):
events = self._read_run_events(thread_id, run_id)
count = len(events)
path = self._run_file(thread_id, run_id)
if path.exists():
path.unlink()
return count

View File

@ -0,0 +1,333 @@
"""Run event capture via LangChain callbacks.
RunJournal sits between LangChain's callback mechanism and the pluggable
RunEventStore. It standardizes callback data into RunEvent records and
handles token usage accumulation.
Key design decisions:
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
- All LangChain objects serialized via serialize_lc_object (same as worker.py SSE)
- Token usage accumulated in memory, written to RunRow on run completion
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
"""
from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import Callable
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler
if TYPE_CHECKING:
from deerflow.runtime.events.store.base import RunEventStore
logger = logging.getLogger(__name__)
class RunJournal(BaseCallbackHandler):
"""LangChain callback handler that captures events to RunEventStore."""
def __init__(
self,
run_id: str,
thread_id: str,
event_store: RunEventStore,
*,
track_token_usage: bool = True,
on_complete: Callable[..., Any] | None = None,
flush_threshold: int = 20,
):
super().__init__()
self.run_id = run_id
self.thread_id = thread_id
self._store = event_store
self._track_tokens = track_token_usage
self._on_complete = on_complete
self._flush_threshold = flush_threshold
# Write buffer
self._buffer: list[dict] = []
# Token accumulators
self._total_input_tokens = 0
self._total_output_tokens = 0
self._total_tokens = 0
self._llm_call_count = 0
self._lead_agent_tokens = 0
self._subagent_tokens = 0
self._middleware_tokens = 0
# Convenience fields
self._last_ai_msg: str | None = None
self._first_human_msg: str | None = None
self._msg_count = 0
# Latency tracking
self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time
# -- Lifecycle callbacks --
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
# Only record for the top-level chain (parent_run_id is None)
if kwargs.get("parent_run_id") is not None:
return
self._put(
event_type="run_start",
category="lifecycle",
metadata={"input_preview": str(inputs)[:500]},
)
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None:
return
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
self._flush_sync()
if self._on_complete:
self._on_complete(
total_input_tokens=self._total_input_tokens,
total_output_tokens=self._total_output_tokens,
total_tokens=self._total_tokens,
llm_call_count=self._llm_call_count,
lead_agent_tokens=self._lead_agent_tokens,
subagent_tokens=self._subagent_tokens,
middleware_tokens=self._middleware_tokens,
message_count=self._msg_count,
last_ai_message=self._last_ai_msg,
first_human_message=self._first_human_msg,
)
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None:
return
self._put(
event_type="run_error",
category="lifecycle",
content=str(error),
metadata={"error_type": type(error).__name__},
)
self._flush_sync()
# -- LLM callbacks --
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
self._llm_start_times[str(run_id)] = time.monotonic()
self._put(
event_type="llm_start",
category="trace",
metadata={"model_name": serialized.get("name", "")},
)
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
from deerflow.runtime.serialization import serialize_lc_object
try:
message = response.generations[0][0].message
except (IndexError, AttributeError):
logger.debug("on_llm_end: could not extract message from response")
return
serialized_msg = serialize_lc_object(message)
caller = self._identify_caller(kwargs)
# Latency
start = self._llm_start_times.pop(str(run_id), None)
latency_ms = int((time.monotonic() - start) * 1000) if start else None
# Token usage from message
usage = getattr(message, "usage_metadata", None)
usage_dict = dict(usage) if usage else {}
# trace event: llm_end (every LLM call)
self._put(
event_type="llm_end",
category="trace",
content=getattr(message, "content", "") if isinstance(getattr(message, "content", ""), str) else str(getattr(message, "content", "")),
metadata={
"message": serialized_msg,
"caller": caller,
"usage": usage_dict,
"latency_ms": latency_ms,
},
)
# message event: ai_message (only lead_agent final replies with content)
if caller == "lead_agent":
content = getattr(message, "content", "")
if isinstance(content, str) and content:
tool_calls = getattr(message, "tool_calls", None) or []
tool_calls_summary = [{"name": tc.get("name", ""), "status": "success"} for tc in tool_calls if isinstance(tc, dict)]
resp_meta = getattr(message, "response_metadata", None) or {}
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
self._put(
event_type="ai_message",
category="message",
content=content,
metadata={
"model_name": model_name,
"tool_calls": tool_calls_summary,
},
)
self._last_ai_msg = content[:2000]
self._msg_count += 1
# Token accumulation
input_tk = usage_dict.get("input_tokens", 0) or 0
output_tk = usage_dict.get("output_tokens", 0) or 0
total_tk = usage_dict.get("total_tokens", 0) or 0
if self._track_tokens and total_tk > 0:
self._total_input_tokens += input_tk
self._total_output_tokens += output_tk
self._total_tokens += total_tk
self._llm_call_count += 1
if caller.startswith("subagent:"):
self._subagent_tokens += total_tk
elif caller.startswith("middleware:"):
self._middleware_tokens += total_tk
else:
self._lead_agent_tokens += total_tk
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
self._llm_start_times.pop(str(run_id), None)
self._put(event_type="llm_error", category="trace", content=str(error))
# -- Tool callbacks --
def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None:
self._put(
event_type="tool_start",
category="trace",
metadata={
"tool_name": serialized.get("name", ""),
"tool_call_id": kwargs.get("tool_call_id"),
"args": str(input_str)[:2000],
},
)
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
self._put(
event_type="tool_end",
category="trace",
content=str(output),
metadata={
"tool_name": kwargs.get("name", ""),
"tool_call_id": kwargs.get("tool_call_id"),
"status": "success",
},
)
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
self._put(
event_type="tool_error",
category="trace",
content=str(error),
metadata={
"tool_name": kwargs.get("name", ""),
"tool_call_id": kwargs.get("tool_call_id"),
},
)
# -- Custom event callback --
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
from deerflow.runtime.serialization import serialize_lc_object
if name == "summarization":
data_dict = data if isinstance(data, dict) else {}
self._put(
event_type="summarization",
category="trace",
content=data_dict.get("summary", ""),
metadata={
"replaced_message_ids": data_dict.get("replaced_message_ids", []),
"replaced_count": data_dict.get("replaced_count", 0),
},
)
self._put(
event_type="summary",
category="message",
content=data_dict.get("summary", ""),
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
)
else:
event_data = serialize_lc_object(data) if not isinstance(data, dict) else data
self._put(
event_type=name,
category="trace",
metadata=event_data if isinstance(event_data, dict) else {"data": event_data},
)
# -- Internal methods --
def _put(self, *, event_type: str, category: str, content: str = "", metadata: dict | None = None) -> None:
self._buffer.append({
"thread_id": self.thread_id,
"run_id": self.run_id,
"event_type": event_type,
"category": category,
"content": content,
"metadata": metadata or {},
"created_at": datetime.now(UTC).isoformat(),
})
if len(self._buffer) >= self._flush_threshold:
self._flush_sync()
def _flush_sync(self) -> None:
"""Flush buffer to RunEventStore.
BaseCallbackHandler methods are synchronous. We schedule the async
put_batch via the current event loop.
"""
if not self._buffer:
return
batch = self._buffer.copy()
self._buffer.clear()
try:
loop = asyncio.get_running_loop()
loop.create_task(self._flush_async(batch))
except RuntimeError:
logger.warning("RunJournal: no event loop, dropping %d events", len(batch))
async def _flush_async(self, batch: list[dict]) -> None:
try:
await self._store.put_batch(batch)
except Exception:
logger.warning("RunJournal: failed to flush %d events", len(batch), exc_info=True)
def _identify_caller(self, kwargs: dict) -> str:
for tag in kwargs.get("tags") or []:
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
return tag
return "unknown"
# -- Public methods (called by worker) --
def set_first_human_message(self, content: str) -> None:
"""Record the first human message for convenience fields."""
self._first_human_msg = content[:2000] if content else None
async def flush(self) -> None:
"""Force flush. Used in cancel/error paths."""
if self._buffer:
batch = self._buffer.copy()
self._buffer.clear()
await self._store.put_batch(batch)
def get_completion_data(self) -> dict:
"""Return accumulated token and message data for run completion."""
return {
"total_input_tokens": self._total_input_tokens,
"total_output_tokens": self._total_output_tokens,
"total_tokens": self._total_tokens,
"llm_call_count": self._llm_call_count,
"lead_agent_tokens": self._lead_agent_tokens,
"subagent_tokens": self._subagent_tokens,
"middleware_tokens": self._middleware_tokens,
"message_count": self._msg_count,
"last_ai_message": self._last_ai_msg,
"first_human_message": self._first_human_msg,
}

View File

@ -1,4 +1,4 @@
"""In-memory run registry."""
"""In-memory run registry with optional persistent RunStore backing."""
from __future__ import annotations
@ -7,9 +7,13 @@ import logging
import uuid
from dataclasses import dataclass, field
from datetime import UTC, datetime
from typing import TYPE_CHECKING
from .schemas import DisconnectMode, RunStatus
if TYPE_CHECKING:
from deerflow.runtime.runs.store.base import RunStore
logger = logging.getLogger(__name__)
@ -38,11 +42,17 @@ class RunRecord:
class RunManager:
"""In-memory run registry. All mutations are protected by an asyncio lock."""
"""In-memory run registry with optional persistent RunStore backing.
def __init__(self) -> None:
All mutations are protected by an asyncio lock. When a ``store`` is
provided, serializable metadata is also persisted to the store so
that run history survives process restarts.
"""
def __init__(self, store: RunStore | None = None) -> None:
self._runs: dict[str, RunRecord] = {}
self._lock = asyncio.Lock()
self._store = store
async def create(
self,
@ -71,6 +81,20 @@ class RunManager:
)
async with self._lock:
self._runs[run_id] = record
if self._store is not None:
try:
await self._store.put(
run_id,
thread_id=thread_id,
assistant_id=assistant_id,
status=RunStatus.pending.value,
multitask_strategy=multitask_strategy,
metadata=metadata or {},
kwargs=kwargs or {},
created_at=now,
)
except Exception:
logger.warning("Failed to persist run %s to store", run_id, exc_info=True)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record
@ -96,6 +120,11 @@ class RunManager:
record.updated_at = _now_iso()
if error is not None:
record.error = error
if self._store is not None:
try:
await self._store.update_status(run_id, status.value, error=error)
except Exception:
logger.warning("Failed to persist status update for run %s", run_id, exc_info=True)
logger.info("Run %s -> %s", run_id, status.value)
async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool:
@ -185,6 +214,21 @@ class RunManager:
)
self._runs[run_id] = record
if self._store is not None:
try:
await self._store.put(
run_id,
thread_id=thread_id,
assistant_id=assistant_id,
status=RunStatus.pending.value,
multitask_strategy=multitask_strategy,
metadata=metadata or {},
kwargs=kwargs or {},
created_at=now,
)
except Exception:
logger.warning("Failed to persist run %s to store", run_id, exc_info=True)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record

View File

@ -45,6 +45,8 @@ async def run_agent(
stream_subgraphs: bool = False,
interrupt_before: list[str] | Literal["*"] | None = None,
interrupt_after: list[str] | Literal["*"] | None = None,
event_store: Any | None = None,
run_events_config: Any | None = None,
) -> None:
"""Execute an agent in the background, publishing events to *bridge*."""
@ -52,6 +54,30 @@ async def run_agent(
thread_id = record.thread_id
requested_modes: set[str] = set(stream_modes or ["values"])
# Initialize RunJournal for event capture
journal = None
if event_store is not None:
from deerflow.runtime.journal import RunJournal
journal = RunJournal(
run_id=run_id,
thread_id=thread_id,
event_store=event_store,
track_token_usage=getattr(run_events_config, "track_token_usage", True),
)
# Write human_message event
user_input = _extract_user_input(graph_input)
if user_input:
await event_store.put(
thread_id=thread_id,
run_id=run_id,
event_type="human_message",
category="message",
content=user_input,
)
journal.set_first_human_message(user_input)
# Track whether "events" was requested but skipped
if "events" in requested_modes:
logger.info(
@ -92,6 +118,10 @@ async def run_agent(
runtime = Runtime(context={"thread_id": thread_id}, store=store)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
# Inject RunJournal as a callback
if journal is not None:
config.setdefault("callbacks", []).append(journal)
runnable_config = RunnableConfig(**config)
agent = agent_factory(config=runnable_config)
@ -206,6 +236,13 @@ async def run_agent(
)
finally:
# Flush any buffered journal events
if journal is not None:
try:
await journal.flush()
except Exception:
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
await bridge.publish_end(run_id)
asyncio.create_task(bridge.cleanup(run_id, delay=60))
@ -227,6 +264,23 @@ def _lg_mode_to_sse_event(mode: str) -> str:
return mode
def _extract_user_input(graph_input: dict) -> str:
"""Extract user input text from graph_input for event recording."""
messages = graph_input.get("messages")
if not messages:
return ""
# Take the last message (usually the user's input)
last = messages[-1] if isinstance(messages, list) else messages
if isinstance(last, str):
return last
if hasattr(last, "content"):
content = last.content
return content if isinstance(content, str) else str(content)
if isinstance(last, dict):
return str(last.get("content", ""))
return ""
def _unpack_stream_item(
item: Any,
lg_modes: list[str],

View File

@ -0,0 +1,154 @@
"""Phase 2-B integration tests.
End-to-end test: simulate a run's complete lifecycle, verify data
is correctly written to both RunStore and RunEventStore.
"""
import asyncio
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from deerflow.runtime.events.store.memory import MemoryRunEventStore
from deerflow.runtime.journal import RunJournal
from deerflow.runtime.runs.store.memory import MemoryRunStore
def _make_llm_response(content="Hello", usage=None):
msg = MagicMock()
msg.content = content
msg.tool_calls = []
msg.response_metadata = {"model_name": "test-model"}
msg.usage_metadata = usage
gen = MagicMock()
gen.message = msg
response = MagicMock()
response.generations = [[gen]]
return response
class TestRunLifecycle:
@pytest.mark.anyio
async def test_full_run_lifecycle(self):
"""Simulate a complete run lifecycle with RunStore + RunEventStore."""
run_store = MemoryRunStore()
event_store = MemoryRunEventStore()
# 1. Create run
await run_store.put("r1", thread_id="t1", status="pending")
# 2. Write human_message
await event_store.put(
thread_id="t1",
run_id="r1",
event_type="human_message",
category="message",
content="What is AI?",
)
# 3. Simulate RunJournal callback sequence
on_complete_data = {}
def on_complete(**data):
on_complete_data.update(data)
journal = RunJournal("r1", "t1", event_store, on_complete=on_complete, flush_threshold=100)
journal.set_first_human_message("What is AI?")
# chain_start (top-level)
journal.on_chain_start({}, {"messages": ["What is AI?"]}, run_id=uuid4(), parent_run_id=None)
# llm_start + llm_end
llm_run_id = uuid4()
journal.on_llm_start({"name": "gpt-4"}, ["prompt"], run_id=llm_run_id, tags=["lead_agent"])
usage = {"input_tokens": 50, "output_tokens": 100, "total_tokens": 150}
journal.on_llm_end(_make_llm_response("AI is artificial intelligence.", usage=usage), run_id=llm_run_id, tags=["lead_agent"])
# chain_end (triggers on_complete + flush_sync which creates a task)
journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
await journal.flush()
# Let event loop process any pending flush tasks from _flush_sync
await asyncio.sleep(0.05)
# 4. Verify messages
messages = await event_store.list_messages("t1")
assert len(messages) == 2 # human + ai
assert messages[0]["event_type"] == "human_message"
assert messages[1]["event_type"] == "ai_message"
assert messages[1]["content"] == "AI is artificial intelligence."
# 5. Verify events
events = await event_store.list_events("t1", "r1")
event_types = {e["event_type"] for e in events}
assert "run_start" in event_types
assert "llm_start" in event_types
assert "llm_end" in event_types
assert "run_end" in event_types
# 6. Verify on_complete data
assert on_complete_data["total_tokens"] == 150
assert on_complete_data["llm_call_count"] == 1
assert on_complete_data["lead_agent_tokens"] == 150
assert on_complete_data["message_count"] == 1
assert on_complete_data["last_ai_message"] == "AI is artificial intelligence."
assert on_complete_data["first_human_message"] == "What is AI?"
@pytest.mark.anyio
async def test_run_with_tool_calls(self):
"""Simulate a run that uses tools."""
event_store = MemoryRunEventStore()
journal = RunJournal("r1", "t1", event_store, flush_threshold=100)
# tool_start + tool_end
journal.on_tool_start({"name": "web_search"}, '{"query": "AI"}', run_id=uuid4())
journal.on_tool_end("Search results...", run_id=uuid4(), name="web_search")
await journal.flush()
events = await event_store.list_events("t1", "r1")
assert len(events) == 2
assert events[0]["event_type"] == "tool_start"
assert events[1]["event_type"] == "tool_end"
@pytest.mark.anyio
async def test_multi_run_thread(self):
"""Multiple runs on the same thread maintain unified seq ordering."""
event_store = MemoryRunEventStore()
# Run 1
await event_store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="Q1")
await event_store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="A1")
# Run 2
await event_store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message", content="Q2")
await event_store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message", content="A2")
messages = await event_store.list_messages("t1")
assert len(messages) == 4
assert [m["seq"] for m in messages] == [1, 2, 3, 4]
assert messages[0]["run_id"] == "r1"
assert messages[2]["run_id"] == "r2"
@pytest.mark.anyio
async def test_runmanager_with_store_backing(self):
"""RunManager persists to RunStore when one is provided."""
from deerflow.runtime.runs.manager import RunManager
run_store = MemoryRunStore()
mgr = RunManager(store=run_store)
record = await mgr.create("t1", assistant_id="lead_agent")
# Verify persisted to store
row = await run_store.get(record.run_id)
assert row is not None
assert row["thread_id"] == "t1"
assert row["status"] == "pending"
# Status update
from deerflow.runtime.runs.schemas import RunStatus
await mgr.set_status(record.run_id, RunStatus.running)
row = await run_store.get(record.run_id)
assert row["status"] == "running"

View File

@ -1,14 +1,7 @@
"""Tests for RunEventStore ABC + MemoryRunEventStore.
"""Tests for RunEventStore contract across all backends.
Covers:
- Basic write and query (put, seq assignment, cross-thread independence)
- list_messages (category filtering, pagination, cross-run ordering)
- list_events (run filtering, event_types filtering)
- list_messages_by_run
- count_messages
- put_batch
- delete_by_thread, delete_by_run
- Edge cases (empty thread/run)
Uses a helper to create the store for each backend type.
Memory tests run directly; DB and JSONL tests create stores inside each test.
"""
import pytest
@ -35,7 +28,6 @@ class TestPutAndSeq:
assert record["event_type"] == "human_message"
assert record["category"] == "message"
assert record["content"] == "hello"
assert record["metadata"] == {}
assert "created_at" in record
@pytest.mark.anyio
@ -91,7 +83,6 @@ class TestListMessages:
@pytest.mark.anyio
async def test_before_seq_pagination(self, store):
# Put 10 messages with seq 1..10
for i in range(10):
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
messages = await store.list_messages("t1", before_seq=6, limit=3)
@ -236,7 +227,6 @@ class TestDelete:
await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace")
count = await store.delete_by_run("t1", "r2")
assert count == 2
# r1 events should still be there
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["run_id"] == "r1"
@ -270,3 +260,145 @@ class TestEdgeCases:
@pytest.mark.anyio
async def test_empty_thread_count_messages(self, store):
assert await store.count_messages("empty") == 0
# -- DB-specific tests --
class TestDbRunEventStore:
"""Tests for DbRunEventStore with temp SQLite."""
@pytest.mark.anyio
async def test_basic_crud(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.runtime.events.store.db import DbRunEventStore
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
s = DbRunEventStore(get_session_factory())
r = await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hi")
assert r["seq"] == 1
r2 = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="hello")
assert r2["seq"] == 2
messages = await s.list_messages("t1")
assert len(messages) == 2
count = await s.count_messages("t1")
assert count == 2
await close_engine()
@pytest.mark.anyio
async def test_trace_content_truncation(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.runtime.events.store.db import DbRunEventStore
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
s = DbRunEventStore(get_session_factory(), max_trace_content=100)
long = "x" * 200
r = await s.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace", content=long)
assert len(r["content"]) == 100
assert r["metadata"].get("content_truncated") is True
# message content NOT truncated
m = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content=long)
assert len(m["content"]) == 200
await close_engine()
@pytest.mark.anyio
async def test_pagination(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.runtime.events.store.db import DbRunEventStore
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
s = DbRunEventStore(get_session_factory())
for i in range(10):
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
# before_seq
msgs = await s.list_messages("t1", before_seq=6, limit=3)
assert [m["seq"] for m in msgs] == [3, 4, 5]
# after_seq
msgs = await s.list_messages("t1", after_seq=7, limit=3)
assert [m["seq"] for m in msgs] == [8, 9, 10]
# default (latest)
msgs = await s.list_messages("t1", limit=3)
assert [m["seq"] for m in msgs] == [8, 9, 10]
await close_engine()
@pytest.mark.anyio
async def test_delete(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.runtime.events.store.db import DbRunEventStore
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
s = DbRunEventStore(get_session_factory())
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await s.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message")
c = await s.delete_by_run("t1", "r2")
assert c == 1
assert await s.count_messages("t1") == 1
c = await s.delete_by_thread("t1")
assert c == 1
assert await s.count_messages("t1") == 0
await close_engine()
# -- JSONL-specific tests --
class TestJsonlRunEventStore:
@pytest.mark.anyio
async def test_basic_crud(self, tmp_path):
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
r = await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hi")
assert r["seq"] == 1
messages = await s.list_messages("t1")
assert len(messages) == 1
@pytest.mark.anyio
async def test_file_at_correct_path(self, tmp_path):
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
assert (tmp_path / "jsonl" / "threads" / "t1" / "runs" / "r1.jsonl").exists()
@pytest.mark.anyio
async def test_cross_run_messages(self, tmp_path):
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await s.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
messages = await s.list_messages("t1")
assert len(messages) == 2
assert [m["seq"] for m in messages] == [1, 2]
@pytest.mark.anyio
async def test_delete_by_run(self, tmp_path):
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
await s.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
c = await s.delete_by_run("t1", "r2")
assert c == 1
assert not (tmp_path / "jsonl" / "threads" / "t1" / "runs" / "r2.jsonl").exists()
assert await s.count_messages("t1") == 1

View File

@ -0,0 +1,230 @@
"""Tests for RunJournal callback handler.
Uses MemoryRunEventStore as the backend for direct event inspection.
"""
import asyncio
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from deerflow.runtime.events.store.memory import MemoryRunEventStore
from deerflow.runtime.journal import RunJournal
@pytest.fixture
def journal_setup():
store = MemoryRunEventStore()
on_complete_data = {}
def on_complete(**data):
on_complete_data.update(data)
j = RunJournal("r1", "t1", store, on_complete=on_complete, flush_threshold=100)
return j, store, on_complete_data
def _make_llm_response(content="Hello", usage=None):
"""Create a mock LLM response with a message."""
msg = MagicMock()
msg.content = content
msg.tool_calls = []
msg.response_metadata = {"model_name": "test-model"}
msg.usage_metadata = usage
gen = MagicMock()
gen.message = msg
response = MagicMock()
response.generations = [[gen]]
return response
class TestLlmCallbacks:
@pytest.mark.anyio
async def test_on_llm_end_produces_trace_event(self, journal_setup):
j, store, _ = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, tags=["lead_agent"])
await j.flush()
events = await store.list_events("t1", "r1")
trace_events = [e for e in events if e["event_type"] == "llm_end"]
assert len(trace_events) == 1
assert trace_events[0]["category"] == "trace"
@pytest.mark.anyio
async def test_on_llm_end_lead_agent_produces_ai_message(self, journal_setup):
j, store, _ = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Answer"), run_id=run_id, tags=["lead_agent"])
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["event_type"] == "ai_message"
assert messages[0]["content"] == "Answer"
@pytest.mark.anyio
async def test_on_llm_end_subagent_no_ai_message(self, journal_setup):
j, store, _ = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["subagent:research"])
j.on_llm_end(_make_llm_response("Sub answer"), run_id=run_id, tags=["subagent:research"])
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 0
@pytest.mark.anyio
async def test_token_accumulation(self, journal_setup):
j, store, on_complete_data = journal_setup
usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4(), tags=["lead_agent"])
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4(), tags=["lead_agent"])
assert j._total_input_tokens == 30
assert j._total_output_tokens == 15
assert j._total_tokens == 45
assert j._llm_call_count == 2
@pytest.mark.anyio
async def test_caller_token_classification(self, journal_setup):
j, store, _ = journal_setup
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), tags=["lead_agent"])
j.on_llm_start({}, [], run_id=uuid4(), tags=["subagent:research"])
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), tags=["subagent:research"])
j.on_llm_start({}, [], run_id=uuid4(), tags=["middleware:summarization"])
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), tags=["middleware:summarization"])
assert j._lead_agent_tokens == 15
assert j._subagent_tokens == 15
assert j._middleware_tokens == 15
@pytest.mark.anyio
async def test_usage_metadata_none_no_crash(self, journal_setup):
j, store, _ = journal_setup
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("No usage", usage=None), run_id=uuid4(), tags=["lead_agent"])
# Should not raise
await j.flush()
@pytest.mark.anyio
async def test_latency_tracking(self, journal_setup):
j, store, _ = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, tags=["lead_agent"])
await j.flush()
events = await store.list_events("t1", "r1")
llm_end = [e for e in events if e["event_type"] == "llm_end"][0]
assert "latency_ms" in llm_end["metadata"]
assert llm_end["metadata"]["latency_ms"] is not None
class TestLifecycleCallbacks:
@pytest.mark.anyio
async def test_on_chain_end_triggers_on_complete(self, journal_setup):
j, store, on_complete_data = journal_setup
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None)
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
assert "total_tokens" in on_complete_data
assert "message_count" in on_complete_data
@pytest.mark.anyio
async def test_nested_chain_ignored(self, journal_setup):
j, store, on_complete_data = journal_setup
parent_id = uuid4()
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id)
j.on_chain_end({}, run_id=uuid4(), parent_run_id=parent_id)
await j.flush()
events = await store.list_events("t1", "r1")
lifecycle = [e for e in events if e["category"] == "lifecycle"]
assert len(lifecycle) == 0
class TestToolCallbacks:
@pytest.mark.anyio
async def test_tool_start_end_produce_trace(self, journal_setup):
j, store, _ = journal_setup
j.on_tool_start({"name": "web_search"}, "query", run_id=uuid4())
j.on_tool_end("results", run_id=uuid4(), name="web_search")
await j.flush()
events = await store.list_events("t1", "r1")
types = {e["event_type"] for e in events}
assert "tool_start" in types
assert "tool_end" in types
class TestCustomEvents:
@pytest.mark.anyio
async def test_summarization_event(self, journal_setup):
j, store, _ = journal_setup
j.on_custom_event(
"summarization",
{"summary": "Context was summarized.", "replaced_count": 5, "replaced_message_ids": ["a", "b"]},
run_id=uuid4(),
)
await j.flush()
events = await store.list_events("t1", "r1")
trace = [e for e in events if e["event_type"] == "summarization"]
assert len(trace) == 1
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["event_type"] == "summary"
class TestBufferFlush:
@pytest.mark.anyio
async def test_flush_threshold(self, journal_setup):
j, store, _ = journal_setup
j._flush_threshold = 3
j.on_tool_start({"name": "a"}, "x", run_id=uuid4())
j.on_tool_start({"name": "b"}, "x", run_id=uuid4())
# Buffer has 2 events, not yet flushed
assert len(j._buffer) == 2
j.on_tool_start({"name": "c"}, "x", run_id=uuid4())
# Buffer should have been flushed (threshold=3 triggers flush)
# Give the async task a chance to complete
await asyncio.sleep(0.1)
events = await store.list_events("t1", "r1")
assert len(events) >= 3
class TestIdentifyCaller:
def test_lead_agent_tag(self, journal_setup):
j, _, _ = journal_setup
assert j._identify_caller({"tags": ["lead_agent"]}) == "lead_agent"
def test_subagent_tag(self, journal_setup):
j, _, _ = journal_setup
assert j._identify_caller({"tags": ["subagent:research"]}) == "subagent:research"
def test_middleware_tag(self, journal_setup):
j, _, _ = journal_setup
assert j._identify_caller({"tags": ["middleware:summarization"]}) == "middleware:summarization"
def test_no_tags_returns_unknown(self, journal_setup):
j, _, _ = journal_setup
assert j._identify_caller({"tags": []}) == "unknown"
assert j._identify_caller({}) == "unknown"
class TestPublicMethods:
@pytest.mark.anyio
async def test_set_first_human_message(self, journal_setup):
j, _, _ = journal_setup
j.set_first_human_message("Hello world")
assert j._first_human_msg == "Hello world"
@pytest.mark.anyio
async def test_get_completion_data(self, journal_setup):
j, _, _ = journal_setup
j._total_tokens = 100
j._msg_count = 5
data = j.get_completion_data()
assert data["total_tokens"] == 100
assert data["message_count"] == 5

View File

@ -0,0 +1,155 @@
"""Tests for RunRepository (SQLAlchemy-backed RunStore).
Uses a temp SQLite DB to test ORM-backed CRUD operations.
"""
import pytest
from deerflow.persistence.repositories.run_repo import RunRepository
async def _make_repo(tmp_path):
from deerflow.persistence.engine import get_session_factory, init_engine
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
return RunRepository(get_session_factory())
async def _cleanup():
from deerflow.persistence.engine import close_engine
await close_engine()
class TestRunRepository:
@pytest.mark.anyio
async def test_put_and_get(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="pending")
row = await repo.get("r1")
assert row is not None
assert row["run_id"] == "r1"
assert row["thread_id"] == "t1"
assert row["status"] == "pending"
await _cleanup()
@pytest.mark.anyio
async def test_get_missing_returns_none(self, tmp_path):
repo = await _make_repo(tmp_path)
assert await repo.get("nope") is None
await _cleanup()
@pytest.mark.anyio
async def test_update_status(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1")
await repo.update_status("r1", "running")
row = await repo.get("r1")
assert row["status"] == "running"
await _cleanup()
@pytest.mark.anyio
async def test_update_status_with_error(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1")
await repo.update_status("r1", "error", error="boom")
row = await repo.get("r1")
assert row["status"] == "error"
assert row["error"] == "boom"
await _cleanup()
@pytest.mark.anyio
async def test_list_by_thread(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1")
await repo.put("r2", thread_id="t1")
await repo.put("r3", thread_id="t2")
rows = await repo.list_by_thread("t1")
assert len(rows) == 2
assert all(r["thread_id"] == "t1" for r in rows)
await _cleanup()
@pytest.mark.anyio
async def test_list_by_thread_owner_filter(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", owner_id="alice")
await repo.put("r2", thread_id="t1", owner_id="bob")
rows = await repo.list_by_thread("t1", owner_id="alice")
assert len(rows) == 1
assert rows[0]["owner_id"] == "alice"
await _cleanup()
@pytest.mark.anyio
async def test_delete(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1")
await repo.delete("r1")
assert await repo.get("r1") is None
await _cleanup()
@pytest.mark.anyio
async def test_delete_nonexistent_is_noop(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.delete("nope") # should not raise
await _cleanup()
@pytest.mark.anyio
async def test_list_pending(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="pending")
await repo.put("r2", thread_id="t1", status="running")
await repo.put("r3", thread_id="t2", status="pending")
pending = await repo.list_pending()
assert len(pending) == 2
assert all(r["status"] == "pending" for r in pending)
await _cleanup()
@pytest.mark.anyio
async def test_update_run_completion(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_completion(
"r1",
status="success",
total_input_tokens=100,
total_output_tokens=50,
total_tokens=150,
llm_call_count=2,
lead_agent_tokens=120,
subagent_tokens=20,
middleware_tokens=10,
message_count=3,
last_ai_message="The answer is 42",
first_human_message="What is the meaning?",
)
row = await repo.get("r1")
assert row["status"] == "success"
assert row["total_tokens"] == 150
assert row["llm_call_count"] == 2
assert row["lead_agent_tokens"] == 120
assert row["message_count"] == 3
assert row["last_ai_message"] == "The answer is 42"
assert row["first_human_message"] == "What is the meaning?"
await _cleanup()
@pytest.mark.anyio
async def test_metadata_preserved(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", metadata={"key": "value"})
row = await repo.get("r1")
assert row["metadata"] == {"key": "value"}
await _cleanup()
@pytest.mark.anyio
async def test_kwargs_with_non_serializable(self, tmp_path):
"""kwargs containing non-JSON-serializable objects should be safely handled."""
repo = await _make_repo(tmp_path)
class Dummy:
pass
await repo.put("r1", thread_id="t1", kwargs={"obj": Dummy()})
row = await repo.get("r1")
assert "obj" in row["kwargs"]
await _cleanup()