deer-flow/backend/tests/test_run_journal.py
JeffJiang 565ab432fc Fix test assertions for run ordering in RunManager tests
- Updated assertions in `test_list_by_thread` to reflect correct ordering of runs.
- Modified `test_list_by_thread_is_stable_when_timestamps_tie` to ensure stable ordering when timestamps are tied.
2026-04-19 09:55:34 +08:00

386 lines
15 KiB
Python

"""Tests for RunJournal callback handler.
Uses MemoryRunEventStore as the backend for direct event inspection.
"""
import asyncio
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from deerflow.runtime.events.store.memory import MemoryRunEventStore
from deerflow.runtime.journal import RunJournal
@pytest.fixture
def journal_setup():
store = MemoryRunEventStore()
j = RunJournal("r1", "t1", store, flush_threshold=100)
return j, store
def _make_llm_response(content="Hello", usage=None, tool_calls=None, additional_kwargs=None):
"""Create a mock LLM response with a message.
model_dump() returns checkpoint-aligned format matching real AIMessage.
"""
msg = MagicMock()
msg.type = "ai"
msg.content = content
msg.id = f"msg-{id(msg)}"
msg.tool_calls = tool_calls or []
msg.invalid_tool_calls = []
msg.response_metadata = {"model_name": "test-model"}
msg.usage_metadata = usage
msg.additional_kwargs = additional_kwargs or {}
msg.name = None
# model_dump returns checkpoint-aligned format
msg.model_dump.return_value = {
"content": content,
"additional_kwargs": additional_kwargs or {},
"response_metadata": {"model_name": "test-model"},
"type": "ai",
"name": None,
"id": msg.id,
"tool_calls": tool_calls or [],
"invalid_tool_calls": [],
"usage_metadata": usage,
}
gen = MagicMock()
gen.message = msg
response = MagicMock()
response.generations = [[gen]]
return response
class TestLlmCallbacks:
@pytest.mark.anyio
async def test_on_llm_end_produces_trace_event(self, journal_setup):
j, store = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
await j.flush()
events = await store.list_events("t1", "r1")
trace_events = [e for e in events if e["event_type"] == "llm.ai.response"]
assert len(trace_events) == 1
assert trace_events[0]["category"] == "message"
@pytest.mark.anyio
async def test_on_llm_end_lead_agent_produces_ai_message(self, journal_setup):
j, store = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Answer"), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["event_type"] == "llm.ai.response"
# Content is checkpoint-aligned model_dump format
assert messages[0]["content"]["type"] == "ai"
assert messages[0]["content"]["content"] == "Answer"
@pytest.mark.anyio
async def test_on_llm_end_with_tool_calls_produces_ai_tool_call(self, journal_setup):
"""LLM response with pending tool_calls emits llm.ai.response with tool_calls in content."""
j, store = journal_setup
run_id = uuid4()
j.on_llm_end(
_make_llm_response("Let me search", tool_calls=[{"id": "call_1", "name": "search", "args": {}}]),
run_id=run_id,
parent_run_id=None,
tags=["lead_agent"],
)
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["event_type"] == "llm.ai.response"
assert len(messages[0]["content"]["tool_calls"]) == 1
@pytest.mark.anyio
async def test_on_llm_end_subagent_no_ai_message(self, journal_setup):
j, store = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["subagent:research"])
j.on_llm_end(_make_llm_response("Sub answer"), run_id=run_id, parent_run_id=None, tags=["subagent:research"])
await j.flush()
messages = await store.list_messages("t1")
# subagent responses still emit llm.ai.response with category="message"
assert len(messages) == 1
@pytest.mark.anyio
async def test_token_accumulation(self, journal_setup):
j, store = journal_setup
usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
j.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
assert j._total_input_tokens == 30
assert j._total_output_tokens == 15
assert j._total_tokens == 45
assert j._llm_call_count == 2
@pytest.mark.anyio
async def test_total_tokens_computed_from_input_output(self, journal_setup):
"""If total_tokens is 0, it should be computed from input + output."""
j, store = journal_setup
j.on_llm_end(
_make_llm_response("Hi", usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 0}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
assert j._total_tokens == 150
@pytest.mark.anyio
async def test_caller_token_classification(self, journal_setup):
j, store = journal_setup
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"])
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:summarization"])
# token tracking not broken by caller type
assert j._total_tokens == 45
assert j._llm_call_count == 3
@pytest.mark.anyio
async def test_usage_metadata_none_no_crash(self, journal_setup):
j, store = journal_setup
j.on_llm_end(_make_llm_response("No usage", usage=None), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
await j.flush()
@pytest.mark.anyio
async def test_latency_tracking(self, journal_setup):
j, store = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
await j.flush()
events = await store.list_events("t1", "r1")
llm_resp = [e for e in events if e["event_type"] == "llm.ai.response"][0]
assert "latency_ms" in llm_resp["metadata"]
assert llm_resp["metadata"]["latency_ms"] is not None
class TestLifecycleCallbacks:
@pytest.mark.anyio
async def test_chain_start_end_produce_trace_events(self, journal_setup):
j, store = journal_setup
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None)
j.on_chain_end({}, run_id=uuid4())
await asyncio.sleep(0.05)
await j.flush()
events = await store.list_events("t1", "r1")
types = {e["event_type"] for e in events}
assert "run.start" in types
assert "run.end" in types
@pytest.mark.anyio
async def test_nested_chain_no_run_start(self, journal_setup):
"""Nested chains (parent_run_id set) should NOT produce run.start."""
j, store = journal_setup
parent_id = uuid4()
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id)
j.on_chain_end({}, run_id=uuid4())
await j.flush()
events = await store.list_events("t1", "r1")
assert not any(e["event_type"] == "run.start" for e in events)
class TestToolCallbacks:
@pytest.mark.anyio
async def test_tool_end_with_tool_message(self, journal_setup):
"""on_tool_end with a ToolMessage stores it as llm.tool.result."""
from langchain_core.messages import ToolMessage
j, store = journal_setup
tool_msg = ToolMessage(content="results", tool_call_id="call_1", name="web_search")
j.on_tool_end(tool_msg, run_id=uuid4())
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["event_type"] == "llm.tool.result"
assert messages[0]["content"]["type"] == "tool"
@pytest.mark.anyio
async def test_tool_end_with_command_unwraps_tool_message(self, journal_setup):
"""on_tool_end with Command(update={'messages':[ToolMessage]}) unwraps inner message."""
from langchain_core.messages import ToolMessage
from langgraph.types import Command
j, store = journal_setup
inner = ToolMessage(content="file list", tool_call_id="call_2", name="present_files")
cmd = Command(update={"messages": [inner]})
j.on_tool_end(cmd, run_id=uuid4())
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["event_type"] == "llm.tool.result"
assert messages[0]["content"]["content"] == "file list"
@pytest.mark.anyio
async def test_on_tool_error_no_crash(self, journal_setup):
"""on_tool_error should not crash (no event emitted by default)."""
j, store = journal_setup
j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch")
await j.flush()
# Base implementation does not emit tool_error — just verify no crash
events = await store.list_events("t1", "r1")
assert isinstance(events, list)
class TestCustomEvents:
@pytest.mark.anyio
async def test_on_custom_event_not_implemented(self, journal_setup):
"""RunJournal does not implement on_custom_event — no crash expected."""
j, store = journal_setup
# BaseCallbackHandler.on_custom_event is a no-op by default
j.on_custom_event("task_running", {"task_id": "t1"}, run_id=uuid4())
await j.flush()
events = await store.list_events("t1", "r1")
assert isinstance(events, list)
class TestBufferFlush:
@pytest.mark.anyio
async def test_flush_threshold(self, journal_setup):
j, store = journal_setup
j._flush_threshold = 2
# Each on_llm_end emits 1 event
j.on_llm_end(_make_llm_response("A"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
assert len(j._buffer) == 1
j.on_llm_end(_make_llm_response("B"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
# At threshold the buffer should have been flushed asynchronously
await asyncio.sleep(0.1)
events = await store.list_events("t1", "r1")
assert len(events) >= 2
@pytest.mark.anyio
async def test_events_retained_when_no_loop(self, journal_setup):
"""Events buffered in a sync (no-loop) context should survive
until the async flush() in the finally block."""
j, store = journal_setup
j._flush_threshold = 1
original = asyncio.get_running_loop
def no_loop():
raise RuntimeError("no running event loop")
asyncio.get_running_loop = no_loop
try:
j._put(event_type="llm.ai.response", category="message", content="test")
finally:
asyncio.get_running_loop = original
assert len(j._buffer) == 1
await j.flush()
events = await store.list_events("t1", "r1")
assert any(e["event_type"] == "llm.ai.response" for e in events)
class TestIdentifyCaller:
def test_lead_agent_tag(self, journal_setup):
j, _ = journal_setup
assert j._identify_caller(["lead_agent"]) == "lead_agent"
def test_subagent_tag(self, journal_setup):
j, _ = journal_setup
assert j._identify_caller(["subagent:research"]) == "subagent:research"
def test_middleware_tag(self, journal_setup):
j, _ = journal_setup
assert j._identify_caller(["middleware:summarization"]) == "middleware:summarization"
def test_no_tags_returns_lead_agent(self, journal_setup):
j, _ = journal_setup
assert j._identify_caller([]) == "lead_agent"
assert j._identify_caller(None) == "lead_agent"
class TestChainErrorCallback:
@pytest.mark.anyio
async def test_on_chain_error_writes_run_error(self, journal_setup):
j, store = journal_setup
j.on_chain_error(ValueError("boom"), run_id=uuid4())
await asyncio.sleep(0.05)
await j.flush()
events = await store.list_events("t1", "r1")
error_events = [e for e in events if e["event_type"] == "run.error"]
assert len(error_events) == 1
assert "boom" in error_events[0]["content"]
assert error_events[0]["metadata"]["error_type"] == "ValueError"
class TestTokenTrackingDisabled:
@pytest.mark.anyio
async def test_track_token_usage_false(self):
store = MemoryRunEventStore()
j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100)
j.on_llm_end(
_make_llm_response("X", usage={"input_tokens": 50, "output_tokens": 50, "total_tokens": 100}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
data = j.get_completion_data()
assert data["total_tokens"] == 0
assert data["llm_call_count"] == 0
class TestConvenienceFields:
@pytest.mark.anyio
async def test_first_human_message_via_set(self, journal_setup):
j, _ = journal_setup
j.set_first_human_message("What is AI?")
data = j.get_completion_data()
assert data["first_human_message"] == "What is AI?"
@pytest.mark.anyio
async def test_get_completion_data(self, journal_setup):
j, _ = journal_setup
j._total_tokens = 100
j._msg_count = 5
data = j.get_completion_data()
assert data["total_tokens"] == 100
assert data["message_count"] == 5
class TestMiddlewareEvents:
@pytest.mark.anyio
async def test_record_middleware_uses_middleware_category(self, journal_setup):
j, store = journal_setup
j.record_middleware(
"title",
name="TitleMiddleware",
hook="after_model",
action="generate_title",
changes={"title": "Test Title", "thread_id": "t1"},
)
await j.flush()
events = await store.list_events("t1", "r1")
mw_events = [e for e in events if e["event_type"] == "middleware:title"]
assert len(mw_events) == 1
assert mw_events[0]["category"] == "middleware"
assert mw_events[0]["content"]["name"] == "TitleMiddleware"
assert mw_events[0]["content"]["hook"] == "after_model"
assert mw_events[0]["content"]["action"] == "generate_title"
assert mw_events[0]["content"]["changes"]["title"] == "Test Title"
@pytest.mark.anyio
async def test_middleware_tag_variants(self, journal_setup):
"""Different middleware tags produce distinct event_types."""
j, store = journal_setup
j.record_middleware("title", name="TitleMiddleware", hook="after_model", action="generate_title", changes={})
j.record_middleware("guardrail", name="GuardrailMiddleware", hook="before_tool", action="deny", changes={})
await j.flush()
events = await store.list_events("t1", "r1")
event_types = {e["event_type"] for e in events}
assert "middleware:title" in event_types
assert "middleware:guardrail" in event_types