mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-29 13:28:11 +00:00
feat(persistence): add ORM models, repositories, DB/JSONL event stores, RunJournal, and API endpoints
Phase 2-B: run persistence + event storage + token tracking. - ORM models: RunRow (with token fields), ThreadMetaRow, RunEventRow - RunRepository implements RunStore ABC via SQLAlchemy ORM - ThreadMetaRepository with owner access control - DbRunEventStore with trace content truncation and cursor pagination - JsonlRunEventStore with per-run files and seq recovery from disk - RunJournal (BaseCallbackHandler) captures LLM/tool/lifecycle events, accumulates token usage by caller type, buffers and flushes to store - RunManager now accepts optional RunStore for persistent backing - Worker creates RunJournal, writes human_message, injects callbacks - Gateway deps use factory functions (RunRepository when DB available) - New endpoints: messages, run messages, run events, token-usage - ThreadCreateRequest gains assistant_id field - 92 tests pass (33 new), zero regressions Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
23eacf9533
commit
e3179cd54d
@ -31,27 +31,23 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.persistence.engine import close_engine, init_engine_from_config
|
||||
from deerflow.runtime import make_store, make_stream_bridge
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
|
||||
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
|
||||
app.state.store = await stack.enter_async_context(make_store())
|
||||
app.state.run_manager = RunManager()
|
||||
|
||||
# Initialize persistence layer from unified database config
|
||||
config = get_app_config()
|
||||
await init_engine_from_config(config.database)
|
||||
|
||||
# Initialize run store (MemoryRunStore for now; switch to ORM-backed
|
||||
# RunRepository when models are implemented)
|
||||
app.state.run_store = MemoryRunStore()
|
||||
# Initialize run store (RunRepository if DB available, else MemoryRunStore)
|
||||
app.state.run_store = _make_run_store()
|
||||
|
||||
# Initialize run event store (MemoryRunEventStore for now)
|
||||
# TODO(Phase 2-B): switch to db/jsonl backend based on config.run_events.backend
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
# Initialize run event store based on config
|
||||
app.state.run_event_store = _make_run_event_store(config)
|
||||
|
||||
app.state.run_event_store = MemoryRunEventStore()
|
||||
# RunManager with store backing for persistence
|
||||
app.state.run_manager = RunManager(store=app.state.run_store)
|
||||
|
||||
try:
|
||||
yield
|
||||
@ -59,6 +55,32 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
await close_engine()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_run_store() -> RunStore:
|
||||
"""Create a RunStore: RunRepository if DB engine is available, else MemoryRunStore."""
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
|
||||
sf = get_session_factory()
|
||||
if sf is not None:
|
||||
from deerflow.persistence.repositories.run_repo import RunRepository
|
||||
|
||||
return RunRepository(sf)
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
return MemoryRunStore()
|
||||
|
||||
|
||||
def _make_run_event_store(config) -> RunEventStore:
|
||||
from deerflow.runtime.events.store import make_run_event_store
|
||||
|
||||
run_events_config = getattr(config, "run_events", None)
|
||||
return make_run_event_store(run_events_config)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Getters -- called by routers per-request
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -19,7 +19,7 @@ from fastapi import APIRouter, HTTPException, Query, Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
|
||||
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from deerflow.runtime import RunRecord, serialize_channel_values
|
||||
|
||||
@ -263,3 +263,77 @@ async def stream_existing_run(
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Messages / Events / Token usage endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/{thread_id}/messages")
|
||||
async def list_thread_messages(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
limit: int = Query(default=50, le=200),
|
||||
before_seq: int | None = Query(default=None),
|
||||
after_seq: int | None = Query(default=None),
|
||||
) -> list[dict]:
|
||||
"""Return displayable messages for a thread (across all runs)."""
|
||||
event_store = get_run_event_store(request)
|
||||
return await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/messages")
|
||||
async def list_run_messages(thread_id: str, run_id: str, request: Request) -> list[dict]:
|
||||
"""Return displayable messages for a specific run."""
|
||||
event_store = get_run_event_store(request)
|
||||
return await event_store.list_messages_by_run(thread_id, run_id)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/events")
|
||||
async def list_run_events(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
request: Request,
|
||||
event_types: str | None = Query(default=None),
|
||||
limit: int = Query(default=500, le=2000),
|
||||
) -> list[dict]:
|
||||
"""Return the full event stream for a run (debug/audit)."""
|
||||
event_store = get_run_event_store(request)
|
||||
types = event_types.split(",") if event_types else None
|
||||
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/token-usage")
|
||||
async def thread_token_usage(thread_id: str, request: Request) -> dict:
|
||||
"""Thread-level token usage aggregation."""
|
||||
run_store = get_run_store(request)
|
||||
runs = await run_store.list_by_thread(thread_id, limit=10000)
|
||||
completed = [r for r in runs if r.get("status") in ("success", "error")]
|
||||
|
||||
total_tokens = sum(r.get("total_tokens", 0) for r in completed)
|
||||
total_input = sum(r.get("total_input_tokens", 0) for r in completed)
|
||||
total_output = sum(r.get("total_output_tokens", 0) for r in completed)
|
||||
|
||||
by_model: dict[str, dict] = {}
|
||||
for r in completed:
|
||||
model = r.get("model_name") or "unknown"
|
||||
entry = by_model.setdefault(model, {"tokens": 0, "runs": 0})
|
||||
entry["tokens"] += r.get("total_tokens", 0)
|
||||
entry["runs"] += 1
|
||||
|
||||
by_caller = {
|
||||
"lead_agent": sum(r.get("lead_agent_tokens", 0) for r in completed),
|
||||
"subagent": sum(r.get("subagent_tokens", 0) for r in completed),
|
||||
"middleware": sum(r.get("middleware_tokens", 0) for r in completed),
|
||||
}
|
||||
|
||||
return {
|
||||
"thread_id": thread_id,
|
||||
"total_tokens": total_tokens,
|
||||
"total_input_tokens": total_input,
|
||||
"total_output_tokens": total_output,
|
||||
"total_runs": len(completed),
|
||||
"by_model": by_model,
|
||||
"by_caller": by_caller,
|
||||
}
|
||||
|
||||
@ -63,6 +63,7 @@ class ThreadCreateRequest(BaseModel):
|
||||
"""Request body for creating a thread."""
|
||||
|
||||
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
|
||||
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
|
||||
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ from typing import Any
|
||||
from fastapi import HTTPException, Request
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_store, get_stream_bridge
|
||||
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_store, get_stream_bridge
|
||||
from deerflow.runtime import (
|
||||
END_SENTINEL,
|
||||
HEARTBEAT_SENTINEL,
|
||||
@ -245,6 +245,12 @@ async def start_run(
|
||||
run_mgr = get_run_manager(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
store = get_store(request)
|
||||
event_store = get_run_event_store(request)
|
||||
|
||||
# Get run_events config for journal
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
run_events_config = getattr(get_app_config(), "run_events", None)
|
||||
|
||||
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
||||
|
||||
@ -287,6 +293,8 @@ async def start_run(
|
||||
stream_subgraphs=body.stream_subgraphs,
|
||||
interrupt_before=body.interrupt_before,
|
||||
interrupt_after=body.interrupt_after,
|
||||
event_store=event_store,
|
||||
run_events_config=run_events_config,
|
||||
)
|
||||
)
|
||||
record.task = task
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
from deerflow.persistence.models.run import RunRow
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.persistence.models.thread_meta import ThreadMetaRow
|
||||
|
||||
__all__ = ["RunEventRow", "RunRow", "ThreadMetaRow"]
|
||||
49
backend/packages/harness/deerflow/persistence/models/run.py
Normal file
49
backend/packages/harness/deerflow/persistence/models/run.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""ORM model for run metadata."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import JSON, Index, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
class RunRow(Base):
|
||||
__tablename__ = "runs"
|
||||
|
||||
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128))
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="pending")
|
||||
# "pending" | "running" | "success" | "error" | "timeout" | "interrupted"
|
||||
|
||||
model_name: Mapped[str | None] = mapped_column(String(128))
|
||||
multitask_strategy: Mapped[str] = mapped_column(String(20), default="reject")
|
||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
kwargs_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
error: Mapped[str | None] = mapped_column(Text)
|
||||
|
||||
# Convenience fields (for listing pages without querying RunEventStore)
|
||||
message_count: Mapped[int] = mapped_column(default=0)
|
||||
first_human_message: Mapped[str | None] = mapped_column(Text)
|
||||
last_ai_message: Mapped[str | None] = mapped_column(Text)
|
||||
|
||||
# Token usage (accumulated in-memory by RunJournal, written on run completion)
|
||||
total_input_tokens: Mapped[int] = mapped_column(default=0)
|
||||
total_output_tokens: Mapped[int] = mapped_column(default=0)
|
||||
total_tokens: Mapped[int] = mapped_column(default=0)
|
||||
llm_call_count: Mapped[int] = mapped_column(default=0)
|
||||
lead_agent_tokens: Mapped[int] = mapped_column(default=0)
|
||||
subagent_tokens: Mapped[int] = mapped_column(default=0)
|
||||
middleware_tokens: Mapped[int] = mapped_column(default=0)
|
||||
|
||||
# Follow-up association
|
||||
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64))
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))
|
||||
updated_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
|
||||
|
||||
__table_args__ = (Index("ix_runs_thread_status", "thread_id", "status"),)
|
||||
@ -0,0 +1,30 @@
|
||||
"""ORM model for run events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import JSON, Index, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
class RunEventRow(Base):
|
||||
__tablename__ = "run_events"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
run_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
# "message" | "trace" | "lifecycle"
|
||||
content: Mapped[str] = mapped_column(Text, default="")
|
||||
event_metadata: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
seq: Mapped[int] = mapped_column(nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
|
||||
Index("ix_events_run", "thread_id", "run_id", "seq"),
|
||||
)
|
||||
@ -0,0 +1,23 @@
|
||||
"""ORM model for thread metadata."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import JSON, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
class ThreadMetaRow(Base):
|
||||
__tablename__ = "threads_meta"
|
||||
|
||||
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128), index=True)
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
display_name: Mapped[str | None] = mapped_column(String(256))
|
||||
status: Mapped[str] = mapped_column(String(20), default="idle")
|
||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))
|
||||
updated_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
|
||||
@ -0,0 +1,4 @@
|
||||
from deerflow.persistence.repositories.run_repo import RunRepository
|
||||
from deerflow.persistence.repositories.thread_meta_repo import ThreadMetaRepository
|
||||
|
||||
__all__ = ["RunRepository", "ThreadMetaRepository"]
|
||||
@ -0,0 +1,174 @@
|
||||
"""SQLAlchemy-backed RunStore implementation.
|
||||
|
||||
Each method acquires and releases its own short-lived session.
|
||||
Run status updates happen from background workers that may live
|
||||
minutes -- we don't hold connections across long execution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.run import RunRow
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunRepository(RunStore):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._sf = session_factory
|
||||
|
||||
@staticmethod
|
||||
def _safe_json(obj: Any) -> Any:
|
||||
"""Ensure obj is JSON-serializable. Falls back to model_dump() or str()."""
|
||||
if obj is None:
|
||||
return None
|
||||
if isinstance(obj, (str, int, float, bool)):
|
||||
return obj
|
||||
if isinstance(obj, dict):
|
||||
return {k: RunRepository._safe_json(v) for k, v in obj.items()}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [RunRepository._safe_json(v) for v in obj]
|
||||
if hasattr(obj, "model_dump"):
|
||||
try:
|
||||
return obj.model_dump()
|
||||
except Exception:
|
||||
pass
|
||||
if hasattr(obj, "dict"):
|
||||
try:
|
||||
return obj.dict()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
json.dumps(obj)
|
||||
return obj
|
||||
except (TypeError, ValueError):
|
||||
return str(obj)
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: RunRow) -> dict[str, Any]:
|
||||
d = row.to_dict()
|
||||
# Remap JSON columns to match RunStore interface
|
||||
d["metadata"] = d.pop("metadata_json", {})
|
||||
d["kwargs"] = d.pop("kwargs_json", {})
|
||||
# Convert datetime to ISO string for consistency with MemoryRunStore
|
||||
for key in ("created_at", "updated_at"):
|
||||
val = d.get(key)
|
||||
if isinstance(val, datetime):
|
||||
d[key] = val.isoformat()
|
||||
return d
|
||||
|
||||
async def put(
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
thread_id,
|
||||
assistant_id=None,
|
||||
owner_id=None,
|
||||
status="pending",
|
||||
multitask_strategy="reject",
|
||||
metadata=None,
|
||||
kwargs=None,
|
||||
error=None,
|
||||
created_at=None,
|
||||
):
|
||||
now = datetime.now(UTC)
|
||||
row = RunRow(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
owner_id=owner_id,
|
||||
status=status,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata_json=self._safe_json(metadata) or {},
|
||||
kwargs_json=self._safe_json(kwargs) or {},
|
||||
error=error,
|
||||
created_at=datetime.fromisoformat(created_at) if created_at else now,
|
||||
updated_at=now,
|
||||
)
|
||||
async with self._sf() as session:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
|
||||
async def get(self, run_id):
|
||||
async with self._sf() as session:
|
||||
row = await session.get(RunRow, run_id)
|
||||
return self._row_to_dict(row) if row else None
|
||||
|
||||
async def list_by_thread(self, thread_id, *, owner_id=None, limit=100):
|
||||
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
|
||||
if owner_id is not None:
|
||||
stmt = stmt.where(RunRow.owner_id == owner_id)
|
||||
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def update_status(self, run_id, status, *, error=None):
|
||||
values: dict[str, Any] = {"status": status, "updated_at": datetime.now(UTC)}
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
async with self._sf() as session:
|
||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||
await session.commit()
|
||||
|
||||
async def delete(self, run_id):
|
||||
async with self._sf() as session:
|
||||
row = await session.get(RunRow, run_id)
|
||||
if row is not None:
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
async def list_pending(self, *, before=None):
|
||||
now = before or datetime.now(UTC).isoformat()
|
||||
stmt = select(RunRow).where(RunRow.status == "pending", RunRow.created_at <= now).order_by(RunRow.created_at.asc())
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
status: str,
|
||||
total_input_tokens: int = 0,
|
||||
total_output_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
llm_call_count: int = 0,
|
||||
lead_agent_tokens: int = 0,
|
||||
subagent_tokens: int = 0,
|
||||
middleware_tokens: int = 0,
|
||||
message_count: int = 0,
|
||||
last_ai_message: str | None = None,
|
||||
first_human_message: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Update status + token usage + convenience fields on run completion."""
|
||||
values: dict[str, Any] = {
|
||||
"status": status,
|
||||
"total_input_tokens": total_input_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"llm_call_count": llm_call_count,
|
||||
"lead_agent_tokens": lead_agent_tokens,
|
||||
"subagent_tokens": subagent_tokens,
|
||||
"middleware_tokens": middleware_tokens,
|
||||
"message_count": message_count,
|
||||
"updated_at": datetime.now(UTC),
|
||||
}
|
||||
if last_ai_message is not None:
|
||||
values["last_ai_message"] = last_ai_message[:2000]
|
||||
if first_human_message is not None:
|
||||
values["first_human_message"] = first_human_message[:2000]
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
async with self._sf() as session:
|
||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||
await session.commit()
|
||||
@ -0,0 +1,91 @@
|
||||
"""SQLAlchemy-backed thread metadata repository."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.thread_meta import ThreadMetaRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThreadMetaRepository:
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._sf = session_factory
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: ThreadMetaRow) -> dict[str, Any]:
|
||||
d = row.to_dict()
|
||||
d["metadata"] = d.pop("metadata_json", {})
|
||||
for key in ("created_at", "updated_at"):
|
||||
val = d.get(key)
|
||||
if isinstance(val, datetime):
|
||||
d[key] = val.isoformat()
|
||||
return d
|
||||
|
||||
async def create(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
owner_id: str | None = None,
|
||||
display_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
now = datetime.now(UTC)
|
||||
row = ThreadMetaRow(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
owner_id=owner_id,
|
||||
display_name=display_name,
|
||||
metadata_json=metadata or {},
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
async with self._sf() as session:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def get(self, thread_id: str) -> dict | None:
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
return self._row_to_dict(row) if row else None
|
||||
|
||||
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]:
|
||||
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def check_access(self, thread_id: str, owner_id: str) -> bool:
|
||||
"""Check if owner_id has access to thread_id.
|
||||
|
||||
Returns True if: row doesn't exist (untracked thread), owner_id
|
||||
is None on the row (shared thread), or owner_id matches.
|
||||
"""
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return True
|
||||
if row.owner_id is None:
|
||||
return True
|
||||
return row.owner_id == owner_id
|
||||
|
||||
async def update_status(self, thread_id: str, status: str) -> None:
|
||||
async with self._sf() as session:
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
|
||||
async def delete(self, thread_id: str) -> None:
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is not None:
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
@ -1,4 +1,26 @@
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
__all__ = ["MemoryRunEventStore", "RunEventStore"]
|
||||
|
||||
def make_run_event_store(config=None) -> RunEventStore:
|
||||
"""Create a RunEventStore based on run_events.backend configuration."""
|
||||
if config is None or config.backend == "memory":
|
||||
return MemoryRunEventStore()
|
||||
if config.backend == "db":
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
# database.backend=memory but run_events.backend=db -> fallback
|
||||
return MemoryRunEventStore()
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
return DbRunEventStore(sf, max_trace_content=config.max_trace_content)
|
||||
if config.backend == "jsonl":
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
return JsonlRunEventStore()
|
||||
raise ValueError(f"Unknown run_events backend: {config.backend!r}")
|
||||
|
||||
|
||||
__all__ = ["MemoryRunEventStore", "RunEventStore", "make_run_event_store"]
|
||||
|
||||
148
backend/packages/harness/deerflow/runtime/events/store/db.py
Normal file
148
backend/packages/harness/deerflow/runtime/events/store/db.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""SQLAlchemy-backed RunEventStore implementation.
|
||||
|
||||
Persists events to the ``run_events`` table. Trace content is truncated
|
||||
at ``max_trace_content`` bytes to avoid bloating the database.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
|
||||
|
||||
class DbRunEventStore(RunEventStore):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, max_trace_content: int = 10240):
|
||||
self._sf = session_factory
|
||||
self._max_trace_content = max_trace_content
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: RunEventRow) -> dict:
|
||||
d = row.to_dict()
|
||||
d["metadata"] = d.pop("event_metadata", {})
|
||||
val = d.get("created_at")
|
||||
if isinstance(val, datetime):
|
||||
d["created_at"] = val.isoformat()
|
||||
d.pop("id", None)
|
||||
return d
|
||||
|
||||
def _truncate_trace(self, category: str, content: str, metadata: dict | None) -> tuple[str, dict]:
|
||||
if category == "trace" and len(content) > self._max_trace_content:
|
||||
content = content[: self._max_trace_content]
|
||||
metadata = {**(metadata or {}), "content_truncated": True}
|
||||
return content, metadata or {}
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None):
|
||||
content, metadata = self._truncate_trace(category, content, metadata)
|
||||
async with self._sf() as session:
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id))
|
||||
seq = (max_seq or 0) + 1
|
||||
row = RunEventRow(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
event_type=event_type,
|
||||
category=category,
|
||||
content=content,
|
||||
event_metadata=metadata,
|
||||
seq=seq,
|
||||
created_at=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC),
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
return []
|
||||
async with self._sf() as session:
|
||||
# Get max seq for the thread (assume all events in batch belong to same thread)
|
||||
thread_id = events[0]["thread_id"]
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id))
|
||||
seq = max_seq or 0
|
||||
rows = []
|
||||
for e in events:
|
||||
seq += 1
|
||||
content = e.get("content", "")
|
||||
category = e.get("category", "trace")
|
||||
metadata = e.get("metadata")
|
||||
content, metadata = self._truncate_trace(category, content, metadata)
|
||||
row = RunEventRow(
|
||||
thread_id=e["thread_id"],
|
||||
run_id=e["run_id"],
|
||||
event_type=e["event_type"],
|
||||
category=category,
|
||||
content=content,
|
||||
event_metadata=metadata,
|
||||
seq=seq,
|
||||
created_at=datetime.fromisoformat(e["created_at"]) if e.get("created_at") else datetime.now(UTC),
|
||||
)
|
||||
session.add(row)
|
||||
rows.append(row)
|
||||
await session.commit()
|
||||
for row in rows:
|
||||
await session.refresh(row)
|
||||
return [self._row_to_dict(r) for r in rows]
|
||||
|
||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq > after_seq)
|
||||
|
||||
if after_seq is not None:
|
||||
# Forward pagination: first `limit` records after cursor
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
else:
|
||||
# before_seq or default (latest): take last `limit` records, return ascending
|
||||
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
rows = list(result.scalars())
|
||||
return [self._row_to_dict(r) for r in reversed(rows)]
|
||||
|
||||
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
||||
if event_types:
|
||||
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def list_messages_by_run(self, thread_id, run_id):
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message").order_by(RunEventRow.seq.asc())
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
async with self._sf() as session:
|
||||
return await session.scalar(stmt) or 0
|
||||
|
||||
async def delete_by_thread(self, thread_id):
|
||||
async with self._sf() as session:
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id))
|
||||
await session.commit()
|
||||
return count
|
||||
|
||||
async def delete_by_run(self, thread_id, run_id):
|
||||
async with self._sf() as session:
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id))
|
||||
await session.commit()
|
||||
return count
|
||||
164
backend/packages/harness/deerflow/runtime/events/store/jsonl.py
Normal file
164
backend/packages/harness/deerflow/runtime/events/store/jsonl.py
Normal file
@ -0,0 +1,164 @@
|
||||
"""JSONL file-backed RunEventStore implementation.
|
||||
|
||||
Each run's events are stored in a single file:
|
||||
``.deer-flow/threads/{thread_id}/runs/{run_id}.jsonl``
|
||||
|
||||
All categories (message, trace, lifecycle) are in the same file.
|
||||
This backend is suitable for lightweight single-node deployments.
|
||||
|
||||
Known trade-off: ``list_messages()`` must scan all run files for a
|
||||
thread since messages from multiple runs need unified seq ordering.
|
||||
``list_events()`` reads only one file -- the fast path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JsonlRunEventStore(RunEventStore):
|
||||
def __init__(self, base_dir: str | Path | None = None):
|
||||
self._base_dir = Path(base_dir) if base_dir else Path(".deer-flow")
|
||||
self._seq_counters: dict[str, int] = {} # thread_id -> current max seq
|
||||
|
||||
def _thread_dir(self, thread_id: str) -> Path:
|
||||
return self._base_dir / "threads" / thread_id / "runs"
|
||||
|
||||
def _run_file(self, thread_id: str, run_id: str) -> Path:
|
||||
return self._thread_dir(thread_id) / f"{run_id}.jsonl"
|
||||
|
||||
def _next_seq(self, thread_id: str) -> int:
|
||||
self._seq_counters[thread_id] = self._seq_counters.get(thread_id, 0) + 1
|
||||
return self._seq_counters[thread_id]
|
||||
|
||||
def _ensure_seq_loaded(self, thread_id: str) -> None:
|
||||
"""Load max seq from existing files if not yet cached."""
|
||||
if thread_id in self._seq_counters:
|
||||
return
|
||||
max_seq = 0
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if thread_dir.exists():
|
||||
for f in thread_dir.glob("*.jsonl"):
|
||||
for line in f.read_text(encoding="utf-8").strip().splitlines():
|
||||
try:
|
||||
record = json.loads(line)
|
||||
max_seq = max(max_seq, record.get("seq", 0))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
self._seq_counters[thread_id] = max_seq
|
||||
|
||||
def _write_record(self, record: dict) -> None:
|
||||
path = self._run_file(record["thread_id"], record["run_id"])
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(record, default=str, ensure_ascii=False) + "\n")
|
||||
|
||||
def _read_thread_events(self, thread_id: str) -> list[dict]:
|
||||
"""Read all events for a thread, sorted by seq."""
|
||||
events = []
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if not thread_dir.exists():
|
||||
return events
|
||||
for f in sorted(thread_dir.glob("*.jsonl")):
|
||||
for line in f.read_text(encoding="utf-8").strip().splitlines():
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
events.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
events.sort(key=lambda e: e.get("seq", 0))
|
||||
return events
|
||||
|
||||
def _read_run_events(self, thread_id: str, run_id: str) -> list[dict]:
|
||||
"""Read events for a specific run file."""
|
||||
path = self._run_file(thread_id, run_id)
|
||||
if not path.exists():
|
||||
return []
|
||||
events = []
|
||||
for line in path.read_text(encoding="utf-8").strip().splitlines():
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
events.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
events.sort(key=lambda e: e.get("seq", 0))
|
||||
return events
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None):
|
||||
self._ensure_seq_loaded(thread_id)
|
||||
seq = self._next_seq(thread_id)
|
||||
record = {
|
||||
"thread_id": thread_id,
|
||||
"run_id": run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"seq": seq,
|
||||
"created_at": created_at or datetime.now(UTC).isoformat(),
|
||||
}
|
||||
self._write_record(record)
|
||||
return record
|
||||
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
return []
|
||||
results = []
|
||||
for ev in events:
|
||||
record = await self.put(**ev)
|
||||
results.append(record)
|
||||
return results
|
||||
|
||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
messages = [e for e in all_events if e.get("category") == "message"]
|
||||
|
||||
if before_seq is not None:
|
||||
messages = [e for e in messages if e["seq"] < before_seq]
|
||||
return messages[-limit:]
|
||||
elif after_seq is not None:
|
||||
messages = [e for e in messages if e["seq"] > after_seq]
|
||||
return messages[:limit]
|
||||
else:
|
||||
return messages[-limit:]
|
||||
|
||||
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
if event_types is not None:
|
||||
events = [e for e in events if e.get("event_type") in event_types]
|
||||
return events[:limit]
|
||||
|
||||
async def list_messages_by_run(self, thread_id, run_id):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
return [e for e in events if e.get("category") == "message"]
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
return sum(1 for e in all_events if e.get("category") == "message")
|
||||
|
||||
async def delete_by_thread(self, thread_id):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
count = len(all_events)
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if thread_dir.exists():
|
||||
for f in thread_dir.glob("*.jsonl"):
|
||||
f.unlink()
|
||||
self._seq_counters.pop(thread_id, None)
|
||||
return count
|
||||
|
||||
async def delete_by_run(self, thread_id, run_id):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
count = len(events)
|
||||
path = self._run_file(thread_id, run_id)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
return count
|
||||
333
backend/packages/harness/deerflow/runtime/journal.py
Normal file
333
backend/packages/harness/deerflow/runtime/journal.py
Normal file
@ -0,0 +1,333 @@
|
||||
"""Run event capture via LangChain callbacks.
|
||||
|
||||
RunJournal sits between LangChain's callback mechanism and the pluggable
|
||||
RunEventStore. It standardizes callback data into RunEvent records and
|
||||
handles token usage accumulation.
|
||||
|
||||
Key design decisions:
|
||||
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
|
||||
- All LangChain objects serialized via serialize_lc_object (same as worker.py SSE)
|
||||
- Token usage accumulated in memory, written to RunRow on run completion
|
||||
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunJournal(BaseCallbackHandler):
|
||||
"""LangChain callback handler that captures events to RunEventStore."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
event_store: RunEventStore,
|
||||
*,
|
||||
track_token_usage: bool = True,
|
||||
on_complete: Callable[..., Any] | None = None,
|
||||
flush_threshold: int = 20,
|
||||
):
|
||||
super().__init__()
|
||||
self.run_id = run_id
|
||||
self.thread_id = thread_id
|
||||
self._store = event_store
|
||||
self._track_tokens = track_token_usage
|
||||
self._on_complete = on_complete
|
||||
self._flush_threshold = flush_threshold
|
||||
|
||||
# Write buffer
|
||||
self._buffer: list[dict] = []
|
||||
|
||||
# Token accumulators
|
||||
self._total_input_tokens = 0
|
||||
self._total_output_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._llm_call_count = 0
|
||||
self._lead_agent_tokens = 0
|
||||
self._subagent_tokens = 0
|
||||
self._middleware_tokens = 0
|
||||
|
||||
# Convenience fields
|
||||
self._last_ai_msg: str | None = None
|
||||
self._first_human_msg: str | None = None
|
||||
self._msg_count = 0
|
||||
|
||||
# Latency tracking
|
||||
self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time
|
||||
|
||||
# -- Lifecycle callbacks --
|
||||
|
||||
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
# Only record for the top-level chain (parent_run_id is None)
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(
|
||||
event_type="run_start",
|
||||
category="lifecycle",
|
||||
metadata={"input_preview": str(inputs)[:500]},
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
|
||||
self._flush_sync()
|
||||
if self._on_complete:
|
||||
self._on_complete(
|
||||
total_input_tokens=self._total_input_tokens,
|
||||
total_output_tokens=self._total_output_tokens,
|
||||
total_tokens=self._total_tokens,
|
||||
llm_call_count=self._llm_call_count,
|
||||
lead_agent_tokens=self._lead_agent_tokens,
|
||||
subagent_tokens=self._subagent_tokens,
|
||||
middleware_tokens=self._middleware_tokens,
|
||||
message_count=self._msg_count,
|
||||
last_ai_message=self._last_ai_msg,
|
||||
first_human_message=self._first_human_msg,
|
||||
)
|
||||
|
||||
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(
|
||||
event_type="run_error",
|
||||
category="lifecycle",
|
||||
content=str(error),
|
||||
metadata={"error_type": type(error).__name__},
|
||||
)
|
||||
self._flush_sync()
|
||||
|
||||
# -- LLM callbacks --
|
||||
|
||||
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._llm_start_times[str(run_id)] = time.monotonic()
|
||||
self._put(
|
||||
event_type="llm_start",
|
||||
category="trace",
|
||||
metadata={"model_name": serialized.get("name", "")},
|
||||
)
|
||||
|
||||
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
|
||||
try:
|
||||
message = response.generations[0][0].message
|
||||
except (IndexError, AttributeError):
|
||||
logger.debug("on_llm_end: could not extract message from response")
|
||||
return
|
||||
|
||||
serialized_msg = serialize_lc_object(message)
|
||||
caller = self._identify_caller(kwargs)
|
||||
|
||||
# Latency
|
||||
start = self._llm_start_times.pop(str(run_id), None)
|
||||
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
||||
|
||||
# Token usage from message
|
||||
usage = getattr(message, "usage_metadata", None)
|
||||
usage_dict = dict(usage) if usage else {}
|
||||
|
||||
# trace event: llm_end (every LLM call)
|
||||
self._put(
|
||||
event_type="llm_end",
|
||||
category="trace",
|
||||
content=getattr(message, "content", "") if isinstance(getattr(message, "content", ""), str) else str(getattr(message, "content", "")),
|
||||
metadata={
|
||||
"message": serialized_msg,
|
||||
"caller": caller,
|
||||
"usage": usage_dict,
|
||||
"latency_ms": latency_ms,
|
||||
},
|
||||
)
|
||||
|
||||
# message event: ai_message (only lead_agent final replies with content)
|
||||
if caller == "lead_agent":
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, str) and content:
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
tool_calls_summary = [{"name": tc.get("name", ""), "status": "success"} for tc in tool_calls if isinstance(tc, dict)]
|
||||
resp_meta = getattr(message, "response_metadata", None) or {}
|
||||
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
|
||||
self._put(
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content=content,
|
||||
metadata={
|
||||
"model_name": model_name,
|
||||
"tool_calls": tool_calls_summary,
|
||||
},
|
||||
)
|
||||
self._last_ai_msg = content[:2000]
|
||||
self._msg_count += 1
|
||||
|
||||
# Token accumulation
|
||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||
if self._track_tokens and total_tk > 0:
|
||||
self._total_input_tokens += input_tk
|
||||
self._total_output_tokens += output_tk
|
||||
self._total_tokens += total_tk
|
||||
self._llm_call_count += 1
|
||||
if caller.startswith("subagent:"):
|
||||
self._subagent_tokens += total_tk
|
||||
elif caller.startswith("middleware:"):
|
||||
self._middleware_tokens += total_tk
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
|
||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._llm_start_times.pop(str(run_id), None)
|
||||
self._put(event_type="llm_error", category="trace", content=str(error))
|
||||
|
||||
# -- Tool callbacks --
|
||||
|
||||
def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._put(
|
||||
event_type="tool_start",
|
||||
category="trace",
|
||||
metadata={
|
||||
"tool_name": serialized.get("name", ""),
|
||||
"tool_call_id": kwargs.get("tool_call_id"),
|
||||
"args": str(input_str)[:2000],
|
||||
},
|
||||
)
|
||||
|
||||
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._put(
|
||||
event_type="tool_end",
|
||||
category="trace",
|
||||
content=str(output),
|
||||
metadata={
|
||||
"tool_name": kwargs.get("name", ""),
|
||||
"tool_call_id": kwargs.get("tool_call_id"),
|
||||
"status": "success",
|
||||
},
|
||||
)
|
||||
|
||||
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._put(
|
||||
event_type="tool_error",
|
||||
category="trace",
|
||||
content=str(error),
|
||||
metadata={
|
||||
"tool_name": kwargs.get("name", ""),
|
||||
"tool_call_id": kwargs.get("tool_call_id"),
|
||||
},
|
||||
)
|
||||
|
||||
# -- Custom event callback --
|
||||
|
||||
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
|
||||
if name == "summarization":
|
||||
data_dict = data if isinstance(data, dict) else {}
|
||||
self._put(
|
||||
event_type="summarization",
|
||||
category="trace",
|
||||
content=data_dict.get("summary", ""),
|
||||
metadata={
|
||||
"replaced_message_ids": data_dict.get("replaced_message_ids", []),
|
||||
"replaced_count": data_dict.get("replaced_count", 0),
|
||||
},
|
||||
)
|
||||
self._put(
|
||||
event_type="summary",
|
||||
category="message",
|
||||
content=data_dict.get("summary", ""),
|
||||
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
|
||||
)
|
||||
else:
|
||||
event_data = serialize_lc_object(data) if not isinstance(data, dict) else data
|
||||
self._put(
|
||||
event_type=name,
|
||||
category="trace",
|
||||
metadata=event_data if isinstance(event_data, dict) else {"data": event_data},
|
||||
)
|
||||
|
||||
# -- Internal methods --
|
||||
|
||||
def _put(self, *, event_type: str, category: str, content: str = "", metadata: dict | None = None) -> None:
|
||||
self._buffer.append({
|
||||
"thread_id": self.thread_id,
|
||||
"run_id": self.run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
})
|
||||
if len(self._buffer) >= self._flush_threshold:
|
||||
self._flush_sync()
|
||||
|
||||
def _flush_sync(self) -> None:
|
||||
"""Flush buffer to RunEventStore.
|
||||
|
||||
BaseCallbackHandler methods are synchronous. We schedule the async
|
||||
put_batch via the current event loop.
|
||||
"""
|
||||
if not self._buffer:
|
||||
return
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._flush_async(batch))
|
||||
except RuntimeError:
|
||||
logger.warning("RunJournal: no event loop, dropping %d events", len(batch))
|
||||
|
||||
async def _flush_async(self, batch: list[dict]) -> None:
|
||||
try:
|
||||
await self._store.put_batch(batch)
|
||||
except Exception:
|
||||
logger.warning("RunJournal: failed to flush %d events", len(batch), exc_info=True)
|
||||
|
||||
def _identify_caller(self, kwargs: dict) -> str:
|
||||
for tag in kwargs.get("tags") or []:
|
||||
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
|
||||
return tag
|
||||
return "unknown"
|
||||
|
||||
# -- Public methods (called by worker) --
|
||||
|
||||
def set_first_human_message(self, content: str) -> None:
|
||||
"""Record the first human message for convenience fields."""
|
||||
self._first_human_msg = content[:2000] if content else None
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Force flush. Used in cancel/error paths."""
|
||||
if self._buffer:
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
await self._store.put_batch(batch)
|
||||
|
||||
def get_completion_data(self) -> dict:
|
||||
"""Return accumulated token and message data for run completion."""
|
||||
return {
|
||||
"total_input_tokens": self._total_input_tokens,
|
||||
"total_output_tokens": self._total_output_tokens,
|
||||
"total_tokens": self._total_tokens,
|
||||
"llm_call_count": self._llm_call_count,
|
||||
"lead_agent_tokens": self._lead_agent_tokens,
|
||||
"subagent_tokens": self._subagent_tokens,
|
||||
"middleware_tokens": self._middleware_tokens,
|
||||
"message_count": self._msg_count,
|
||||
"last_ai_message": self._last_ai_msg,
|
||||
"first_human_message": self._first_human_msg,
|
||||
}
|
||||
@ -1,4 +1,4 @@
|
||||
"""In-memory run registry."""
|
||||
"""In-memory run registry with optional persistent RunStore backing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -7,9 +7,13 @@ import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .schemas import DisconnectMode, RunStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -38,11 +42,17 @@ class RunRecord:
|
||||
|
||||
|
||||
class RunManager:
|
||||
"""In-memory run registry. All mutations are protected by an asyncio lock."""
|
||||
"""In-memory run registry with optional persistent RunStore backing.
|
||||
|
||||
def __init__(self) -> None:
|
||||
All mutations are protected by an asyncio lock. When a ``store`` is
|
||||
provided, serializable metadata is also persisted to the store so
|
||||
that run history survives process restarts.
|
||||
"""
|
||||
|
||||
def __init__(self, store: RunStore | None = None) -> None:
|
||||
self._runs: dict[str, RunRecord] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
self._store = store
|
||||
|
||||
async def create(
|
||||
self,
|
||||
@ -71,6 +81,20 @@ class RunManager:
|
||||
)
|
||||
async with self._lock:
|
||||
self._runs[run_id] = record
|
||||
if self._store is not None:
|
||||
try:
|
||||
await self._store.put(
|
||||
run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
status=RunStatus.pending.value,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata=metadata or {},
|
||||
kwargs=kwargs or {},
|
||||
created_at=now,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run %s to store", run_id, exc_info=True)
|
||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||
return record
|
||||
|
||||
@ -96,6 +120,11 @@ class RunManager:
|
||||
record.updated_at = _now_iso()
|
||||
if error is not None:
|
||||
record.error = error
|
||||
if self._store is not None:
|
||||
try:
|
||||
await self._store.update_status(run_id, status.value, error=error)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist status update for run %s", run_id, exc_info=True)
|
||||
logger.info("Run %s -> %s", run_id, status.value)
|
||||
|
||||
async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool:
|
||||
@ -185,6 +214,21 @@ class RunManager:
|
||||
)
|
||||
self._runs[run_id] = record
|
||||
|
||||
if self._store is not None:
|
||||
try:
|
||||
await self._store.put(
|
||||
run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
status=RunStatus.pending.value,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata=metadata or {},
|
||||
kwargs=kwargs or {},
|
||||
created_at=now,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run %s to store", run_id, exc_info=True)
|
||||
|
||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||
return record
|
||||
|
||||
|
||||
@ -45,6 +45,8 @@ async def run_agent(
|
||||
stream_subgraphs: bool = False,
|
||||
interrupt_before: list[str] | Literal["*"] | None = None,
|
||||
interrupt_after: list[str] | Literal["*"] | None = None,
|
||||
event_store: Any | None = None,
|
||||
run_events_config: Any | None = None,
|
||||
) -> None:
|
||||
"""Execute an agent in the background, publishing events to *bridge*."""
|
||||
|
||||
@ -52,6 +54,30 @@ async def run_agent(
|
||||
thread_id = record.thread_id
|
||||
requested_modes: set[str] = set(stream_modes or ["values"])
|
||||
|
||||
# Initialize RunJournal for event capture
|
||||
journal = None
|
||||
if event_store is not None:
|
||||
from deerflow.runtime.journal import RunJournal
|
||||
|
||||
journal = RunJournal(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
event_store=event_store,
|
||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||
)
|
||||
|
||||
# Write human_message event
|
||||
user_input = _extract_user_input(graph_input)
|
||||
if user_input:
|
||||
await event_store.put(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content=user_input,
|
||||
)
|
||||
journal.set_first_human_message(user_input)
|
||||
|
||||
# Track whether "events" was requested but skipped
|
||||
if "events" in requested_modes:
|
||||
logger.info(
|
||||
@ -92,6 +118,10 @@ async def run_agent(
|
||||
runtime = Runtime(context={"thread_id": thread_id}, store=store)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
# Inject RunJournal as a callback
|
||||
if journal is not None:
|
||||
config.setdefault("callbacks", []).append(journal)
|
||||
|
||||
runnable_config = RunnableConfig(**config)
|
||||
agent = agent_factory(config=runnable_config)
|
||||
|
||||
@ -206,6 +236,13 @@ async def run_agent(
|
||||
)
|
||||
|
||||
finally:
|
||||
# Flush any buffered journal events
|
||||
if journal is not None:
|
||||
try:
|
||||
await journal.flush()
|
||||
except Exception:
|
||||
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
|
||||
|
||||
await bridge.publish_end(run_id)
|
||||
asyncio.create_task(bridge.cleanup(run_id, delay=60))
|
||||
|
||||
@ -227,6 +264,23 @@ def _lg_mode_to_sse_event(mode: str) -> str:
|
||||
return mode
|
||||
|
||||
|
||||
def _extract_user_input(graph_input: dict) -> str:
|
||||
"""Extract user input text from graph_input for event recording."""
|
||||
messages = graph_input.get("messages")
|
||||
if not messages:
|
||||
return ""
|
||||
# Take the last message (usually the user's input)
|
||||
last = messages[-1] if isinstance(messages, list) else messages
|
||||
if isinstance(last, str):
|
||||
return last
|
||||
if hasattr(last, "content"):
|
||||
content = last.content
|
||||
return content if isinstance(content, str) else str(content)
|
||||
if isinstance(last, dict):
|
||||
return str(last.get("content", ""))
|
||||
return ""
|
||||
|
||||
|
||||
def _unpack_stream_item(
|
||||
item: Any,
|
||||
lg_modes: list[str],
|
||||
|
||||
154
backend/tests/test_phase2b_integration.py
Normal file
154
backend/tests/test_phase2b_integration.py
Normal file
@ -0,0 +1,154 @@
|
||||
"""Phase 2-B integration tests.
|
||||
|
||||
End-to-end test: simulate a run's complete lifecycle, verify data
|
||||
is correctly written to both RunStore and RunEventStore.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
from deerflow.runtime.journal import RunJournal
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
|
||||
def _make_llm_response(content="Hello", usage=None):
|
||||
msg = MagicMock()
|
||||
msg.content = content
|
||||
msg.tool_calls = []
|
||||
msg.response_metadata = {"model_name": "test-model"}
|
||||
msg.usage_metadata = usage
|
||||
|
||||
gen = MagicMock()
|
||||
gen.message = msg
|
||||
|
||||
response = MagicMock()
|
||||
response.generations = [[gen]]
|
||||
return response
|
||||
|
||||
|
||||
class TestRunLifecycle:
|
||||
@pytest.mark.anyio
|
||||
async def test_full_run_lifecycle(self):
|
||||
"""Simulate a complete run lifecycle with RunStore + RunEventStore."""
|
||||
run_store = MemoryRunStore()
|
||||
event_store = MemoryRunEventStore()
|
||||
|
||||
# 1. Create run
|
||||
await run_store.put("r1", thread_id="t1", status="pending")
|
||||
|
||||
# 2. Write human_message
|
||||
await event_store.put(
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content="What is AI?",
|
||||
)
|
||||
|
||||
# 3. Simulate RunJournal callback sequence
|
||||
on_complete_data = {}
|
||||
|
||||
def on_complete(**data):
|
||||
on_complete_data.update(data)
|
||||
|
||||
journal = RunJournal("r1", "t1", event_store, on_complete=on_complete, flush_threshold=100)
|
||||
journal.set_first_human_message("What is AI?")
|
||||
|
||||
# chain_start (top-level)
|
||||
journal.on_chain_start({}, {"messages": ["What is AI?"]}, run_id=uuid4(), parent_run_id=None)
|
||||
|
||||
# llm_start + llm_end
|
||||
llm_run_id = uuid4()
|
||||
journal.on_llm_start({"name": "gpt-4"}, ["prompt"], run_id=llm_run_id, tags=["lead_agent"])
|
||||
usage = {"input_tokens": 50, "output_tokens": 100, "total_tokens": 150}
|
||||
journal.on_llm_end(_make_llm_response("AI is artificial intelligence.", usage=usage), run_id=llm_run_id, tags=["lead_agent"])
|
||||
|
||||
# chain_end (triggers on_complete + flush_sync which creates a task)
|
||||
journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
|
||||
await journal.flush()
|
||||
# Let event loop process any pending flush tasks from _flush_sync
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# 4. Verify messages
|
||||
messages = await event_store.list_messages("t1")
|
||||
assert len(messages) == 2 # human + ai
|
||||
assert messages[0]["event_type"] == "human_message"
|
||||
assert messages[1]["event_type"] == "ai_message"
|
||||
assert messages[1]["content"] == "AI is artificial intelligence."
|
||||
|
||||
# 5. Verify events
|
||||
events = await event_store.list_events("t1", "r1")
|
||||
event_types = {e["event_type"] for e in events}
|
||||
assert "run_start" in event_types
|
||||
assert "llm_start" in event_types
|
||||
assert "llm_end" in event_types
|
||||
assert "run_end" in event_types
|
||||
|
||||
# 6. Verify on_complete data
|
||||
assert on_complete_data["total_tokens"] == 150
|
||||
assert on_complete_data["llm_call_count"] == 1
|
||||
assert on_complete_data["lead_agent_tokens"] == 150
|
||||
assert on_complete_data["message_count"] == 1
|
||||
assert on_complete_data["last_ai_message"] == "AI is artificial intelligence."
|
||||
assert on_complete_data["first_human_message"] == "What is AI?"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_run_with_tool_calls(self):
|
||||
"""Simulate a run that uses tools."""
|
||||
event_store = MemoryRunEventStore()
|
||||
journal = RunJournal("r1", "t1", event_store, flush_threshold=100)
|
||||
|
||||
# tool_start + tool_end
|
||||
journal.on_tool_start({"name": "web_search"}, '{"query": "AI"}', run_id=uuid4())
|
||||
journal.on_tool_end("Search results...", run_id=uuid4(), name="web_search")
|
||||
await journal.flush()
|
||||
|
||||
events = await event_store.list_events("t1", "r1")
|
||||
assert len(events) == 2
|
||||
assert events[0]["event_type"] == "tool_start"
|
||||
assert events[1]["event_type"] == "tool_end"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_multi_run_thread(self):
|
||||
"""Multiple runs on the same thread maintain unified seq ordering."""
|
||||
event_store = MemoryRunEventStore()
|
||||
|
||||
# Run 1
|
||||
await event_store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="Q1")
|
||||
await event_store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="A1")
|
||||
|
||||
# Run 2
|
||||
await event_store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message", content="Q2")
|
||||
await event_store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message", content="A2")
|
||||
|
||||
messages = await event_store.list_messages("t1")
|
||||
assert len(messages) == 4
|
||||
assert [m["seq"] for m in messages] == [1, 2, 3, 4]
|
||||
assert messages[0]["run_id"] == "r1"
|
||||
assert messages[2]["run_id"] == "r2"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_runmanager_with_store_backing(self):
|
||||
"""RunManager persists to RunStore when one is provided."""
|
||||
from deerflow.runtime.runs.manager import RunManager
|
||||
|
||||
run_store = MemoryRunStore()
|
||||
mgr = RunManager(store=run_store)
|
||||
|
||||
record = await mgr.create("t1", assistant_id="lead_agent")
|
||||
# Verify persisted to store
|
||||
row = await run_store.get(record.run_id)
|
||||
assert row is not None
|
||||
assert row["thread_id"] == "t1"
|
||||
assert row["status"] == "pending"
|
||||
|
||||
# Status update
|
||||
from deerflow.runtime.runs.schemas import RunStatus
|
||||
|
||||
await mgr.set_status(record.run_id, RunStatus.running)
|
||||
row = await run_store.get(record.run_id)
|
||||
assert row["status"] == "running"
|
||||
@ -1,14 +1,7 @@
|
||||
"""Tests for RunEventStore ABC + MemoryRunEventStore.
|
||||
"""Tests for RunEventStore contract across all backends.
|
||||
|
||||
Covers:
|
||||
- Basic write and query (put, seq assignment, cross-thread independence)
|
||||
- list_messages (category filtering, pagination, cross-run ordering)
|
||||
- list_events (run filtering, event_types filtering)
|
||||
- list_messages_by_run
|
||||
- count_messages
|
||||
- put_batch
|
||||
- delete_by_thread, delete_by_run
|
||||
- Edge cases (empty thread/run)
|
||||
Uses a helper to create the store for each backend type.
|
||||
Memory tests run directly; DB and JSONL tests create stores inside each test.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
@ -35,7 +28,6 @@ class TestPutAndSeq:
|
||||
assert record["event_type"] == "human_message"
|
||||
assert record["category"] == "message"
|
||||
assert record["content"] == "hello"
|
||||
assert record["metadata"] == {}
|
||||
assert "created_at" in record
|
||||
|
||||
@pytest.mark.anyio
|
||||
@ -91,7 +83,6 @@ class TestListMessages:
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_before_seq_pagination(self, store):
|
||||
# Put 10 messages with seq 1..10
|
||||
for i in range(10):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
|
||||
messages = await store.list_messages("t1", before_seq=6, limit=3)
|
||||
@ -236,7 +227,6 @@ class TestDelete:
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace")
|
||||
count = await store.delete_by_run("t1", "r2")
|
||||
assert count == 2
|
||||
# r1 events should still be there
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["run_id"] == "r1"
|
||||
@ -270,3 +260,145 @@ class TestEdgeCases:
|
||||
@pytest.mark.anyio
|
||||
async def test_empty_thread_count_messages(self, store):
|
||||
assert await store.count_messages("empty") == 0
|
||||
|
||||
|
||||
# -- DB-specific tests --
|
||||
|
||||
|
||||
class TestDbRunEventStore:
|
||||
"""Tests for DbRunEventStore with temp SQLite."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_basic_crud(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
r = await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hi")
|
||||
assert r["seq"] == 1
|
||||
r2 = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="hello")
|
||||
assert r2["seq"] == 2
|
||||
|
||||
messages = await s.list_messages("t1")
|
||||
assert len(messages) == 2
|
||||
|
||||
count = await s.count_messages("t1")
|
||||
assert count == 2
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_trace_content_truncation(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory(), max_trace_content=100)
|
||||
|
||||
long = "x" * 200
|
||||
r = await s.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace", content=long)
|
||||
assert len(r["content"]) == 100
|
||||
assert r["metadata"].get("content_truncated") is True
|
||||
|
||||
# message content NOT truncated
|
||||
m = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content=long)
|
||||
assert len(m["content"]) == 200
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_pagination(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
for i in range(10):
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
|
||||
|
||||
# before_seq
|
||||
msgs = await s.list_messages("t1", before_seq=6, limit=3)
|
||||
assert [m["seq"] for m in msgs] == [3, 4, 5]
|
||||
|
||||
# after_seq
|
||||
msgs = await s.list_messages("t1", after_seq=7, limit=3)
|
||||
assert [m["seq"] for m in msgs] == [8, 9, 10]
|
||||
|
||||
# default (latest)
|
||||
msgs = await s.list_messages("t1", limit=3)
|
||||
assert [m["seq"] for m in msgs] == [8, 9, 10]
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await s.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message")
|
||||
c = await s.delete_by_run("t1", "r2")
|
||||
assert c == 1
|
||||
assert await s.count_messages("t1") == 1
|
||||
|
||||
c = await s.delete_by_thread("t1")
|
||||
assert c == 1
|
||||
assert await s.count_messages("t1") == 0
|
||||
|
||||
await close_engine()
|
||||
|
||||
|
||||
# -- JSONL-specific tests --
|
||||
|
||||
|
||||
class TestJsonlRunEventStore:
|
||||
@pytest.mark.anyio
|
||||
async def test_basic_crud(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
r = await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hi")
|
||||
assert r["seq"] == 1
|
||||
messages = await s.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_file_at_correct_path(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
assert (tmp_path / "jsonl" / "threads" / "t1" / "runs" / "r1.jsonl").exists()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cross_run_messages(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await s.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
messages = await s.list_messages("t1")
|
||||
assert len(messages) == 2
|
||||
assert [m["seq"] for m in messages] == [1, 2]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await s.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
c = await s.delete_by_run("t1", "r2")
|
||||
assert c == 1
|
||||
assert not (tmp_path / "jsonl" / "threads" / "t1" / "runs" / "r2.jsonl").exists()
|
||||
assert await s.count_messages("t1") == 1
|
||||
|
||||
230
backend/tests/test_run_journal.py
Normal file
230
backend/tests/test_run_journal.py
Normal file
@ -0,0 +1,230 @@
|
||||
"""Tests for RunJournal callback handler.
|
||||
|
||||
Uses MemoryRunEventStore as the backend for direct event inspection.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
from deerflow.runtime.journal import RunJournal
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def journal_setup():
|
||||
store = MemoryRunEventStore()
|
||||
on_complete_data = {}
|
||||
|
||||
def on_complete(**data):
|
||||
on_complete_data.update(data)
|
||||
|
||||
j = RunJournal("r1", "t1", store, on_complete=on_complete, flush_threshold=100)
|
||||
return j, store, on_complete_data
|
||||
|
||||
|
||||
def _make_llm_response(content="Hello", usage=None):
|
||||
"""Create a mock LLM response with a message."""
|
||||
msg = MagicMock()
|
||||
msg.content = content
|
||||
msg.tool_calls = []
|
||||
msg.response_metadata = {"model_name": "test-model"}
|
||||
msg.usage_metadata = usage
|
||||
|
||||
gen = MagicMock()
|
||||
gen.message = msg
|
||||
|
||||
response = MagicMock()
|
||||
response.generations = [[gen]]
|
||||
return response
|
||||
|
||||
|
||||
class TestLlmCallbacks:
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_produces_trace_event(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, tags=["lead_agent"])
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
trace_events = [e for e in events if e["event_type"] == "llm_end"]
|
||||
assert len(trace_events) == 1
|
||||
assert trace_events[0]["category"] == "trace"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_lead_agent_produces_ai_message(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("Answer"), run_id=run_id, tags=["lead_agent"])
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "ai_message"
|
||||
assert messages[0]["content"] == "Answer"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_subagent_no_ai_message(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_start({}, [], run_id=run_id, tags=["subagent:research"])
|
||||
j.on_llm_end(_make_llm_response("Sub answer"), run_id=run_id, tags=["subagent:research"])
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_token_accumulation(self, journal_setup):
|
||||
j, store, on_complete_data = journal_setup
|
||||
usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
|
||||
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4(), tags=["lead_agent"])
|
||||
assert j._total_input_tokens == 30
|
||||
assert j._total_output_tokens == 15
|
||||
assert j._total_tokens == 45
|
||||
assert j._llm_call_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_caller_token_classification(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_start({}, [], run_id=uuid4(), tags=["subagent:research"])
|
||||
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), tags=["subagent:research"])
|
||||
j.on_llm_start({}, [], run_id=uuid4(), tags=["middleware:summarization"])
|
||||
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), tags=["middleware:summarization"])
|
||||
assert j._lead_agent_tokens == 15
|
||||
assert j._subagent_tokens == 15
|
||||
assert j._middleware_tokens == 15
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_usage_metadata_none_no_crash(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("No usage", usage=None), run_id=uuid4(), tags=["lead_agent"])
|
||||
# Should not raise
|
||||
await j.flush()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_latency_tracking(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, tags=["lead_agent"])
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
llm_end = [e for e in events if e["event_type"] == "llm_end"][0]
|
||||
assert "latency_ms" in llm_end["metadata"]
|
||||
assert llm_end["metadata"]["latency_ms"] is not None
|
||||
|
||||
|
||||
class TestLifecycleCallbacks:
|
||||
@pytest.mark.anyio
|
||||
async def test_on_chain_end_triggers_on_complete(self, journal_setup):
|
||||
j, store, on_complete_data = journal_setup
|
||||
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None)
|
||||
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
|
||||
assert "total_tokens" in on_complete_data
|
||||
assert "message_count" in on_complete_data
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_nested_chain_ignored(self, journal_setup):
|
||||
j, store, on_complete_data = journal_setup
|
||||
parent_id = uuid4()
|
||||
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id)
|
||||
j.on_chain_end({}, run_id=uuid4(), parent_run_id=parent_id)
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
lifecycle = [e for e in events if e["category"] == "lifecycle"]
|
||||
assert len(lifecycle) == 0
|
||||
|
||||
|
||||
class TestToolCallbacks:
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_start_end_produce_trace(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j.on_tool_start({"name": "web_search"}, "query", run_id=uuid4())
|
||||
j.on_tool_end("results", run_id=uuid4(), name="web_search")
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
types = {e["event_type"] for e in events}
|
||||
assert "tool_start" in types
|
||||
assert "tool_end" in types
|
||||
|
||||
|
||||
class TestCustomEvents:
|
||||
@pytest.mark.anyio
|
||||
async def test_summarization_event(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j.on_custom_event(
|
||||
"summarization",
|
||||
{"summary": "Context was summarized.", "replaced_count": 5, "replaced_message_ids": ["a", "b"]},
|
||||
run_id=uuid4(),
|
||||
)
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
trace = [e for e in events if e["event_type"] == "summarization"]
|
||||
assert len(trace) == 1
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "summary"
|
||||
|
||||
|
||||
class TestBufferFlush:
|
||||
@pytest.mark.anyio
|
||||
async def test_flush_threshold(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j._flush_threshold = 3
|
||||
j.on_tool_start({"name": "a"}, "x", run_id=uuid4())
|
||||
j.on_tool_start({"name": "b"}, "x", run_id=uuid4())
|
||||
# Buffer has 2 events, not yet flushed
|
||||
assert len(j._buffer) == 2
|
||||
j.on_tool_start({"name": "c"}, "x", run_id=uuid4())
|
||||
# Buffer should have been flushed (threshold=3 triggers flush)
|
||||
# Give the async task a chance to complete
|
||||
await asyncio.sleep(0.1)
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert len(events) >= 3
|
||||
|
||||
|
||||
class TestIdentifyCaller:
|
||||
def test_lead_agent_tag(self, journal_setup):
|
||||
j, _, _ = journal_setup
|
||||
assert j._identify_caller({"tags": ["lead_agent"]}) == "lead_agent"
|
||||
|
||||
def test_subagent_tag(self, journal_setup):
|
||||
j, _, _ = journal_setup
|
||||
assert j._identify_caller({"tags": ["subagent:research"]}) == "subagent:research"
|
||||
|
||||
def test_middleware_tag(self, journal_setup):
|
||||
j, _, _ = journal_setup
|
||||
assert j._identify_caller({"tags": ["middleware:summarization"]}) == "middleware:summarization"
|
||||
|
||||
def test_no_tags_returns_unknown(self, journal_setup):
|
||||
j, _, _ = journal_setup
|
||||
assert j._identify_caller({"tags": []}) == "unknown"
|
||||
assert j._identify_caller({}) == "unknown"
|
||||
|
||||
|
||||
class TestPublicMethods:
|
||||
@pytest.mark.anyio
|
||||
async def test_set_first_human_message(self, journal_setup):
|
||||
j, _, _ = journal_setup
|
||||
j.set_first_human_message("Hello world")
|
||||
assert j._first_human_msg == "Hello world"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_completion_data(self, journal_setup):
|
||||
j, _, _ = journal_setup
|
||||
j._total_tokens = 100
|
||||
j._msg_count = 5
|
||||
data = j.get_completion_data()
|
||||
assert data["total_tokens"] == 100
|
||||
assert data["message_count"] == 5
|
||||
155
backend/tests/test_run_repository.py
Normal file
155
backend/tests/test_run_repository.py
Normal file
@ -0,0 +1,155 @@
|
||||
"""Tests for RunRepository (SQLAlchemy-backed RunStore).
|
||||
|
||||
Uses a temp SQLite DB to test ORM-backed CRUD operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.persistence.repositories.run_repo import RunRepository
|
||||
|
||||
|
||||
async def _make_repo(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
return RunRepository(get_session_factory())
|
||||
|
||||
|
||||
async def _cleanup():
|
||||
from deerflow.persistence.engine import close_engine
|
||||
|
||||
await close_engine()
|
||||
|
||||
|
||||
class TestRunRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_put_and_get(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="pending")
|
||||
row = await repo.get("r1")
|
||||
assert row is not None
|
||||
assert row["run_id"] == "r1"
|
||||
assert row["thread_id"] == "t1"
|
||||
assert row["status"] == "pending"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_missing_returns_none(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
assert await repo.get("nope") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_status(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1")
|
||||
await repo.update_status("r1", "running")
|
||||
row = await repo.get("r1")
|
||||
assert row["status"] == "running"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_status_with_error(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1")
|
||||
await repo.update_status("r1", "error", error="boom")
|
||||
row = await repo.get("r1")
|
||||
assert row["status"] == "error"
|
||||
assert row["error"] == "boom"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1")
|
||||
await repo.put("r2", thread_id="t1")
|
||||
await repo.put("r3", thread_id="t2")
|
||||
rows = await repo.list_by_thread("t1")
|
||||
assert len(rows) == 2
|
||||
assert all(r["thread_id"] == "t1" for r in rows)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_owner_filter(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", owner_id="alice")
|
||||
await repo.put("r2", thread_id="t1", owner_id="bob")
|
||||
rows = await repo.list_by_thread("t1", owner_id="alice")
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["owner_id"] == "alice"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1")
|
||||
await repo.delete("r1")
|
||||
assert await repo.get("r1") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_is_noop(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.delete("nope") # should not raise
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_pending(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="pending")
|
||||
await repo.put("r2", thread_id="t1", status="running")
|
||||
await repo.put("r3", thread_id="t2", status="pending")
|
||||
pending = await repo.list_pending()
|
||||
assert len(pending) == 2
|
||||
assert all(r["status"] == "pending" for r in pending)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_run_completion(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="running")
|
||||
await repo.update_run_completion(
|
||||
"r1",
|
||||
status="success",
|
||||
total_input_tokens=100,
|
||||
total_output_tokens=50,
|
||||
total_tokens=150,
|
||||
llm_call_count=2,
|
||||
lead_agent_tokens=120,
|
||||
subagent_tokens=20,
|
||||
middleware_tokens=10,
|
||||
message_count=3,
|
||||
last_ai_message="The answer is 42",
|
||||
first_human_message="What is the meaning?",
|
||||
)
|
||||
row = await repo.get("r1")
|
||||
assert row["status"] == "success"
|
||||
assert row["total_tokens"] == 150
|
||||
assert row["llm_call_count"] == 2
|
||||
assert row["lead_agent_tokens"] == 120
|
||||
assert row["message_count"] == 3
|
||||
assert row["last_ai_message"] == "The answer is 42"
|
||||
assert row["first_human_message"] == "What is the meaning?"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_metadata_preserved(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", metadata={"key": "value"})
|
||||
row = await repo.get("r1")
|
||||
assert row["metadata"] == {"key": "value"}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_kwargs_with_non_serializable(self, tmp_path):
|
||||
"""kwargs containing non-JSON-serializable objects should be safely handled."""
|
||||
repo = await _make_repo(tmp_path)
|
||||
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
await repo.put("r1", thread_id="t1", kwargs={"obj": Dummy()})
|
||||
row = await repo.get("r1")
|
||||
assert "obj" in row["kwargs"]
|
||||
await _cleanup()
|
||||
Loading…
x
Reference in New Issue
Block a user