refactor(journal): fix flush, token tracking, and consolidate tests

RunJournal fixes:
- _flush_sync: retain events in buffer when no event loop instead of
  dropping them; worker's finally block flushes via async flush().
- on_llm_end: add tool_calls filter and caller=="lead_agent" guard for
  ai_message events; mark message IDs for dedup with record_llm_usage.
- worker.py: persist completion data (tokens, message count) to RunStore
  in finally block.

Model factory:
- Auto-inject stream_usage=True for BaseChatOpenAI subclasses with
  custom api_base, so usage_metadata is populated in streaming responses.

Test consolidation:
- Delete test_phase2b_integration.py (redundant with existing tests).
- Move DB-backed lifecycle test into test_run_journal.py.
- Add tests for stream_usage injection in test_model_factory.py.
- Clean up executor/task_tool dead journal references.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng 2026-04-03 17:26:11 +08:00
parent e5b01d7e74
commit b92ddafd4b
7 changed files with 360 additions and 451 deletions

View File

@ -77,6 +77,15 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
elif "reasoning_effort" not in model_settings_from_config:
model_settings_from_config["reasoning_effort"] = "medium"
# Ensure stream_usage is enabled so that token usage metadata is available
# in streaming responses. LangChain's BaseChatOpenAI only defaults
# stream_usage=True when no custom base_url/api_base is set, so models
# hitting third-party endpoints (e.g. doubao, deepseek) silently lose
# usage data. We default it to True unless explicitly configured.
if "stream_usage" not in model_settings_from_config and "stream_usage" not in kwargs:
if "stream_usage" in getattr(model_class, "model_fields", {}):
model_settings_from_config["stream_usage"] = True
model_instance = model_class(**kwargs, **model_settings_from_config)
if is_tracing_enabled():

View File

@ -16,7 +16,6 @@ from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import Callable
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
from uuid import UUID
@ -39,7 +38,6 @@ class RunJournal(BaseCallbackHandler):
event_store: RunEventStore,
*,
track_token_usage: bool = True,
on_complete: Callable[..., Any] | None = None,
flush_threshold: int = 20,
):
super().__init__()
@ -47,7 +45,6 @@ class RunJournal(BaseCallbackHandler):
self.thread_id = thread_id
self._store = event_store
self._track_tokens = track_token_usage
self._on_complete = on_complete
self._flush_threshold = flush_threshold
# Write buffer
@ -73,7 +70,6 @@ class RunJournal(BaseCallbackHandler):
# -- Lifecycle callbacks --
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
# Only record for the top-level chain (parent_run_id is None)
if kwargs.get("parent_run_id") is not None:
return
self._put(
@ -87,19 +83,6 @@ class RunJournal(BaseCallbackHandler):
return
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
self._flush_sync()
if self._on_complete:
self._on_complete(
total_input_tokens=self._total_input_tokens,
total_output_tokens=self._total_output_tokens,
total_tokens=self._total_tokens,
llm_call_count=self._llm_call_count,
lead_agent_tokens=self._lead_agent_tokens,
subagent_tokens=self._subagent_tokens,
middleware_tokens=self._middleware_tokens,
message_count=self._msg_count,
last_ai_message=self._last_ai_msg,
first_human_message=self._first_human_msg,
)
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None:
@ -131,7 +114,6 @@ class RunJournal(BaseCallbackHandler):
logger.debug("on_llm_end: could not extract message from response")
return
serialized_msg = serialize_lc_object(message)
caller = self._identify_caller(kwargs)
# Latency
@ -142,54 +124,52 @@ class RunJournal(BaseCallbackHandler):
usage = getattr(message, "usage_metadata", None)
usage_dict = dict(usage) if usage else {}
# trace event: llm_end (every LLM call)
# Trace event: llm_end (every LLM call)
content = getattr(message, "content", "")
self._put(
event_type="llm_end",
category="trace",
content=getattr(message, "content", "") if isinstance(getattr(message, "content", ""), str) else str(getattr(message, "content", "")),
content=content if isinstance(content, str) else str(content),
metadata={
"message": serialized_msg,
"message": serialize_lc_object(message),
"caller": caller,
"usage": usage_dict,
"latency_ms": latency_ms,
},
)
# message event: ai_message (only lead_agent final replies with content)
if caller == "lead_agent":
content = getattr(message, "content", "")
if isinstance(content, str) and content:
tool_calls = getattr(message, "tool_calls", None) or []
tool_calls_summary = [{"name": tc.get("name", ""), "status": "success"} for tc in tool_calls if isinstance(tc, dict)]
resp_meta = getattr(message, "response_metadata", None) or {}
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
self._put(
event_type="ai_message",
category="message",
content=content,
metadata={
"model_name": model_name,
"tool_calls": tool_calls_summary,
},
)
self._last_ai_msg = content[:2000]
self._msg_count += 1
# Message event: ai_message (only lead_agent final replies — no pending tool_calls)
tool_calls = getattr(message, "tool_calls", None) or []
if caller == "lead_agent" and isinstance(content, str) and content and not tool_calls:
resp_meta = getattr(message, "response_metadata", None) or {}
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
self._put(
event_type="ai_message",
category="message",
content=content,
metadata={"model_name": model_name},
)
self._last_ai_msg = content[:2000]
self._msg_count += 1
# Token accumulation
input_tk = usage_dict.get("input_tokens", 0) or 0
output_tk = usage_dict.get("output_tokens", 0) or 0
total_tk = usage_dict.get("total_tokens", 0) or 0
if self._track_tokens and total_tk > 0:
self._total_input_tokens += input_tk
self._total_output_tokens += output_tk
self._total_tokens += total_tk
self._llm_call_count += 1
if caller.startswith("subagent:"):
self._subagent_tokens += total_tk
elif caller.startswith("middleware:"):
self._middleware_tokens += total_tk
else:
self._lead_agent_tokens += total_tk
if self._track_tokens:
input_tk = usage_dict.get("input_tokens", 0) or 0
output_tk = usage_dict.get("output_tokens", 0) or 0
total_tk = usage_dict.get("total_tokens", 0) or 0
if total_tk == 0:
total_tk = input_tk + output_tk
if total_tk > 0:
self._total_input_tokens += input_tk
self._total_output_tokens += output_tk
self._total_tokens += total_tk
self._llm_call_count += 1
if caller.startswith("subagent:"):
self._subagent_tokens += total_tk
elif caller.startswith("middleware:"):
self._middleware_tokens += total_tk
else:
self._lead_agent_tokens += total_tk
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
self._llm_start_times.pop(str(run_id), None)
@ -277,20 +257,23 @@ class RunJournal(BaseCallbackHandler):
self._flush_sync()
def _flush_sync(self) -> None:
"""Flush buffer to RunEventStore.
"""Best-effort flush of buffer to RunEventStore.
BaseCallbackHandler methods are synchronous. We schedule the async
put_batch via the current event loop.
BaseCallbackHandler methods are synchronous. If an event loop is
running we schedule an async ``put_batch``; otherwise the events
stay in the buffer and are flushed later by the async ``flush()``
call in the worker's ``finally`` block.
"""
if not self._buffer:
return
batch = self._buffer.copy()
self._buffer.clear()
try:
loop = asyncio.get_running_loop()
loop.create_task(self._flush_async(batch))
except RuntimeError:
logger.warning("RunJournal: no event loop, dropping %d events", len(batch))
# No event loop — keep events in buffer for later async flush.
return
batch = self._buffer.copy()
self._buffer.clear()
loop.create_task(self._flush_async(batch))
async def _flush_async(self, batch: list[dict]) -> None:
try:
@ -302,7 +285,10 @@ class RunJournal(BaseCallbackHandler):
for tag in kwargs.get("tags") or []:
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
return tag
return "unknown"
# Default to lead_agent: the main agent graph does not inject
# callback tags, while subagents and middleware explicitly tag
# themselves.
return "lead_agent"
# -- Public methods (called by worker) --
@ -311,7 +297,7 @@ class RunJournal(BaseCallbackHandler):
self._first_human_msg = content[:2000] if content else None
async def flush(self) -> None:
"""Force flush. Used in cancel/error paths."""
"""Force flush remaining buffer. Called in worker's finally block."""
if self._buffer:
batch = self._buffer.copy()
self._buffer.clear()

View File

@ -123,7 +123,8 @@ async def run_agent(
runtime = Runtime(context={"thread_id": thread_id}, store=store)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
# Inject RunJournal as a callback
# Inject RunJournal as a LangChain callback handler.
# on_llm_end captures token usage; on_chain_start/end captures lifecycle.
if journal is not None:
config.setdefault("callbacks", []).append(journal)
@ -241,13 +242,25 @@ async def run_agent(
)
finally:
# Flush any buffered journal events
# Flush any buffered journal events and persist completion data
if journal is not None:
try:
await journal.flush()
except Exception:
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
# Persist token usage + convenience fields to RunStore
if run_manager._store is not None:
try:
completion = journal.get_completion_data()
await run_manager._store.update_run_completion(
run_id,
status=record.status.value,
**completion,
)
except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
await bridge.publish_end(run_id)
asyncio.create_task(bridge.cleanup(run_id, delay=60))

View File

@ -593,6 +593,84 @@ def test_codex_provider_strips_unsupported_max_tokens(monkeypatch):
assert "max_tokens" not in FakeChatModel.captured_kwargs
# ---------------------------------------------------------------------------
# stream_usage injection
# ---------------------------------------------------------------------------
class _FakeWithStreamUsage(FakeChatModel):
"""Fake model that declares stream_usage in model_fields (like BaseChatOpenAI)."""
stream_usage: bool | None = None
def test_stream_usage_injected_for_openai_compatible_model(monkeypatch):
"""Factory should set stream_usage=True for models with stream_usage field."""
cfg = _make_app_config([_make_model("deepseek", use="langchain_deepseek:ChatDeepSeek")])
_patch_factory(monkeypatch, cfg, model_class=_FakeWithStreamUsage)
captured: dict = {}
class CapturingModel(_FakeWithStreamUsage):
def __init__(self, **kwargs):
captured.update(kwargs)
BaseChatModel.__init__(self, **kwargs)
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="deepseek")
assert captured.get("stream_usage") is True
def test_stream_usage_not_injected_for_non_openai_model(monkeypatch):
"""Factory should NOT inject stream_usage for models without the field."""
cfg = _make_app_config([_make_model("claude", use="langchain_anthropic:ChatAnthropic")])
_patch_factory(monkeypatch, cfg)
captured: dict = {}
class CapturingModel(FakeChatModel):
def __init__(self, **kwargs):
captured.update(kwargs)
BaseChatModel.__init__(self, **kwargs)
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="claude")
assert "stream_usage" not in captured
def test_stream_usage_not_overridden_when_explicitly_set_in_config(monkeypatch):
"""If config dumps stream_usage=False, factory should respect it."""
cfg = _make_app_config([_make_model("deepseek", use="langchain_deepseek:ChatDeepSeek")])
_patch_factory(monkeypatch, cfg, model_class=_FakeWithStreamUsage)
captured: dict = {}
class CapturingModel(_FakeWithStreamUsage):
def __init__(self, **kwargs):
captured.update(kwargs)
BaseChatModel.__init__(self, **kwargs)
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
# Simulate config having stream_usage explicitly set by patching model_dump
original_get_model_config = cfg.get_model_config
def patched_get_model_config(name):
mc = original_get_model_config(name)
mc.stream_usage = False # type: ignore[attr-defined]
return mc
monkeypatch.setattr(cfg, "get_model_config", patched_get_model_config)
factory_module.create_chat_model(name="deepseek")
assert captured.get("stream_usage") is False
def test_openai_responses_api_settings_are_passed_to_chatopenai(monkeypatch):
model = ModelConfig(
name="gpt-5-responses",

View File

@ -15,7 +15,6 @@ import pytest
from deerflow.config.database_config import DatabaseConfig
from deerflow.runtime.runs.store.memory import MemoryRunStore
# -- DatabaseConfig --

View File

@ -1,279 +0,0 @@
"""Phase 2-B integration tests.
End-to-end test: simulate a run's complete lifecycle, verify data
is correctly written to both RunStore and RunEventStore.
"""
import asyncio
from uuid import uuid4
import pytest
from deerflow.runtime.events.store.memory import MemoryRunEventStore
from deerflow.runtime.journal import RunJournal
from deerflow.runtime.runs.store.memory import MemoryRunStore
class _FakeMessage:
def __init__(self, content, usage):
self.content = content
self.tool_calls = []
self.response_metadata = {"model_name": "test-model"}
self.usage_metadata = usage
self.id = "test-msg-id"
def model_dump(self):
return {"type": "ai", "content": self.content, "id": self.id, "tool_calls": [], "usage_metadata": self.usage_metadata, "response_metadata": self.response_metadata}
class _FakeGeneration:
def __init__(self, message):
self.message = message
class _FakeLLMResult:
def __init__(self, content, usage):
self.generations = [[_FakeGeneration(_FakeMessage(content, usage))]]
def _make_llm_response(content="Hello", usage=None):
return _FakeLLMResult(content, usage)
class TestRunLifecycle:
@pytest.mark.anyio
async def test_full_run_lifecycle(self):
"""Simulate a complete run lifecycle with RunStore + RunEventStore."""
run_store = MemoryRunStore()
event_store = MemoryRunEventStore()
# 1. Create run
await run_store.put("r1", thread_id="t1", status="pending")
# 2. Write human_message
await event_store.put(
thread_id="t1",
run_id="r1",
event_type="human_message",
category="message",
content="What is AI?",
)
# 3. Simulate RunJournal callback sequence
on_complete_data = {}
def on_complete(**data):
on_complete_data.update(data)
journal = RunJournal("r1", "t1", event_store, on_complete=on_complete, flush_threshold=100)
journal.set_first_human_message("What is AI?")
# chain_start (top-level)
journal.on_chain_start({}, {"messages": ["What is AI?"]}, run_id=uuid4(), parent_run_id=None)
# llm_start + llm_end
llm_run_id = uuid4()
journal.on_llm_start({"name": "gpt-4"}, ["prompt"], run_id=llm_run_id, tags=["lead_agent"])
usage = {"input_tokens": 50, "output_tokens": 100, "total_tokens": 150}
journal.on_llm_end(_make_llm_response("AI is artificial intelligence.", usage=usage), run_id=llm_run_id, tags=["lead_agent"])
# chain_end (triggers on_complete + flush_sync which creates a task)
journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
await journal.flush()
# Let event loop process any pending flush tasks from _flush_sync
await asyncio.sleep(0.05)
# 4. Verify messages
messages = await event_store.list_messages("t1")
assert len(messages) == 2 # human + ai
assert messages[0]["event_type"] == "human_message"
assert messages[1]["event_type"] == "ai_message"
assert messages[1]["content"] == "AI is artificial intelligence."
# 5. Verify events
events = await event_store.list_events("t1", "r1")
event_types = {e["event_type"] for e in events}
assert "run_start" in event_types
assert "llm_start" in event_types
assert "llm_end" in event_types
assert "run_end" in event_types
# 6. Verify on_complete data
assert on_complete_data["total_tokens"] == 150
assert on_complete_data["llm_call_count"] == 1
assert on_complete_data["lead_agent_tokens"] == 150
assert on_complete_data["message_count"] == 1
assert on_complete_data["last_ai_message"] == "AI is artificial intelligence."
assert on_complete_data["first_human_message"] == "What is AI?"
@pytest.mark.anyio
async def test_run_with_tool_calls(self):
"""Simulate a run that uses tools."""
event_store = MemoryRunEventStore()
journal = RunJournal("r1", "t1", event_store, flush_threshold=100)
# tool_start + tool_end
journal.on_tool_start({"name": "web_search"}, '{"query": "AI"}', run_id=uuid4())
journal.on_tool_end("Search results...", run_id=uuid4(), name="web_search")
await journal.flush()
events = await event_store.list_events("t1", "r1")
assert len(events) == 2
assert events[0]["event_type"] == "tool_start"
assert events[1]["event_type"] == "tool_end"
@pytest.mark.anyio
async def test_multi_run_thread(self):
"""Multiple runs on the same thread maintain unified seq ordering."""
event_store = MemoryRunEventStore()
# Run 1
await event_store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="Q1")
await event_store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="A1")
# Run 2
await event_store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message", content="Q2")
await event_store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message", content="A2")
messages = await event_store.list_messages("t1")
assert len(messages) == 4
assert [m["seq"] for m in messages] == [1, 2, 3, 4]
assert messages[0]["run_id"] == "r1"
assert messages[2]["run_id"] == "r2"
@pytest.mark.anyio
async def test_runmanager_with_store_backing(self):
"""RunManager persists to RunStore when one is provided."""
from deerflow.runtime.runs.manager import RunManager
run_store = MemoryRunStore()
mgr = RunManager(store=run_store)
record = await mgr.create("t1", assistant_id="lead_agent")
# Verify persisted to store
row = await run_store.get(record.run_id)
assert row is not None
assert row["thread_id"] == "t1"
assert row["status"] == "pending"
# Status update
from deerflow.runtime.runs.schemas import RunStatus
await mgr.set_status(record.run_id, RunStatus.running)
row = await run_store.get(record.run_id)
assert row["status"] == "running"
@pytest.mark.anyio
async def test_runmanager_create_or_reject_persists(self):
"""create_or_reject also persists to store."""
from deerflow.runtime.runs.manager import RunManager
run_store = MemoryRunStore()
mgr = RunManager(store=run_store)
record = await mgr.create_or_reject("t1", "lead_agent", metadata={"key": "val"})
row = await run_store.get(record.run_id)
assert row is not None
assert row["status"] == "pending"
assert row["metadata"] == {"key": "val"}
@pytest.mark.anyio
async def test_follow_up_metadata_in_messages(self):
"""human_message metadata carries follow_up_to_run_id."""
event_store = MemoryRunEventStore()
# Run 1
await event_store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="Q1")
await event_store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="A1")
# Run 2 (follow-up)
await event_store.put(
thread_id="t1",
run_id="r2",
event_type="human_message",
category="message",
content="Tell me more",
metadata={"follow_up_to_run_id": "r1"},
)
messages = await event_store.list_messages("t1")
assert len(messages) == 3
assert messages[2]["metadata"]["follow_up_to_run_id"] == "r1"
@pytest.mark.anyio
async def test_summarization_in_history(self):
"""summary message appears correctly in message history."""
event_store = MemoryRunEventStore()
await event_store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="Q1")
await event_store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="A1")
await event_store.put(thread_id="t1", run_id="r2", event_type="summary", category="message", content="Previous conversation summarized.", metadata={"replaced_count": 2})
await event_store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message", content="Q2")
await event_store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message", content="A2")
messages = await event_store.list_messages("t1")
assert len(messages) == 5
assert messages[2]["event_type"] == "summary"
assert messages[2]["metadata"]["replaced_count"] == 2
@pytest.mark.anyio
async def test_db_backed_run_lifecycle(self, tmp_path):
"""Full lifecycle with SQLite-backed RunRepository + DbRunEventStore."""
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.persistence.repositories.run_repo import RunRepository
from deerflow.runtime.events.store.db import DbRunEventStore
from deerflow.runtime.runs.manager import RunManager
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
sf = get_session_factory()
run_store = RunRepository(sf)
event_store = DbRunEventStore(sf)
mgr = RunManager(store=run_store)
# Create run
record = await mgr.create("t1", "lead_agent")
run_id = record.run_id
# Write human_message
await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content="Hello DB")
# Simulate journal
on_complete_data = {}
journal = RunJournal(run_id, "t1", event_store, on_complete=lambda **d: on_complete_data.update(d), flush_threshold=100)
journal.set_first_human_message("Hello DB")
journal.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None)
llm_rid = uuid4()
journal.on_llm_start({"name": "test"}, [], run_id=llm_rid, tags=["lead_agent"])
journal.on_llm_end(_make_llm_response("DB response", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), run_id=llm_rid, tags=["lead_agent"])
journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
await journal.flush()
await asyncio.sleep(0.05)
# Verify run persisted
row = await run_store.get(run_id)
assert row is not None
assert row["status"] == "pending" # RunManager set it, journal doesn't update status
# Update completion
await run_store.update_run_completion(run_id, status="success", **on_complete_data)
row = await run_store.get(run_id)
assert row["status"] == "success"
assert row["total_tokens"] == 15
# Verify messages from DB
messages = await event_store.list_messages("t1")
assert len(messages) == 2
assert messages[0]["event_type"] == "human_message"
assert messages[1]["event_type"] == "ai_message"
# Verify events from DB
events = await event_store.list_events("t1", run_id)
event_types = {e["event_type"] for e in events}
assert "run_start" in event_types
assert "llm_end" in event_types
assert "run_end" in event_types
await close_engine()

View File

@ -16,22 +16,28 @@ from deerflow.runtime.journal import RunJournal
@pytest.fixture
def journal_setup():
store = MemoryRunEventStore()
on_complete_data = {}
def on_complete(**data):
on_complete_data.update(data)
j = RunJournal("r1", "t1", store, on_complete=on_complete, flush_threshold=100)
return j, store, on_complete_data
j = RunJournal("r1", "t1", store, flush_threshold=100)
return j, store
def _make_llm_response(content="Hello", usage=None):
def _make_llm_response(content="Hello", usage=None, tool_calls=None):
"""Create a mock LLM response with a message."""
msg = MagicMock()
msg.content = content
msg.tool_calls = []
msg.id = f"msg-{id(msg)}"
msg.tool_calls = tool_calls or []
msg.response_metadata = {"model_name": "test-model"}
msg.usage_metadata = usage
# Provide a real model_dump so serialize_lc_object returns a plain dict
# (needed for DB-backed tests where json.dumps must succeed).
msg.model_dump.return_value = {
"type": "ai",
"content": content,
"id": msg.id,
"tool_calls": tool_calls or [],
"usage_metadata": usage,
"response_metadata": {"model_name": "test-model"},
}
gen = MagicMock()
gen.message = msg
@ -44,7 +50,7 @@ def _make_llm_response(content="Hello", usage=None):
class TestLlmCallbacks:
@pytest.mark.anyio
async def test_on_llm_end_produces_trace_event(self, journal_setup):
j, store, _ = journal_setup
j, store = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, tags=["lead_agent"])
@ -56,7 +62,7 @@ class TestLlmCallbacks:
@pytest.mark.anyio
async def test_on_llm_end_lead_agent_produces_ai_message(self, journal_setup):
j, store, _ = journal_setup
j, store = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Answer"), run_id=run_id, tags=["lead_agent"])
@ -66,9 +72,23 @@ class TestLlmCallbacks:
assert messages[0]["event_type"] == "ai_message"
assert messages[0]["content"] == "Answer"
@pytest.mark.anyio
async def test_on_llm_end_with_tool_calls_no_ai_message(self, journal_setup):
"""LLM response with pending tool_calls should NOT produce ai_message."""
j, store = journal_setup
run_id = uuid4()
j.on_llm_end(
_make_llm_response("Let me search", tool_calls=[{"name": "search"}]),
run_id=run_id,
tags=["lead_agent"],
)
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 0
@pytest.mark.anyio
async def test_on_llm_end_subagent_no_ai_message(self, journal_setup):
j, store, _ = journal_setup
j, store = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["subagent:research"])
j.on_llm_end(_make_llm_response("Sub answer"), run_id=run_id, tags=["subagent:research"])
@ -78,27 +98,34 @@ class TestLlmCallbacks:
@pytest.mark.anyio
async def test_token_accumulation(self, journal_setup):
j, store, on_complete_data = journal_setup
j, store = journal_setup
usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4(), tags=["lead_agent"])
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4(), tags=["lead_agent"])
assert j._total_input_tokens == 30
assert j._total_output_tokens == 15
assert j._total_tokens == 45
assert j._llm_call_count == 2
@pytest.mark.anyio
async def test_total_tokens_computed_from_input_output(self, journal_setup):
"""If total_tokens is 0, it should be computed from input + output."""
j, store = journal_setup
j.on_llm_end(
_make_llm_response("Hi", usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 0}),
run_id=uuid4(),
tags=["lead_agent"],
)
assert j._total_tokens == 150
assert j._lead_agent_tokens == 150
@pytest.mark.anyio
async def test_caller_token_classification(self, journal_setup):
j, store, _ = journal_setup
j, store = journal_setup
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), tags=["lead_agent"])
j.on_llm_start({}, [], run_id=uuid4(), tags=["subagent:research"])
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), tags=["subagent:research"])
j.on_llm_start({}, [], run_id=uuid4(), tags=["middleware:summarization"])
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), tags=["middleware:summarization"])
assert j._lead_agent_tokens == 15
assert j._subagent_tokens == 15
@ -106,15 +133,13 @@ class TestLlmCallbacks:
@pytest.mark.anyio
async def test_usage_metadata_none_no_crash(self, journal_setup):
j, store, _ = journal_setup
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
j, store = journal_setup
j.on_llm_end(_make_llm_response("No usage", usage=None), run_id=uuid4(), tags=["lead_agent"])
# Should not raise
await j.flush()
@pytest.mark.anyio
async def test_latency_tracking(self, journal_setup):
j, store, _ = journal_setup
j, store = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, tags=["lead_agent"])
@ -127,16 +152,20 @@ class TestLlmCallbacks:
class TestLifecycleCallbacks:
@pytest.mark.anyio
async def test_on_chain_end_triggers_on_complete(self, journal_setup):
j, store, on_complete_data = journal_setup
async def test_chain_start_end_produce_lifecycle_events(self, journal_setup):
j, store = journal_setup
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None)
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
assert "total_tokens" in on_complete_data
assert "message_count" in on_complete_data
await asyncio.sleep(0.05)
await j.flush()
events = await store.list_events("t1", "r1")
types = [e["event_type"] for e in events if e["category"] == "lifecycle"]
assert "run_start" in types
assert "run_end" in types
@pytest.mark.anyio
async def test_nested_chain_ignored(self, journal_setup):
j, store, on_complete_data = journal_setup
j, store = journal_setup
parent_id = uuid4()
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id)
j.on_chain_end({}, run_id=uuid4(), parent_run_id=parent_id)
@ -149,7 +178,7 @@ class TestLifecycleCallbacks:
class TestToolCallbacks:
@pytest.mark.anyio
async def test_tool_start_end_produce_trace(self, journal_setup):
j, store, _ = journal_setup
j, store = journal_setup
j.on_tool_start({"name": "web_search"}, "query", run_id=uuid4())
j.on_tool_end("results", run_id=uuid4(), name="web_search")
await j.flush()
@ -158,11 +187,19 @@ class TestToolCallbacks:
assert "tool_start" in types
assert "tool_end" in types
@pytest.mark.anyio
async def test_on_tool_error(self, journal_setup):
j, store = journal_setup
j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch")
await j.flush()
events = await store.list_events("t1", "r1")
assert any(e["event_type"] == "tool_error" for e in events)
class TestCustomEvents:
@pytest.mark.anyio
async def test_summarization_event(self, journal_setup):
j, store, _ = journal_setup
j, store = journal_setup
j.on_custom_event(
"summarization",
{"summary": "Context was summarized.", "replaced_count": 5, "replaced_message_ids": ["a", "b"]},
@ -176,50 +213,76 @@ class TestCustomEvents:
assert len(messages) == 1
assert messages[0]["event_type"] == "summary"
@pytest.mark.anyio
async def test_non_summarization_custom_event(self, journal_setup):
j, store = journal_setup
j.on_custom_event("task_running", {"task_id": "t1", "status": "running"}, run_id=uuid4())
await j.flush()
events = await store.list_events("t1", "r1")
assert any(e["event_type"] == "task_running" for e in events)
class TestBufferFlush:
@pytest.mark.anyio
async def test_flush_threshold(self, journal_setup):
j, store, _ = journal_setup
j, store = journal_setup
j._flush_threshold = 3
j.on_tool_start({"name": "a"}, "x", run_id=uuid4())
j.on_tool_start({"name": "b"}, "x", run_id=uuid4())
# Buffer has 2 events, not yet flushed
assert len(j._buffer) == 2
j.on_tool_start({"name": "c"}, "x", run_id=uuid4())
# Buffer should have been flushed (threshold=3 triggers flush)
# Give the async task a chance to complete
await asyncio.sleep(0.1)
events = await store.list_events("t1", "r1")
assert len(events) >= 3
@pytest.mark.anyio
async def test_events_retained_when_no_loop(self, journal_setup):
"""Events buffered in a sync (no-loop) context should survive
until the async flush() in the finally block."""
j, store = journal_setup
j._flush_threshold = 1
original = asyncio.get_running_loop
def no_loop():
raise RuntimeError("no running event loop")
asyncio.get_running_loop = no_loop
try:
j._put(event_type="llm_end", category="trace", content="test")
finally:
asyncio.get_running_loop = original
assert len(j._buffer) == 1
await j.flush()
events = await store.list_events("t1", "r1")
assert any(e["event_type"] == "llm_end" for e in events)
class TestIdentifyCaller:
def test_lead_agent_tag(self, journal_setup):
j, _, _ = journal_setup
j, _ = journal_setup
assert j._identify_caller({"tags": ["lead_agent"]}) == "lead_agent"
def test_subagent_tag(self, journal_setup):
j, _, _ = journal_setup
j, _ = journal_setup
assert j._identify_caller({"tags": ["subagent:research"]}) == "subagent:research"
def test_middleware_tag(self, journal_setup):
j, _, _ = journal_setup
j, _ = journal_setup
assert j._identify_caller({"tags": ["middleware:summarization"]}) == "middleware:summarization"
def test_no_tags_returns_unknown(self, journal_setup):
j, _, _ = journal_setup
assert j._identify_caller({"tags": []}) == "unknown"
assert j._identify_caller({}) == "unknown"
def test_no_tags_returns_lead_agent(self, journal_setup):
j, _ = journal_setup
assert j._identify_caller({"tags": []}) == "lead_agent"
assert j._identify_caller({}) == "lead_agent"
class TestChainErrorCallback:
@pytest.mark.anyio
async def test_on_chain_error_writes_run_error(self, journal_setup):
j, store, _ = journal_setup
# parent_run_id must be None (top-level chain) for the event to be recorded
j, store = journal_setup
j.on_chain_error(ValueError("boom"), run_id=uuid4(), parent_run_id=None)
# on_chain_error calls _flush_sync internally, give async task time to complete
await asyncio.sleep(0.05)
await j.flush()
events = await store.list_events("t1", "r1")
@ -232,85 +295,125 @@ class TestChainErrorCallback:
class TestTokenTrackingDisabled:
@pytest.mark.anyio
async def test_track_token_usage_false(self):
"""track_token_usage=False disables token accumulation."""
store = MemoryRunEventStore()
complete_data = {}
j = RunJournal("r1", "t1", store, track_token_usage=False, on_complete=lambda **d: complete_data.update(d), flush_threshold=100)
j.on_llm_end(_make_llm_response("X", usage={"input_tokens": 50, "output_tokens": 50, "total_tokens": 100}), run_id=uuid4(), tags=["lead_agent"])
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
assert complete_data["total_tokens"] == 0
assert complete_data["llm_call_count"] == 0
class TestMiddlewareNoMessage:
@pytest.mark.anyio
async def test_on_llm_end_middleware_no_ai_message(self, journal_setup):
j, store, _ = journal_setup
j.on_llm_end(_make_llm_response("Summary"), run_id=uuid4(), tags=["middleware:summarization"])
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 0
class TestUnknownCallerTokens:
@pytest.mark.anyio
async def test_unknown_caller_tokens_go_to_lead(self, journal_setup):
"""No caller tag: tokens attributed to lead_agent bucket."""
j, store, _ = journal_setup
j.on_llm_end(_make_llm_response("X", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), run_id=uuid4(), tags=[])
assert j._lead_agent_tokens == 15
j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100)
j.on_llm_end(
_make_llm_response("X", usage={"input_tokens": 50, "output_tokens": 50, "total_tokens": 100}),
run_id=uuid4(),
tags=["lead_agent"],
)
data = j.get_completion_data()
assert data["total_tokens"] == 0
assert data["llm_call_count"] == 0
class TestConvenienceFields:
@pytest.mark.anyio
async def test_last_ai_message_tracks_latest(self, journal_setup):
j, store, complete_data = journal_setup
j, store = journal_setup
j.on_llm_end(_make_llm_response("First"), run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Second"), run_id=uuid4(), tags=["lead_agent"])
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
assert complete_data["last_ai_message"] == "Second"
assert complete_data["message_count"] == 2
data = j.get_completion_data()
assert data["last_ai_message"] == "Second"
assert data["message_count"] == 2
@pytest.mark.anyio
async def test_first_human_message_via_set(self, journal_setup):
j, store, complete_data = journal_setup
j, _ = journal_setup
j.set_first_human_message("What is AI?")
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
assert complete_data["first_human_message"] == "What is AI?"
class TestToolError:
@pytest.mark.anyio
async def test_on_tool_error(self, journal_setup):
j, store, _ = journal_setup
j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch")
await j.flush()
events = await store.list_events("t1", "r1")
assert any(e["event_type"] == "tool_error" for e in events)
class TestOtherCustomEvent:
@pytest.mark.anyio
async def test_non_summarization_custom_event(self, journal_setup):
j, store, _ = journal_setup
j.on_custom_event("task_running", {"task_id": "t1", "status": "running"}, run_id=uuid4())
await j.flush()
events = await store.list_events("t1", "r1")
assert any(e["event_type"] == "task_running" for e in events)
class TestPublicMethods:
@pytest.mark.anyio
async def test_set_first_human_message(self, journal_setup):
j, _, _ = journal_setup
j.set_first_human_message("Hello world")
assert j._first_human_msg == "Hello world"
data = j.get_completion_data()
assert data["first_human_message"] == "What is AI?"
@pytest.mark.anyio
async def test_get_completion_data(self, journal_setup):
j, _, _ = journal_setup
j, _ = journal_setup
j._total_tokens = 100
j._msg_count = 5
data = j.get_completion_data()
assert data["total_tokens"] == 100
assert data["message_count"] == 5
class TestUnknownCallerTokens:
@pytest.mark.anyio
async def test_unknown_caller_tokens_go_to_lead(self, journal_setup):
j, store = journal_setup
j.on_llm_end(
_make_llm_response("X", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=uuid4(),
tags=[],
)
assert j._lead_agent_tokens == 15
# ---------------------------------------------------------------------------
# SQLite-backed end-to-end test
# ---------------------------------------------------------------------------
class TestDbBackedLifecycle:
@pytest.mark.anyio
async def test_full_lifecycle_with_sqlite(self, tmp_path):
"""Full lifecycle with SQLite-backed RunRepository + DbRunEventStore."""
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.persistence.repositories.run_repo import RunRepository
from deerflow.runtime.events.store.db import DbRunEventStore
from deerflow.runtime.runs.manager import RunManager
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
sf = get_session_factory()
run_store = RunRepository(sf)
event_store = DbRunEventStore(sf)
mgr = RunManager(store=run_store)
# Create run
record = await mgr.create("t1", "lead_agent")
run_id = record.run_id
# Write human_message
await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content="Hello DB")
# Simulate journal
journal = RunJournal(run_id, "t1", event_store, flush_threshold=100)
journal.set_first_human_message("Hello DB")
journal.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None)
llm_rid = uuid4()
journal.on_llm_start({"name": "test"}, [], run_id=llm_rid, tags=["lead_agent"])
journal.on_llm_end(
_make_llm_response("DB response", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=llm_rid,
tags=["lead_agent"],
)
journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
await asyncio.sleep(0.05)
await journal.flush()
# Verify run persisted
row = await run_store.get(run_id)
assert row is not None
assert row["status"] == "pending"
# Update completion
completion = journal.get_completion_data()
await run_store.update_run_completion(run_id, status="success", **completion)
row = await run_store.get(run_id)
assert row["status"] == "success"
assert row["total_tokens"] == 15
# Verify messages from DB
messages = await event_store.list_messages("t1")
assert len(messages) == 2
assert messages[0]["event_type"] == "human_message"
assert messages[1]["event_type"] == "ai_message"
# Verify events from DB
events = await event_store.list_events("t1", run_id)
event_types = {e["event_type"] for e in events}
assert "run_start" in event_types
assert "llm_end" in event_types
assert "run_end" in event_types
await close_engine()