mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-03 07:18:25 +00:00
refactor(journal): fix flush, token tracking, and consolidate tests
RunJournal fixes: - _flush_sync: retain events in buffer when no event loop instead of dropping them; worker's finally block flushes via async flush(). - on_llm_end: add tool_calls filter and caller=="lead_agent" guard for ai_message events; mark message IDs for dedup with record_llm_usage. - worker.py: persist completion data (tokens, message count) to RunStore in finally block. Model factory: - Auto-inject stream_usage=True for BaseChatOpenAI subclasses with custom api_base, so usage_metadata is populated in streaming responses. Test consolidation: - Delete test_phase2b_integration.py (redundant with existing tests). - Move DB-backed lifecycle test into test_run_journal.py. - Add tests for stream_usage injection in test_model_factory.py. - Clean up executor/task_tool dead journal references. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
e5b01d7e74
commit
b92ddafd4b
@ -77,6 +77,15 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
|
||||
elif "reasoning_effort" not in model_settings_from_config:
|
||||
model_settings_from_config["reasoning_effort"] = "medium"
|
||||
|
||||
# Ensure stream_usage is enabled so that token usage metadata is available
|
||||
# in streaming responses. LangChain's BaseChatOpenAI only defaults
|
||||
# stream_usage=True when no custom base_url/api_base is set, so models
|
||||
# hitting third-party endpoints (e.g. doubao, deepseek) silently lose
|
||||
# usage data. We default it to True unless explicitly configured.
|
||||
if "stream_usage" not in model_settings_from_config and "stream_usage" not in kwargs:
|
||||
if "stream_usage" in getattr(model_class, "model_fields", {}):
|
||||
model_settings_from_config["stream_usage"] = True
|
||||
|
||||
model_instance = model_class(**kwargs, **model_settings_from_config)
|
||||
|
||||
if is_tracing_enabled():
|
||||
|
||||
@ -16,7 +16,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
@ -39,7 +38,6 @@ class RunJournal(BaseCallbackHandler):
|
||||
event_store: RunEventStore,
|
||||
*,
|
||||
track_token_usage: bool = True,
|
||||
on_complete: Callable[..., Any] | None = None,
|
||||
flush_threshold: int = 20,
|
||||
):
|
||||
super().__init__()
|
||||
@ -47,7 +45,6 @@ class RunJournal(BaseCallbackHandler):
|
||||
self.thread_id = thread_id
|
||||
self._store = event_store
|
||||
self._track_tokens = track_token_usage
|
||||
self._on_complete = on_complete
|
||||
self._flush_threshold = flush_threshold
|
||||
|
||||
# Write buffer
|
||||
@ -73,7 +70,6 @@ class RunJournal(BaseCallbackHandler):
|
||||
# -- Lifecycle callbacks --
|
||||
|
||||
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
# Only record for the top-level chain (parent_run_id is None)
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(
|
||||
@ -87,19 +83,6 @@ class RunJournal(BaseCallbackHandler):
|
||||
return
|
||||
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
|
||||
self._flush_sync()
|
||||
if self._on_complete:
|
||||
self._on_complete(
|
||||
total_input_tokens=self._total_input_tokens,
|
||||
total_output_tokens=self._total_output_tokens,
|
||||
total_tokens=self._total_tokens,
|
||||
llm_call_count=self._llm_call_count,
|
||||
lead_agent_tokens=self._lead_agent_tokens,
|
||||
subagent_tokens=self._subagent_tokens,
|
||||
middleware_tokens=self._middleware_tokens,
|
||||
message_count=self._msg_count,
|
||||
last_ai_message=self._last_ai_msg,
|
||||
first_human_message=self._first_human_msg,
|
||||
)
|
||||
|
||||
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
@ -131,7 +114,6 @@ class RunJournal(BaseCallbackHandler):
|
||||
logger.debug("on_llm_end: could not extract message from response")
|
||||
return
|
||||
|
||||
serialized_msg = serialize_lc_object(message)
|
||||
caller = self._identify_caller(kwargs)
|
||||
|
||||
# Latency
|
||||
@ -142,54 +124,52 @@ class RunJournal(BaseCallbackHandler):
|
||||
usage = getattr(message, "usage_metadata", None)
|
||||
usage_dict = dict(usage) if usage else {}
|
||||
|
||||
# trace event: llm_end (every LLM call)
|
||||
# Trace event: llm_end (every LLM call)
|
||||
content = getattr(message, "content", "")
|
||||
self._put(
|
||||
event_type="llm_end",
|
||||
category="trace",
|
||||
content=getattr(message, "content", "") if isinstance(getattr(message, "content", ""), str) else str(getattr(message, "content", "")),
|
||||
content=content if isinstance(content, str) else str(content),
|
||||
metadata={
|
||||
"message": serialized_msg,
|
||||
"message": serialize_lc_object(message),
|
||||
"caller": caller,
|
||||
"usage": usage_dict,
|
||||
"latency_ms": latency_ms,
|
||||
},
|
||||
)
|
||||
|
||||
# message event: ai_message (only lead_agent final replies with content)
|
||||
if caller == "lead_agent":
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, str) and content:
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
tool_calls_summary = [{"name": tc.get("name", ""), "status": "success"} for tc in tool_calls if isinstance(tc, dict)]
|
||||
resp_meta = getattr(message, "response_metadata", None) or {}
|
||||
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
|
||||
self._put(
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content=content,
|
||||
metadata={
|
||||
"model_name": model_name,
|
||||
"tool_calls": tool_calls_summary,
|
||||
},
|
||||
)
|
||||
self._last_ai_msg = content[:2000]
|
||||
self._msg_count += 1
|
||||
# Message event: ai_message (only lead_agent final replies — no pending tool_calls)
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
if caller == "lead_agent" and isinstance(content, str) and content and not tool_calls:
|
||||
resp_meta = getattr(message, "response_metadata", None) or {}
|
||||
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
|
||||
self._put(
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content=content,
|
||||
metadata={"model_name": model_name},
|
||||
)
|
||||
self._last_ai_msg = content[:2000]
|
||||
self._msg_count += 1
|
||||
|
||||
# Token accumulation
|
||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||
if self._track_tokens and total_tk > 0:
|
||||
self._total_input_tokens += input_tk
|
||||
self._total_output_tokens += output_tk
|
||||
self._total_tokens += total_tk
|
||||
self._llm_call_count += 1
|
||||
if caller.startswith("subagent:"):
|
||||
self._subagent_tokens += total_tk
|
||||
elif caller.startswith("middleware:"):
|
||||
self._middleware_tokens += total_tk
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
if self._track_tokens:
|
||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||
if total_tk == 0:
|
||||
total_tk = input_tk + output_tk
|
||||
if total_tk > 0:
|
||||
self._total_input_tokens += input_tk
|
||||
self._total_output_tokens += output_tk
|
||||
self._total_tokens += total_tk
|
||||
self._llm_call_count += 1
|
||||
if caller.startswith("subagent:"):
|
||||
self._subagent_tokens += total_tk
|
||||
elif caller.startswith("middleware:"):
|
||||
self._middleware_tokens += total_tk
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
|
||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._llm_start_times.pop(str(run_id), None)
|
||||
@ -277,20 +257,23 @@ class RunJournal(BaseCallbackHandler):
|
||||
self._flush_sync()
|
||||
|
||||
def _flush_sync(self) -> None:
|
||||
"""Flush buffer to RunEventStore.
|
||||
"""Best-effort flush of buffer to RunEventStore.
|
||||
|
||||
BaseCallbackHandler methods are synchronous. We schedule the async
|
||||
put_batch via the current event loop.
|
||||
BaseCallbackHandler methods are synchronous. If an event loop is
|
||||
running we schedule an async ``put_batch``; otherwise the events
|
||||
stay in the buffer and are flushed later by the async ``flush()``
|
||||
call in the worker's ``finally`` block.
|
||||
"""
|
||||
if not self._buffer:
|
||||
return
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._flush_async(batch))
|
||||
except RuntimeError:
|
||||
logger.warning("RunJournal: no event loop, dropping %d events", len(batch))
|
||||
# No event loop — keep events in buffer for later async flush.
|
||||
return
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
loop.create_task(self._flush_async(batch))
|
||||
|
||||
async def _flush_async(self, batch: list[dict]) -> None:
|
||||
try:
|
||||
@ -302,7 +285,10 @@ class RunJournal(BaseCallbackHandler):
|
||||
for tag in kwargs.get("tags") or []:
|
||||
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
|
||||
return tag
|
||||
return "unknown"
|
||||
# Default to lead_agent: the main agent graph does not inject
|
||||
# callback tags, while subagents and middleware explicitly tag
|
||||
# themselves.
|
||||
return "lead_agent"
|
||||
|
||||
# -- Public methods (called by worker) --
|
||||
|
||||
@ -311,7 +297,7 @@ class RunJournal(BaseCallbackHandler):
|
||||
self._first_human_msg = content[:2000] if content else None
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Force flush. Used in cancel/error paths."""
|
||||
"""Force flush remaining buffer. Called in worker's finally block."""
|
||||
if self._buffer:
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
|
||||
@ -123,7 +123,8 @@ async def run_agent(
|
||||
runtime = Runtime(context={"thread_id": thread_id}, store=store)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
# Inject RunJournal as a callback
|
||||
# Inject RunJournal as a LangChain callback handler.
|
||||
# on_llm_end captures token usage; on_chain_start/end captures lifecycle.
|
||||
if journal is not None:
|
||||
config.setdefault("callbacks", []).append(journal)
|
||||
|
||||
@ -241,13 +242,25 @@ async def run_agent(
|
||||
)
|
||||
|
||||
finally:
|
||||
# Flush any buffered journal events
|
||||
# Flush any buffered journal events and persist completion data
|
||||
if journal is not None:
|
||||
try:
|
||||
await journal.flush()
|
||||
except Exception:
|
||||
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
|
||||
|
||||
# Persist token usage + convenience fields to RunStore
|
||||
if run_manager._store is not None:
|
||||
try:
|
||||
completion = journal.get_completion_data()
|
||||
await run_manager._store.update_run_completion(
|
||||
run_id,
|
||||
status=record.status.value,
|
||||
**completion,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
|
||||
|
||||
await bridge.publish_end(run_id)
|
||||
asyncio.create_task(bridge.cleanup(run_id, delay=60))
|
||||
|
||||
|
||||
@ -593,6 +593,84 @@ def test_codex_provider_strips_unsupported_max_tokens(monkeypatch):
|
||||
assert "max_tokens" not in FakeChatModel.captured_kwargs
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# stream_usage injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _FakeWithStreamUsage(FakeChatModel):
|
||||
"""Fake model that declares stream_usage in model_fields (like BaseChatOpenAI)."""
|
||||
|
||||
stream_usage: bool | None = None
|
||||
|
||||
|
||||
def test_stream_usage_injected_for_openai_compatible_model(monkeypatch):
|
||||
"""Factory should set stream_usage=True for models with stream_usage field."""
|
||||
cfg = _make_app_config([_make_model("deepseek", use="langchain_deepseek:ChatDeepSeek")])
|
||||
_patch_factory(monkeypatch, cfg, model_class=_FakeWithStreamUsage)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class CapturingModel(_FakeWithStreamUsage):
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
BaseChatModel.__init__(self, **kwargs)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="deepseek")
|
||||
|
||||
assert captured.get("stream_usage") is True
|
||||
|
||||
|
||||
def test_stream_usage_not_injected_for_non_openai_model(monkeypatch):
|
||||
"""Factory should NOT inject stream_usage for models without the field."""
|
||||
cfg = _make_app_config([_make_model("claude", use="langchain_anthropic:ChatAnthropic")])
|
||||
_patch_factory(monkeypatch, cfg)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class CapturingModel(FakeChatModel):
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
BaseChatModel.__init__(self, **kwargs)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
factory_module.create_chat_model(name="claude")
|
||||
|
||||
assert "stream_usage" not in captured
|
||||
|
||||
|
||||
def test_stream_usage_not_overridden_when_explicitly_set_in_config(monkeypatch):
|
||||
"""If config dumps stream_usage=False, factory should respect it."""
|
||||
cfg = _make_app_config([_make_model("deepseek", use="langchain_deepseek:ChatDeepSeek")])
|
||||
_patch_factory(monkeypatch, cfg, model_class=_FakeWithStreamUsage)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
class CapturingModel(_FakeWithStreamUsage):
|
||||
def __init__(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
BaseChatModel.__init__(self, **kwargs)
|
||||
|
||||
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
|
||||
|
||||
# Simulate config having stream_usage explicitly set by patching model_dump
|
||||
original_get_model_config = cfg.get_model_config
|
||||
|
||||
def patched_get_model_config(name):
|
||||
mc = original_get_model_config(name)
|
||||
mc.stream_usage = False # type: ignore[attr-defined]
|
||||
return mc
|
||||
|
||||
monkeypatch.setattr(cfg, "get_model_config", patched_get_model_config)
|
||||
|
||||
factory_module.create_chat_model(name="deepseek")
|
||||
|
||||
assert captured.get("stream_usage") is False
|
||||
|
||||
|
||||
def test_openai_responses_api_settings_are_passed_to_chatopenai(monkeypatch):
|
||||
model = ModelConfig(
|
||||
name="gpt-5-responses",
|
||||
|
||||
@ -15,7 +15,6 @@ import pytest
|
||||
from deerflow.config.database_config import DatabaseConfig
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
|
||||
# -- DatabaseConfig --
|
||||
|
||||
|
||||
|
||||
@ -1,279 +0,0 @@
|
||||
"""Phase 2-B integration tests.
|
||||
|
||||
End-to-end test: simulate a run's complete lifecycle, verify data
|
||||
is correctly written to both RunStore and RunEventStore.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
from deerflow.runtime.journal import RunJournal
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
|
||||
class _FakeMessage:
|
||||
def __init__(self, content, usage):
|
||||
self.content = content
|
||||
self.tool_calls = []
|
||||
self.response_metadata = {"model_name": "test-model"}
|
||||
self.usage_metadata = usage
|
||||
self.id = "test-msg-id"
|
||||
|
||||
def model_dump(self):
|
||||
return {"type": "ai", "content": self.content, "id": self.id, "tool_calls": [], "usage_metadata": self.usage_metadata, "response_metadata": self.response_metadata}
|
||||
|
||||
|
||||
class _FakeGeneration:
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
|
||||
class _FakeLLMResult:
|
||||
def __init__(self, content, usage):
|
||||
self.generations = [[_FakeGeneration(_FakeMessage(content, usage))]]
|
||||
|
||||
|
||||
def _make_llm_response(content="Hello", usage=None):
|
||||
return _FakeLLMResult(content, usage)
|
||||
|
||||
|
||||
class TestRunLifecycle:
|
||||
@pytest.mark.anyio
|
||||
async def test_full_run_lifecycle(self):
|
||||
"""Simulate a complete run lifecycle with RunStore + RunEventStore."""
|
||||
run_store = MemoryRunStore()
|
||||
event_store = MemoryRunEventStore()
|
||||
|
||||
# 1. Create run
|
||||
await run_store.put("r1", thread_id="t1", status="pending")
|
||||
|
||||
# 2. Write human_message
|
||||
await event_store.put(
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content="What is AI?",
|
||||
)
|
||||
|
||||
# 3. Simulate RunJournal callback sequence
|
||||
on_complete_data = {}
|
||||
|
||||
def on_complete(**data):
|
||||
on_complete_data.update(data)
|
||||
|
||||
journal = RunJournal("r1", "t1", event_store, on_complete=on_complete, flush_threshold=100)
|
||||
journal.set_first_human_message("What is AI?")
|
||||
|
||||
# chain_start (top-level)
|
||||
journal.on_chain_start({}, {"messages": ["What is AI?"]}, run_id=uuid4(), parent_run_id=None)
|
||||
|
||||
# llm_start + llm_end
|
||||
llm_run_id = uuid4()
|
||||
journal.on_llm_start({"name": "gpt-4"}, ["prompt"], run_id=llm_run_id, tags=["lead_agent"])
|
||||
usage = {"input_tokens": 50, "output_tokens": 100, "total_tokens": 150}
|
||||
journal.on_llm_end(_make_llm_response("AI is artificial intelligence.", usage=usage), run_id=llm_run_id, tags=["lead_agent"])
|
||||
|
||||
# chain_end (triggers on_complete + flush_sync which creates a task)
|
||||
journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
|
||||
await journal.flush()
|
||||
# Let event loop process any pending flush tasks from _flush_sync
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# 4. Verify messages
|
||||
messages = await event_store.list_messages("t1")
|
||||
assert len(messages) == 2 # human + ai
|
||||
assert messages[0]["event_type"] == "human_message"
|
||||
assert messages[1]["event_type"] == "ai_message"
|
||||
assert messages[1]["content"] == "AI is artificial intelligence."
|
||||
|
||||
# 5. Verify events
|
||||
events = await event_store.list_events("t1", "r1")
|
||||
event_types = {e["event_type"] for e in events}
|
||||
assert "run_start" in event_types
|
||||
assert "llm_start" in event_types
|
||||
assert "llm_end" in event_types
|
||||
assert "run_end" in event_types
|
||||
|
||||
# 6. Verify on_complete data
|
||||
assert on_complete_data["total_tokens"] == 150
|
||||
assert on_complete_data["llm_call_count"] == 1
|
||||
assert on_complete_data["lead_agent_tokens"] == 150
|
||||
assert on_complete_data["message_count"] == 1
|
||||
assert on_complete_data["last_ai_message"] == "AI is artificial intelligence."
|
||||
assert on_complete_data["first_human_message"] == "What is AI?"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_run_with_tool_calls(self):
|
||||
"""Simulate a run that uses tools."""
|
||||
event_store = MemoryRunEventStore()
|
||||
journal = RunJournal("r1", "t1", event_store, flush_threshold=100)
|
||||
|
||||
# tool_start + tool_end
|
||||
journal.on_tool_start({"name": "web_search"}, '{"query": "AI"}', run_id=uuid4())
|
||||
journal.on_tool_end("Search results...", run_id=uuid4(), name="web_search")
|
||||
await journal.flush()
|
||||
|
||||
events = await event_store.list_events("t1", "r1")
|
||||
assert len(events) == 2
|
||||
assert events[0]["event_type"] == "tool_start"
|
||||
assert events[1]["event_type"] == "tool_end"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_multi_run_thread(self):
|
||||
"""Multiple runs on the same thread maintain unified seq ordering."""
|
||||
event_store = MemoryRunEventStore()
|
||||
|
||||
# Run 1
|
||||
await event_store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="Q1")
|
||||
await event_store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="A1")
|
||||
|
||||
# Run 2
|
||||
await event_store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message", content="Q2")
|
||||
await event_store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message", content="A2")
|
||||
|
||||
messages = await event_store.list_messages("t1")
|
||||
assert len(messages) == 4
|
||||
assert [m["seq"] for m in messages] == [1, 2, 3, 4]
|
||||
assert messages[0]["run_id"] == "r1"
|
||||
assert messages[2]["run_id"] == "r2"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_runmanager_with_store_backing(self):
|
||||
"""RunManager persists to RunStore when one is provided."""
|
||||
from deerflow.runtime.runs.manager import RunManager
|
||||
|
||||
run_store = MemoryRunStore()
|
||||
mgr = RunManager(store=run_store)
|
||||
|
||||
record = await mgr.create("t1", assistant_id="lead_agent")
|
||||
# Verify persisted to store
|
||||
row = await run_store.get(record.run_id)
|
||||
assert row is not None
|
||||
assert row["thread_id"] == "t1"
|
||||
assert row["status"] == "pending"
|
||||
|
||||
# Status update
|
||||
from deerflow.runtime.runs.schemas import RunStatus
|
||||
|
||||
await mgr.set_status(record.run_id, RunStatus.running)
|
||||
row = await run_store.get(record.run_id)
|
||||
assert row["status"] == "running"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_runmanager_create_or_reject_persists(self):
|
||||
"""create_or_reject also persists to store."""
|
||||
from deerflow.runtime.runs.manager import RunManager
|
||||
|
||||
run_store = MemoryRunStore()
|
||||
mgr = RunManager(store=run_store)
|
||||
|
||||
record = await mgr.create_or_reject("t1", "lead_agent", metadata={"key": "val"})
|
||||
row = await run_store.get(record.run_id)
|
||||
assert row is not None
|
||||
assert row["status"] == "pending"
|
||||
assert row["metadata"] == {"key": "val"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_follow_up_metadata_in_messages(self):
|
||||
"""human_message metadata carries follow_up_to_run_id."""
|
||||
event_store = MemoryRunEventStore()
|
||||
|
||||
# Run 1
|
||||
await event_store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="Q1")
|
||||
await event_store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="A1")
|
||||
|
||||
# Run 2 (follow-up)
|
||||
await event_store.put(
|
||||
thread_id="t1",
|
||||
run_id="r2",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content="Tell me more",
|
||||
metadata={"follow_up_to_run_id": "r1"},
|
||||
)
|
||||
|
||||
messages = await event_store.list_messages("t1")
|
||||
assert len(messages) == 3
|
||||
assert messages[2]["metadata"]["follow_up_to_run_id"] == "r1"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_summarization_in_history(self):
|
||||
"""summary message appears correctly in message history."""
|
||||
event_store = MemoryRunEventStore()
|
||||
|
||||
await event_store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="Q1")
|
||||
await event_store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="A1")
|
||||
await event_store.put(thread_id="t1", run_id="r2", event_type="summary", category="message", content="Previous conversation summarized.", metadata={"replaced_count": 2})
|
||||
await event_store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message", content="Q2")
|
||||
await event_store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message", content="A2")
|
||||
|
||||
messages = await event_store.list_messages("t1")
|
||||
assert len(messages) == 5
|
||||
assert messages[2]["event_type"] == "summary"
|
||||
assert messages[2]["metadata"]["replaced_count"] == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_db_backed_run_lifecycle(self, tmp_path):
|
||||
"""Full lifecycle with SQLite-backed RunRepository + DbRunEventStore."""
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.persistence.repositories.run_repo import RunRepository
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
from deerflow.runtime.runs.manager import RunManager
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
sf = get_session_factory()
|
||||
|
||||
run_store = RunRepository(sf)
|
||||
event_store = DbRunEventStore(sf)
|
||||
mgr = RunManager(store=run_store)
|
||||
|
||||
# Create run
|
||||
record = await mgr.create("t1", "lead_agent")
|
||||
run_id = record.run_id
|
||||
|
||||
# Write human_message
|
||||
await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content="Hello DB")
|
||||
|
||||
# Simulate journal
|
||||
on_complete_data = {}
|
||||
journal = RunJournal(run_id, "t1", event_store, on_complete=lambda **d: on_complete_data.update(d), flush_threshold=100)
|
||||
journal.set_first_human_message("Hello DB")
|
||||
|
||||
journal.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None)
|
||||
llm_rid = uuid4()
|
||||
journal.on_llm_start({"name": "test"}, [], run_id=llm_rid, tags=["lead_agent"])
|
||||
journal.on_llm_end(_make_llm_response("DB response", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), run_id=llm_rid, tags=["lead_agent"])
|
||||
journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
|
||||
await journal.flush()
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Verify run persisted
|
||||
row = await run_store.get(run_id)
|
||||
assert row is not None
|
||||
assert row["status"] == "pending" # RunManager set it, journal doesn't update status
|
||||
|
||||
# Update completion
|
||||
await run_store.update_run_completion(run_id, status="success", **on_complete_data)
|
||||
row = await run_store.get(run_id)
|
||||
assert row["status"] == "success"
|
||||
assert row["total_tokens"] == 15
|
||||
|
||||
# Verify messages from DB
|
||||
messages = await event_store.list_messages("t1")
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["event_type"] == "human_message"
|
||||
assert messages[1]["event_type"] == "ai_message"
|
||||
|
||||
# Verify events from DB
|
||||
events = await event_store.list_events("t1", run_id)
|
||||
event_types = {e["event_type"] for e in events}
|
||||
assert "run_start" in event_types
|
||||
assert "llm_end" in event_types
|
||||
assert "run_end" in event_types
|
||||
|
||||
await close_engine()
|
||||
@ -16,22 +16,28 @@ from deerflow.runtime.journal import RunJournal
|
||||
@pytest.fixture
|
||||
def journal_setup():
|
||||
store = MemoryRunEventStore()
|
||||
on_complete_data = {}
|
||||
|
||||
def on_complete(**data):
|
||||
on_complete_data.update(data)
|
||||
|
||||
j = RunJournal("r1", "t1", store, on_complete=on_complete, flush_threshold=100)
|
||||
return j, store, on_complete_data
|
||||
j = RunJournal("r1", "t1", store, flush_threshold=100)
|
||||
return j, store
|
||||
|
||||
|
||||
def _make_llm_response(content="Hello", usage=None):
|
||||
def _make_llm_response(content="Hello", usage=None, tool_calls=None):
|
||||
"""Create a mock LLM response with a message."""
|
||||
msg = MagicMock()
|
||||
msg.content = content
|
||||
msg.tool_calls = []
|
||||
msg.id = f"msg-{id(msg)}"
|
||||
msg.tool_calls = tool_calls or []
|
||||
msg.response_metadata = {"model_name": "test-model"}
|
||||
msg.usage_metadata = usage
|
||||
# Provide a real model_dump so serialize_lc_object returns a plain dict
|
||||
# (needed for DB-backed tests where json.dumps must succeed).
|
||||
msg.model_dump.return_value = {
|
||||
"type": "ai",
|
||||
"content": content,
|
||||
"id": msg.id,
|
||||
"tool_calls": tool_calls or [],
|
||||
"usage_metadata": usage,
|
||||
"response_metadata": {"model_name": "test-model"},
|
||||
}
|
||||
|
||||
gen = MagicMock()
|
||||
gen.message = msg
|
||||
@ -44,7 +50,7 @@ def _make_llm_response(content="Hello", usage=None):
|
||||
class TestLlmCallbacks:
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_produces_trace_event(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, tags=["lead_agent"])
|
||||
@ -56,7 +62,7 @@ class TestLlmCallbacks:
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_lead_agent_produces_ai_message(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("Answer"), run_id=run_id, tags=["lead_agent"])
|
||||
@ -66,9 +72,23 @@ class TestLlmCallbacks:
|
||||
assert messages[0]["event_type"] == "ai_message"
|
||||
assert messages[0]["content"] == "Answer"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_with_tool_calls_no_ai_message(self, journal_setup):
|
||||
"""LLM response with pending tool_calls should NOT produce ai_message."""
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_end(
|
||||
_make_llm_response("Let me search", tool_calls=[{"name": "search"}]),
|
||||
run_id=run_id,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_subagent_no_ai_message(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_start({}, [], run_id=run_id, tags=["subagent:research"])
|
||||
j.on_llm_end(_make_llm_response("Sub answer"), run_id=run_id, tags=["subagent:research"])
|
||||
@ -78,27 +98,34 @@ class TestLlmCallbacks:
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_token_accumulation(self, journal_setup):
|
||||
j, store, on_complete_data = journal_setup
|
||||
j, store = journal_setup
|
||||
usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
|
||||
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4(), tags=["lead_agent"])
|
||||
assert j._total_input_tokens == 30
|
||||
assert j._total_output_tokens == 15
|
||||
assert j._total_tokens == 45
|
||||
assert j._llm_call_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_total_tokens_computed_from_input_output(self, journal_setup):
|
||||
"""If total_tokens is 0, it should be computed from input + output."""
|
||||
j, store = journal_setup
|
||||
j.on_llm_end(
|
||||
_make_llm_response("Hi", usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 0}),
|
||||
run_id=uuid4(),
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
assert j._total_tokens == 150
|
||||
assert j._lead_agent_tokens == 150
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_caller_token_classification(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j, store = journal_setup
|
||||
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_start({}, [], run_id=uuid4(), tags=["subagent:research"])
|
||||
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), tags=["subagent:research"])
|
||||
j.on_llm_start({}, [], run_id=uuid4(), tags=["middleware:summarization"])
|
||||
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), tags=["middleware:summarization"])
|
||||
assert j._lead_agent_tokens == 15
|
||||
assert j._subagent_tokens == 15
|
||||
@ -106,15 +133,13 @@ class TestLlmCallbacks:
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_usage_metadata_none_no_crash(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
|
||||
j, store = journal_setup
|
||||
j.on_llm_end(_make_llm_response("No usage", usage=None), run_id=uuid4(), tags=["lead_agent"])
|
||||
# Should not raise
|
||||
await j.flush()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_latency_tracking(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, tags=["lead_agent"])
|
||||
@ -127,16 +152,20 @@ class TestLlmCallbacks:
|
||||
|
||||
class TestLifecycleCallbacks:
|
||||
@pytest.mark.anyio
|
||||
async def test_on_chain_end_triggers_on_complete(self, journal_setup):
|
||||
j, store, on_complete_data = journal_setup
|
||||
async def test_chain_start_end_produce_lifecycle_events(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None)
|
||||
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
|
||||
assert "total_tokens" in on_complete_data
|
||||
assert "message_count" in on_complete_data
|
||||
await asyncio.sleep(0.05)
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
types = [e["event_type"] for e in events if e["category"] == "lifecycle"]
|
||||
assert "run_start" in types
|
||||
assert "run_end" in types
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_nested_chain_ignored(self, journal_setup):
|
||||
j, store, on_complete_data = journal_setup
|
||||
j, store = journal_setup
|
||||
parent_id = uuid4()
|
||||
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id)
|
||||
j.on_chain_end({}, run_id=uuid4(), parent_run_id=parent_id)
|
||||
@ -149,7 +178,7 @@ class TestLifecycleCallbacks:
|
||||
class TestToolCallbacks:
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_start_end_produce_trace(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j, store = journal_setup
|
||||
j.on_tool_start({"name": "web_search"}, "query", run_id=uuid4())
|
||||
j.on_tool_end("results", run_id=uuid4(), name="web_search")
|
||||
await j.flush()
|
||||
@ -158,11 +187,19 @@ class TestToolCallbacks:
|
||||
assert "tool_start" in types
|
||||
assert "tool_end" in types
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_on_tool_error(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch")
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert any(e["event_type"] == "tool_error" for e in events)
|
||||
|
||||
|
||||
class TestCustomEvents:
|
||||
@pytest.mark.anyio
|
||||
async def test_summarization_event(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j, store = journal_setup
|
||||
j.on_custom_event(
|
||||
"summarization",
|
||||
{"summary": "Context was summarized.", "replaced_count": 5, "replaced_message_ids": ["a", "b"]},
|
||||
@ -176,50 +213,76 @@ class TestCustomEvents:
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "summary"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_non_summarization_custom_event(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
j.on_custom_event("task_running", {"task_id": "t1", "status": "running"}, run_id=uuid4())
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert any(e["event_type"] == "task_running" for e in events)
|
||||
|
||||
|
||||
class TestBufferFlush:
|
||||
@pytest.mark.anyio
|
||||
async def test_flush_threshold(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j, store = journal_setup
|
||||
j._flush_threshold = 3
|
||||
j.on_tool_start({"name": "a"}, "x", run_id=uuid4())
|
||||
j.on_tool_start({"name": "b"}, "x", run_id=uuid4())
|
||||
# Buffer has 2 events, not yet flushed
|
||||
assert len(j._buffer) == 2
|
||||
j.on_tool_start({"name": "c"}, "x", run_id=uuid4())
|
||||
# Buffer should have been flushed (threshold=3 triggers flush)
|
||||
# Give the async task a chance to complete
|
||||
await asyncio.sleep(0.1)
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert len(events) >= 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_events_retained_when_no_loop(self, journal_setup):
|
||||
"""Events buffered in a sync (no-loop) context should survive
|
||||
until the async flush() in the finally block."""
|
||||
j, store = journal_setup
|
||||
j._flush_threshold = 1
|
||||
|
||||
original = asyncio.get_running_loop
|
||||
|
||||
def no_loop():
|
||||
raise RuntimeError("no running event loop")
|
||||
|
||||
asyncio.get_running_loop = no_loop
|
||||
try:
|
||||
j._put(event_type="llm_end", category="trace", content="test")
|
||||
finally:
|
||||
asyncio.get_running_loop = original
|
||||
|
||||
assert len(j._buffer) == 1
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert any(e["event_type"] == "llm_end" for e in events)
|
||||
|
||||
|
||||
class TestIdentifyCaller:
|
||||
def test_lead_agent_tag(self, journal_setup):
|
||||
j, _, _ = journal_setup
|
||||
j, _ = journal_setup
|
||||
assert j._identify_caller({"tags": ["lead_agent"]}) == "lead_agent"
|
||||
|
||||
def test_subagent_tag(self, journal_setup):
|
||||
j, _, _ = journal_setup
|
||||
j, _ = journal_setup
|
||||
assert j._identify_caller({"tags": ["subagent:research"]}) == "subagent:research"
|
||||
|
||||
def test_middleware_tag(self, journal_setup):
|
||||
j, _, _ = journal_setup
|
||||
j, _ = journal_setup
|
||||
assert j._identify_caller({"tags": ["middleware:summarization"]}) == "middleware:summarization"
|
||||
|
||||
def test_no_tags_returns_unknown(self, journal_setup):
|
||||
j, _, _ = journal_setup
|
||||
assert j._identify_caller({"tags": []}) == "unknown"
|
||||
assert j._identify_caller({}) == "unknown"
|
||||
def test_no_tags_returns_lead_agent(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
assert j._identify_caller({"tags": []}) == "lead_agent"
|
||||
assert j._identify_caller({}) == "lead_agent"
|
||||
|
||||
|
||||
class TestChainErrorCallback:
|
||||
@pytest.mark.anyio
|
||||
async def test_on_chain_error_writes_run_error(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
# parent_run_id must be None (top-level chain) for the event to be recorded
|
||||
j, store = journal_setup
|
||||
j.on_chain_error(ValueError("boom"), run_id=uuid4(), parent_run_id=None)
|
||||
# on_chain_error calls _flush_sync internally, give async task time to complete
|
||||
await asyncio.sleep(0.05)
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
@ -232,85 +295,125 @@ class TestChainErrorCallback:
|
||||
class TestTokenTrackingDisabled:
|
||||
@pytest.mark.anyio
|
||||
async def test_track_token_usage_false(self):
|
||||
"""track_token_usage=False disables token accumulation."""
|
||||
store = MemoryRunEventStore()
|
||||
complete_data = {}
|
||||
j = RunJournal("r1", "t1", store, track_token_usage=False, on_complete=lambda **d: complete_data.update(d), flush_threshold=100)
|
||||
j.on_llm_end(_make_llm_response("X", usage={"input_tokens": 50, "output_tokens": 50, "total_tokens": 100}), run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
|
||||
assert complete_data["total_tokens"] == 0
|
||||
assert complete_data["llm_call_count"] == 0
|
||||
|
||||
|
||||
class TestMiddlewareNoMessage:
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_middleware_no_ai_message(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j.on_llm_end(_make_llm_response("Summary"), run_id=uuid4(), tags=["middleware:summarization"])
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 0
|
||||
|
||||
|
||||
class TestUnknownCallerTokens:
|
||||
@pytest.mark.anyio
|
||||
async def test_unknown_caller_tokens_go_to_lead(self, journal_setup):
|
||||
"""No caller tag: tokens attributed to lead_agent bucket."""
|
||||
j, store, _ = journal_setup
|
||||
j.on_llm_end(_make_llm_response("X", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), run_id=uuid4(), tags=[])
|
||||
assert j._lead_agent_tokens == 15
|
||||
j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100)
|
||||
j.on_llm_end(
|
||||
_make_llm_response("X", usage={"input_tokens": 50, "output_tokens": 50, "total_tokens": 100}),
|
||||
run_id=uuid4(),
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
data = j.get_completion_data()
|
||||
assert data["total_tokens"] == 0
|
||||
assert data["llm_call_count"] == 0
|
||||
|
||||
|
||||
class TestConvenienceFields:
|
||||
@pytest.mark.anyio
|
||||
async def test_last_ai_message_tracks_latest(self, journal_setup):
|
||||
j, store, complete_data = journal_setup
|
||||
j, store = journal_setup
|
||||
j.on_llm_end(_make_llm_response("First"), run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("Second"), run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
|
||||
assert complete_data["last_ai_message"] == "Second"
|
||||
assert complete_data["message_count"] == 2
|
||||
data = j.get_completion_data()
|
||||
assert data["last_ai_message"] == "Second"
|
||||
assert data["message_count"] == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_first_human_message_via_set(self, journal_setup):
|
||||
j, store, complete_data = journal_setup
|
||||
j, _ = journal_setup
|
||||
j.set_first_human_message("What is AI?")
|
||||
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
|
||||
assert complete_data["first_human_message"] == "What is AI?"
|
||||
|
||||
|
||||
class TestToolError:
|
||||
@pytest.mark.anyio
|
||||
async def test_on_tool_error(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch")
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert any(e["event_type"] == "tool_error" for e in events)
|
||||
|
||||
|
||||
class TestOtherCustomEvent:
|
||||
@pytest.mark.anyio
|
||||
async def test_non_summarization_custom_event(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j.on_custom_event("task_running", {"task_id": "t1", "status": "running"}, run_id=uuid4())
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert any(e["event_type"] == "task_running" for e in events)
|
||||
|
||||
|
||||
class TestPublicMethods:
|
||||
@pytest.mark.anyio
|
||||
async def test_set_first_human_message(self, journal_setup):
|
||||
j, _, _ = journal_setup
|
||||
j.set_first_human_message("Hello world")
|
||||
assert j._first_human_msg == "Hello world"
|
||||
data = j.get_completion_data()
|
||||
assert data["first_human_message"] == "What is AI?"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_completion_data(self, journal_setup):
|
||||
j, _, _ = journal_setup
|
||||
j, _ = journal_setup
|
||||
j._total_tokens = 100
|
||||
j._msg_count = 5
|
||||
data = j.get_completion_data()
|
||||
assert data["total_tokens"] == 100
|
||||
assert data["message_count"] == 5
|
||||
|
||||
|
||||
class TestUnknownCallerTokens:
|
||||
@pytest.mark.anyio
|
||||
async def test_unknown_caller_tokens_go_to_lead(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
j.on_llm_end(
|
||||
_make_llm_response("X", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
|
||||
run_id=uuid4(),
|
||||
tags=[],
|
||||
)
|
||||
assert j._lead_agent_tokens == 15
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SQLite-backed end-to-end test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDbBackedLifecycle:
|
||||
@pytest.mark.anyio
|
||||
async def test_full_lifecycle_with_sqlite(self, tmp_path):
|
||||
"""Full lifecycle with SQLite-backed RunRepository + DbRunEventStore."""
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.persistence.repositories.run_repo import RunRepository
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
from deerflow.runtime.runs.manager import RunManager
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
sf = get_session_factory()
|
||||
|
||||
run_store = RunRepository(sf)
|
||||
event_store = DbRunEventStore(sf)
|
||||
mgr = RunManager(store=run_store)
|
||||
|
||||
# Create run
|
||||
record = await mgr.create("t1", "lead_agent")
|
||||
run_id = record.run_id
|
||||
|
||||
# Write human_message
|
||||
await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content="Hello DB")
|
||||
|
||||
# Simulate journal
|
||||
journal = RunJournal(run_id, "t1", event_store, flush_threshold=100)
|
||||
journal.set_first_human_message("Hello DB")
|
||||
|
||||
journal.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None)
|
||||
llm_rid = uuid4()
|
||||
journal.on_llm_start({"name": "test"}, [], run_id=llm_rid, tags=["lead_agent"])
|
||||
journal.on_llm_end(
|
||||
_make_llm_response("DB response", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
|
||||
run_id=llm_rid,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
|
||||
await asyncio.sleep(0.05)
|
||||
await journal.flush()
|
||||
|
||||
# Verify run persisted
|
||||
row = await run_store.get(run_id)
|
||||
assert row is not None
|
||||
assert row["status"] == "pending"
|
||||
|
||||
# Update completion
|
||||
completion = journal.get_completion_data()
|
||||
await run_store.update_run_completion(run_id, status="success", **completion)
|
||||
row = await run_store.get(run_id)
|
||||
assert row["status"] == "success"
|
||||
assert row["total_tokens"] == 15
|
||||
|
||||
# Verify messages from DB
|
||||
messages = await event_store.list_messages("t1")
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["event_type"] == "human_message"
|
||||
assert messages[1]["event_type"] == "ai_message"
|
||||
|
||||
# Verify events from DB
|
||||
events = await event_store.list_events("t1", run_id)
|
||||
event_types = {e["event_type"] for e in events}
|
||||
assert "run_start" in event_types
|
||||
assert "llm_end" in event_types
|
||||
assert "run_end" in event_types
|
||||
|
||||
await close_engine()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user