mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-13 04:03:41 +00:00
* fix: bucket subagent token usage into RunRow.subagent_tokens Add caller-bucketed token tracking to RunJournal so subagent and middleware LLM calls are written to the correct RunRow columns instead of all falling into lead_agent_tokens (default 0). - RunJournal: accumulate _lead_agent_tokens / _subagent_tokens / _middleware_tokens in on_llm_end, deduped by langchain run_id. Add record_external_llm_usage_records() for external sources (respects track_token_usage flag). Return caller buckets from get_completion_data(). - SubagentTokenCollector: new lightweight callback handler that collects LLM usage within subagent execution. - SubagentExecutor: wire collector into subagent run_config and sync records to SubagentResult on every chunk (timeout/cancel safe). - SubagentResult: add token_usage_records and usage_reported fields. - task_tool: report subagent usage to parent RunJournal on every terminal status (COMPLETED/FAILED/CANCELLED/TIMED_OUT), including the CancelledError path, guarded against double-reporting. No DB migration needed — RunRow columns already exist. * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix: address token usage review feedback * Address review follow-ups --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
681 lines
28 KiB
Python
681 lines
28 KiB
Python
"""Tests for RunJournal callback handler.
|
|
|
|
Uses MemoryRunEventStore as the backend for direct event inspection.
|
|
"""
|
|
|
|
import asyncio
|
|
from unittest.mock import MagicMock
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
|
from deerflow.runtime.journal import RunJournal
|
|
|
|
|
|
@pytest.fixture
|
|
def journal_setup():
|
|
store = MemoryRunEventStore()
|
|
j = RunJournal("r1", "t1", store, flush_threshold=100)
|
|
return j, store
|
|
|
|
|
|
def _make_llm_response(content="Hello", usage=None, tool_calls=None, additional_kwargs=None):
|
|
"""Create a mock LLM response with a message.
|
|
|
|
model_dump() returns checkpoint-aligned format matching real AIMessage.
|
|
"""
|
|
msg = MagicMock()
|
|
msg.type = "ai"
|
|
msg.content = content
|
|
msg.id = f"msg-{id(msg)}"
|
|
msg.tool_calls = tool_calls or []
|
|
msg.invalid_tool_calls = []
|
|
msg.response_metadata = {"model_name": "test-model"}
|
|
msg.usage_metadata = usage
|
|
msg.additional_kwargs = additional_kwargs or {}
|
|
msg.name = None
|
|
# model_dump returns checkpoint-aligned format
|
|
msg.model_dump.return_value = {
|
|
"content": content,
|
|
"additional_kwargs": additional_kwargs or {},
|
|
"response_metadata": {"model_name": "test-model"},
|
|
"type": "ai",
|
|
"name": None,
|
|
"id": msg.id,
|
|
"tool_calls": tool_calls or [],
|
|
"invalid_tool_calls": [],
|
|
"usage_metadata": usage,
|
|
}
|
|
|
|
gen = MagicMock()
|
|
gen.message = msg
|
|
|
|
response = MagicMock()
|
|
response.generations = [[gen]]
|
|
return response
|
|
|
|
|
|
class TestLlmCallbacks:
|
|
@pytest.mark.anyio
|
|
async def test_on_llm_end_produces_trace_event(self, journal_setup):
|
|
j, store = journal_setup
|
|
run_id = uuid4()
|
|
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
|
|
j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
|
await j.flush()
|
|
events = await store.list_events("t1", "r1")
|
|
trace_events = [e for e in events if e["event_type"] == "llm.ai.response"]
|
|
assert len(trace_events) == 1
|
|
assert trace_events[0]["category"] == "message"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_on_llm_end_lead_agent_produces_ai_message(self, journal_setup):
|
|
j, store = journal_setup
|
|
run_id = uuid4()
|
|
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
|
|
j.on_llm_end(_make_llm_response("Answer"), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
|
await j.flush()
|
|
messages = await store.list_messages("t1")
|
|
assert len(messages) == 1
|
|
assert messages[0]["event_type"] == "llm.ai.response"
|
|
# Content is checkpoint-aligned model_dump format
|
|
assert messages[0]["content"]["type"] == "ai"
|
|
assert messages[0]["content"]["content"] == "Answer"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_on_llm_end_with_tool_calls_produces_ai_tool_call(self, journal_setup):
|
|
"""LLM response with pending tool_calls emits llm.ai.response with tool_calls in content."""
|
|
j, store = journal_setup
|
|
run_id = uuid4()
|
|
j.on_llm_end(
|
|
_make_llm_response("Let me search", tool_calls=[{"id": "call_1", "name": "search", "args": {}}]),
|
|
run_id=run_id,
|
|
parent_run_id=None,
|
|
tags=["lead_agent"],
|
|
)
|
|
await j.flush()
|
|
messages = await store.list_messages("t1")
|
|
assert len(messages) == 1
|
|
assert messages[0]["event_type"] == "llm.ai.response"
|
|
assert len(messages[0]["content"]["tool_calls"]) == 1
|
|
|
|
@pytest.mark.anyio
|
|
async def test_on_llm_end_subagent_no_ai_message(self, journal_setup):
|
|
j, store = journal_setup
|
|
run_id = uuid4()
|
|
j.on_llm_start({}, [], run_id=run_id, tags=["subagent:research"])
|
|
j.on_llm_end(_make_llm_response("Sub answer"), run_id=run_id, parent_run_id=None, tags=["subagent:research"])
|
|
await j.flush()
|
|
messages = await store.list_messages("t1")
|
|
# subagent responses still emit llm.ai.response with category="message"
|
|
assert len(messages) == 1
|
|
|
|
@pytest.mark.anyio
|
|
async def test_token_accumulation(self, journal_setup):
|
|
j, store = journal_setup
|
|
usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
|
usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
|
|
j.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
|
j.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
|
assert j._total_input_tokens == 30
|
|
assert j._total_output_tokens == 15
|
|
assert j._total_tokens == 45
|
|
assert j._llm_call_count == 2
|
|
|
|
@pytest.mark.anyio
|
|
async def test_total_tokens_computed_from_input_output(self, journal_setup):
|
|
"""If total_tokens is 0, it should be computed from input + output."""
|
|
j, store = journal_setup
|
|
j.on_llm_end(
|
|
_make_llm_response("Hi", usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 0}),
|
|
run_id=uuid4(),
|
|
parent_run_id=None,
|
|
tags=["lead_agent"],
|
|
)
|
|
assert j._total_tokens == 150
|
|
|
|
@pytest.mark.anyio
|
|
async def test_caller_token_classification(self, journal_setup):
|
|
j, store = journal_setup
|
|
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
|
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
|
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"])
|
|
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:summarization"])
|
|
# token tracking not broken by caller type
|
|
assert j._total_tokens == 45
|
|
assert j._llm_call_count == 3
|
|
|
|
@pytest.mark.anyio
|
|
async def test_usage_metadata_none_no_crash(self, journal_setup):
|
|
j, store = journal_setup
|
|
j.on_llm_end(_make_llm_response("No usage", usage=None), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
|
await j.flush()
|
|
|
|
@pytest.mark.anyio
|
|
async def test_latency_tracking(self, journal_setup):
|
|
j, store = journal_setup
|
|
run_id = uuid4()
|
|
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
|
|
j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
|
await j.flush()
|
|
events = await store.list_events("t1", "r1")
|
|
llm_resp = [e for e in events if e["event_type"] == "llm.ai.response"][0]
|
|
assert "latency_ms" in llm_resp["metadata"]
|
|
assert llm_resp["metadata"]["latency_ms"] is not None
|
|
|
|
|
|
class TestLifecycleCallbacks:
|
|
@pytest.mark.anyio
|
|
async def test_chain_start_end_produce_trace_events(self, journal_setup):
|
|
j, store = journal_setup
|
|
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None)
|
|
j.on_chain_end({}, run_id=uuid4())
|
|
await asyncio.sleep(0.05)
|
|
await j.flush()
|
|
events = await store.list_events("t1", "r1")
|
|
types = {e["event_type"] for e in events}
|
|
assert "run.start" in types
|
|
assert "run.end" in types
|
|
|
|
@pytest.mark.anyio
|
|
async def test_nested_chain_no_run_start(self, journal_setup):
|
|
"""Nested chains (parent_run_id set) should NOT produce run.start."""
|
|
j, store = journal_setup
|
|
parent_id = uuid4()
|
|
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id)
|
|
j.on_chain_end({}, run_id=uuid4())
|
|
await j.flush()
|
|
events = await store.list_events("t1", "r1")
|
|
assert not any(e["event_type"] == "run.start" for e in events)
|
|
|
|
|
|
class TestToolCallbacks:
|
|
@pytest.mark.anyio
|
|
async def test_tool_end_with_tool_message(self, journal_setup):
|
|
"""on_tool_end with a ToolMessage stores it as llm.tool.result."""
|
|
from langchain_core.messages import ToolMessage
|
|
|
|
j, store = journal_setup
|
|
tool_msg = ToolMessage(content="results", tool_call_id="call_1", name="web_search")
|
|
j.on_tool_end(tool_msg, run_id=uuid4())
|
|
await j.flush()
|
|
messages = await store.list_messages("t1")
|
|
assert len(messages) == 1
|
|
assert messages[0]["event_type"] == "llm.tool.result"
|
|
assert messages[0]["content"]["type"] == "tool"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_tool_end_with_command_unwraps_tool_message(self, journal_setup):
|
|
"""on_tool_end with Command(update={'messages':[ToolMessage]}) unwraps inner message."""
|
|
from langchain_core.messages import ToolMessage
|
|
from langgraph.types import Command
|
|
|
|
j, store = journal_setup
|
|
inner = ToolMessage(content="file list", tool_call_id="call_2", name="present_files")
|
|
cmd = Command(update={"messages": [inner]})
|
|
j.on_tool_end(cmd, run_id=uuid4())
|
|
await j.flush()
|
|
messages = await store.list_messages("t1")
|
|
assert len(messages) == 1
|
|
assert messages[0]["event_type"] == "llm.tool.result"
|
|
assert messages[0]["content"]["content"] == "file list"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_on_tool_error_no_crash(self, journal_setup):
|
|
"""on_tool_error should not crash (no event emitted by default)."""
|
|
j, store = journal_setup
|
|
j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch")
|
|
await j.flush()
|
|
# Base implementation does not emit tool_error — just verify no crash
|
|
events = await store.list_events("t1", "r1")
|
|
assert isinstance(events, list)
|
|
|
|
|
|
class TestCustomEvents:
|
|
@pytest.mark.anyio
|
|
async def test_on_custom_event_not_implemented(self, journal_setup):
|
|
"""RunJournal does not implement on_custom_event — no crash expected."""
|
|
j, store = journal_setup
|
|
# BaseCallbackHandler.on_custom_event is a no-op by default
|
|
j.on_custom_event("task_running", {"task_id": "t1"}, run_id=uuid4())
|
|
await j.flush()
|
|
events = await store.list_events("t1", "r1")
|
|
assert isinstance(events, list)
|
|
|
|
|
|
class TestBufferFlush:
|
|
@pytest.mark.anyio
|
|
async def test_flush_threshold(self, journal_setup):
|
|
j, store = journal_setup
|
|
j._flush_threshold = 2
|
|
# Each on_llm_end emits 1 event
|
|
j.on_llm_end(_make_llm_response("A"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
|
assert len(j._buffer) == 1
|
|
j.on_llm_end(_make_llm_response("B"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
|
# At threshold the buffer should have been flushed asynchronously
|
|
await asyncio.sleep(0.1)
|
|
events = await store.list_events("t1", "r1")
|
|
assert len(events) >= 2
|
|
|
|
@pytest.mark.anyio
|
|
async def test_events_retained_when_no_loop(self, journal_setup):
|
|
"""Events buffered in a sync (no-loop) context should survive
|
|
until the async flush() in the finally block."""
|
|
j, store = journal_setup
|
|
j._flush_threshold = 1
|
|
|
|
original = asyncio.get_running_loop
|
|
|
|
def no_loop():
|
|
raise RuntimeError("no running event loop")
|
|
|
|
asyncio.get_running_loop = no_loop
|
|
try:
|
|
j._put(event_type="llm.ai.response", category="message", content="test")
|
|
finally:
|
|
asyncio.get_running_loop = original
|
|
|
|
assert len(j._buffer) == 1
|
|
await j.flush()
|
|
events = await store.list_events("t1", "r1")
|
|
assert any(e["event_type"] == "llm.ai.response" for e in events)
|
|
|
|
|
|
class TestIdentifyCaller:
|
|
def test_lead_agent_tag(self, journal_setup):
|
|
j, _ = journal_setup
|
|
assert j._identify_caller(["lead_agent"]) == "lead_agent"
|
|
|
|
def test_subagent_tag(self, journal_setup):
|
|
j, _ = journal_setup
|
|
assert j._identify_caller(["subagent:research"]) == "subagent:research"
|
|
|
|
def test_middleware_tag(self, journal_setup):
|
|
j, _ = journal_setup
|
|
assert j._identify_caller(["middleware:summarization"]) == "middleware:summarization"
|
|
|
|
def test_no_tags_returns_lead_agent(self, journal_setup):
|
|
j, _ = journal_setup
|
|
assert j._identify_caller([]) == "lead_agent"
|
|
assert j._identify_caller(None) == "lead_agent"
|
|
|
|
|
|
class TestChainErrorCallback:
|
|
@pytest.mark.anyio
|
|
async def test_on_chain_error_writes_run_error(self, journal_setup):
|
|
j, store = journal_setup
|
|
j.on_chain_error(ValueError("boom"), run_id=uuid4())
|
|
await asyncio.sleep(0.05)
|
|
await j.flush()
|
|
events = await store.list_events("t1", "r1")
|
|
error_events = [e for e in events if e["event_type"] == "run.error"]
|
|
assert len(error_events) == 1
|
|
assert "boom" in error_events[0]["content"]
|
|
assert error_events[0]["metadata"]["error_type"] == "ValueError"
|
|
|
|
|
|
class TestTokenTrackingDisabled:
|
|
@pytest.mark.anyio
|
|
async def test_track_token_usage_false(self):
|
|
store = MemoryRunEventStore()
|
|
j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100)
|
|
j.on_llm_end(
|
|
_make_llm_response("X", usage={"input_tokens": 50, "output_tokens": 50, "total_tokens": 100}),
|
|
run_id=uuid4(),
|
|
parent_run_id=None,
|
|
tags=["lead_agent"],
|
|
)
|
|
data = j.get_completion_data()
|
|
assert data["total_tokens"] == 0
|
|
assert data["llm_call_count"] == 0
|
|
|
|
|
|
class TestConvenienceFields:
|
|
@pytest.mark.anyio
|
|
async def test_first_human_message_via_set(self, journal_setup):
|
|
j, _ = journal_setup
|
|
j.set_first_human_message("What is AI?")
|
|
data = j.get_completion_data()
|
|
assert data["first_human_message"] == "What is AI?"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_get_completion_data(self, journal_setup):
|
|
j, _ = journal_setup
|
|
j._total_tokens = 100
|
|
j._msg_count = 5
|
|
data = j.get_completion_data()
|
|
assert data["total_tokens"] == 100
|
|
assert data["message_count"] == 5
|
|
|
|
|
|
class TestMiddlewareEvents:
|
|
@pytest.mark.anyio
|
|
async def test_record_middleware_uses_middleware_category(self, journal_setup):
|
|
j, store = journal_setup
|
|
j.record_middleware(
|
|
"title",
|
|
name="TitleMiddleware",
|
|
hook="after_model",
|
|
action="generate_title",
|
|
changes={"title": "Test Title", "thread_id": "t1"},
|
|
)
|
|
await j.flush()
|
|
events = await store.list_events("t1", "r1")
|
|
mw_events = [e for e in events if e["event_type"] == "middleware:title"]
|
|
assert len(mw_events) == 1
|
|
assert mw_events[0]["category"] == "middleware"
|
|
assert mw_events[0]["content"]["name"] == "TitleMiddleware"
|
|
assert mw_events[0]["content"]["hook"] == "after_model"
|
|
assert mw_events[0]["content"]["action"] == "generate_title"
|
|
assert mw_events[0]["content"]["changes"]["title"] == "Test Title"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_middleware_tag_variants(self, journal_setup):
|
|
"""Different middleware tags produce distinct event_types."""
|
|
j, store = journal_setup
|
|
j.record_middleware("title", name="TitleMiddleware", hook="after_model", action="generate_title", changes={})
|
|
j.record_middleware("guardrail", name="GuardrailMiddleware", hook="before_tool", action="deny", changes={})
|
|
await j.flush()
|
|
events = await store.list_events("t1", "r1")
|
|
event_types = {e["event_type"] for e in events}
|
|
assert "middleware:title" in event_types
|
|
assert "middleware:guardrail" in event_types
|
|
|
|
|
|
class TestCallerBucketing:
|
|
"""Tests for caller-bucketed token accumulation (lead_agent / subagent / middleware)."""
|
|
|
|
def test_lead_agent_bucketing(self, journal_setup):
|
|
j, _ = journal_setup
|
|
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
|
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
|
assert j._lead_agent_tokens == 15
|
|
assert j._subagent_tokens == 0
|
|
assert j._middleware_tokens == 0
|
|
|
|
def test_subagent_bucketing(self, journal_setup):
|
|
j, _ = journal_setup
|
|
usage = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
|
|
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"])
|
|
assert j._subagent_tokens == 30
|
|
assert j._lead_agent_tokens == 0
|
|
assert j._middleware_tokens == 0
|
|
|
|
def test_middleware_bucketing(self, journal_setup):
|
|
j, _ = journal_setup
|
|
usage = {"input_tokens": 5, "output_tokens": 2, "total_tokens": 7}
|
|
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:summarize"])
|
|
assert j._middleware_tokens == 7
|
|
assert j._lead_agent_tokens == 0
|
|
assert j._subagent_tokens == 0
|
|
|
|
def test_mixed_callers_sum_independently(self, journal_setup):
|
|
j, _ = journal_setup
|
|
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
|
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
|
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:bash"])
|
|
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:title"])
|
|
assert j._lead_agent_tokens == 15
|
|
assert j._subagent_tokens == 15
|
|
assert j._middleware_tokens == 15
|
|
assert j._total_tokens == 45
|
|
|
|
def test_get_completion_data_includes_buckets(self, journal_setup):
|
|
j, _ = journal_setup
|
|
j._lead_agent_tokens = 100
|
|
j._subagent_tokens = 200
|
|
j._middleware_tokens = 50
|
|
data = j.get_completion_data()
|
|
assert data["lead_agent_tokens"] == 100
|
|
assert data["subagent_tokens"] == 200
|
|
assert data["middleware_tokens"] == 50
|
|
|
|
def test_dedup_same_run_id(self, journal_setup):
|
|
"""Same langchain run_id in on_llm_end must not double-count."""
|
|
j, _ = journal_setup
|
|
run_id = uuid4()
|
|
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
|
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
|
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
|
assert j._total_tokens == 15
|
|
assert j._lead_agent_tokens == 15
|
|
assert j._llm_call_count == 1
|
|
|
|
def test_first_no_usage_second_with_usage(self, journal_setup):
|
|
"""First callback with no usage must not block second callback with usage for same run_id."""
|
|
j, _ = journal_setup
|
|
run_id = uuid4()
|
|
j.on_llm_end(_make_llm_response("A", usage=None), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
|
assert str(run_id) not in j._counted_llm_run_ids
|
|
# Second callback for the same run_id with actual usage must still count
|
|
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
|
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
|
assert j._total_tokens == 15
|
|
assert j._lead_agent_tokens == 15
|
|
|
|
def test_track_token_usage_false_skips_buckets(self):
|
|
"""When token tracking is disabled, caller buckets stay at 0."""
|
|
store = MemoryRunEventStore()
|
|
j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100)
|
|
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
|
j.on_llm_end(_make_llm_response("X", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"])
|
|
assert j._subagent_tokens == 0
|
|
assert j._lead_agent_tokens == 0
|
|
|
|
def test_default_no_tags_buckets_as_lead_agent(self, journal_setup):
|
|
"""LLM calls without explicit tags default to lead_agent bucket."""
|
|
j, _ = journal_setup
|
|
usage = {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10}
|
|
j.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4(), parent_run_id=None)
|
|
assert j._lead_agent_tokens == 10
|
|
assert j._subagent_tokens == 0
|
|
assert j._middleware_tokens == 0
|
|
|
|
def test_unknown_tag_buckets_as_lead_agent(self, journal_setup):
|
|
"""Calls with unrecognized tags (not lead_agent/subagent:/middleware:) go to lead_agent."""
|
|
j, _ = journal_setup
|
|
usage = {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10}
|
|
j.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["some_random_tag"])
|
|
assert j._lead_agent_tokens == 10
|
|
|
|
|
|
class TestExternalUsageRecords:
|
|
"""Tests for record_external_llm_usage_records."""
|
|
|
|
def test_records_added_to_subagent_bucket(self, journal_setup):
|
|
j, _ = journal_setup
|
|
records = [
|
|
{
|
|
"source_run_id": "ext-1",
|
|
"caller": "subagent:general-purpose",
|
|
"input_tokens": 100,
|
|
"output_tokens": 50,
|
|
"total_tokens": 150,
|
|
}
|
|
]
|
|
j.record_external_llm_usage_records(records)
|
|
assert j._subagent_tokens == 150
|
|
assert j._total_tokens == 150
|
|
assert j._total_input_tokens == 100
|
|
assert j._total_output_tokens == 50
|
|
|
|
def test_records_added_to_middleware_bucket(self, journal_setup):
|
|
j, _ = journal_setup
|
|
records = [
|
|
{
|
|
"source_run_id": "ext-2",
|
|
"caller": "middleware:summarize",
|
|
"input_tokens": 30,
|
|
"output_tokens": 10,
|
|
"total_tokens": 40,
|
|
}
|
|
]
|
|
j.record_external_llm_usage_records(records)
|
|
assert j._middleware_tokens == 40
|
|
assert j._lead_agent_tokens == 0
|
|
assert j._subagent_tokens == 0
|
|
|
|
def test_records_added_to_lead_agent_bucket(self, journal_setup):
|
|
j, _ = journal_setup
|
|
records = [
|
|
{
|
|
"source_run_id": "ext-3",
|
|
"caller": "lead_agent",
|
|
"input_tokens": 10,
|
|
"output_tokens": 5,
|
|
"total_tokens": 15,
|
|
}
|
|
]
|
|
j.record_external_llm_usage_records(records)
|
|
assert j._lead_agent_tokens == 15
|
|
|
|
def test_dedup_same_source_run_id(self, journal_setup):
|
|
"""Same source_run_id must not be double-counted."""
|
|
j, _ = journal_setup
|
|
records = [
|
|
{
|
|
"source_run_id": "dup-1",
|
|
"caller": "subagent:research",
|
|
"input_tokens": 50,
|
|
"output_tokens": 25,
|
|
"total_tokens": 75,
|
|
}
|
|
]
|
|
j.record_external_llm_usage_records(records)
|
|
j.record_external_llm_usage_records(records)
|
|
assert j._subagent_tokens == 75
|
|
assert j._total_tokens == 75
|
|
|
|
def test_total_tokens_missing_computed_from_input_output(self, journal_setup):
|
|
j, _ = journal_setup
|
|
records = [
|
|
{
|
|
"source_run_id": "ext-4",
|
|
"caller": "subagent:bash",
|
|
"input_tokens": 200,
|
|
"output_tokens": 100,
|
|
"total_tokens": 0,
|
|
}
|
|
]
|
|
j.record_external_llm_usage_records(records)
|
|
assert j._subagent_tokens == 300
|
|
assert j._total_tokens == 300
|
|
|
|
def test_total_tokens_zero_no_count(self, journal_setup):
|
|
"""Records with zero total and zero input+output must not be counted."""
|
|
j, _ = journal_setup
|
|
records = [
|
|
{
|
|
"source_run_id": "ext-5",
|
|
"caller": "subagent:research",
|
|
"input_tokens": 0,
|
|
"output_tokens": 0,
|
|
"total_tokens": 0,
|
|
}
|
|
]
|
|
j.record_external_llm_usage_records(records)
|
|
assert j._total_tokens == 0
|
|
assert j._subagent_tokens == 0
|
|
|
|
def test_empty_source_run_id_skipped(self, journal_setup):
|
|
j, _ = journal_setup
|
|
records = [
|
|
{
|
|
"source_run_id": "",
|
|
"caller": "subagent:research",
|
|
"input_tokens": 50,
|
|
"output_tokens": 25,
|
|
"total_tokens": 75,
|
|
}
|
|
]
|
|
j.record_external_llm_usage_records(records)
|
|
assert j._total_tokens == 0
|
|
|
|
def test_multiple_records_in_single_call(self, journal_setup):
|
|
j, _ = journal_setup
|
|
records = [
|
|
{"source_run_id": "r1", "caller": "subagent:gp", "input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
|
{"source_run_id": "r2", "caller": "subagent:bash", "input_tokens": 20, "output_tokens": 10, "total_tokens": 30},
|
|
]
|
|
j.record_external_llm_usage_records(records)
|
|
assert j._subagent_tokens == 45
|
|
assert j._total_tokens == 45
|
|
|
|
def test_external_records_coexist_with_inline_callbacks(self, journal_setup):
|
|
"""External records and inline on_llm_end must not interfere."""
|
|
j, _ = journal_setup
|
|
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
|
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
|
j.record_external_llm_usage_records([{"source_run_id": "ext-6", "caller": "subagent:gp", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150}])
|
|
assert j._lead_agent_tokens == 15
|
|
assert j._subagent_tokens == 150
|
|
assert j._total_tokens == 165
|
|
|
|
def test_track_token_usage_false_skips_external_records(self):
|
|
"""When token tracking is disabled, external records must not accumulate."""
|
|
store = MemoryRunEventStore()
|
|
j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100)
|
|
j.record_external_llm_usage_records([{"source_run_id": "ext-7", "caller": "subagent:gp", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150}])
|
|
assert j._total_tokens == 0
|
|
assert j._subagent_tokens == 0
|
|
|
|
|
|
class TestChatModelStartHumanMessage:
|
|
"""Tests for on_chat_model_start extracting the first human message."""
|
|
|
|
@pytest.mark.anyio
|
|
async def test_extracts_first_human_message(self, journal_setup):
|
|
"""on_chat_model_start captures the first HumanMessage from prompts."""
|
|
from langchain_core.messages import AIMessage, HumanMessage
|
|
|
|
j, store = journal_setup
|
|
messages_batch = [
|
|
[HumanMessage(content="What is AI?"), AIMessage(content="Hi there")],
|
|
]
|
|
j.on_chat_model_start({}, messages_batch, run_id=uuid4(), tags=["lead_agent"])
|
|
await j.flush()
|
|
|
|
assert j._first_human_msg == "What is AI?"
|
|
events = await store.list_events("t1", "r1")
|
|
human_events = [e for e in events if e["event_type"] == "llm.human.input"]
|
|
assert len(human_events) == 1
|
|
assert human_events[0]["content"]["content"] == "What is AI?"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_skips_summary_named_human_messages(self, journal_setup):
|
|
"""HumanMessages with name='summary' are skipped."""
|
|
from langchain_core.messages import HumanMessage
|
|
|
|
j, store = journal_setup
|
|
messages_batch = [
|
|
[HumanMessage(content="Summarized context", name="summary"), HumanMessage(content="Real question")],
|
|
]
|
|
j.on_chat_model_start({}, messages_batch, run_id=uuid4(), tags=["lead_agent"])
|
|
await j.flush()
|
|
|
|
assert j._first_human_msg == "Real question"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_only_first_human_message_captured(self, journal_setup):
|
|
"""Subsequent on_chat_model_start calls do not overwrite the first message."""
|
|
from langchain_core.messages import HumanMessage
|
|
|
|
j, store = journal_setup
|
|
j.on_chat_model_start({}, [[HumanMessage(content="First question")]], run_id=uuid4(), tags=["lead_agent"])
|
|
j.on_chat_model_start({}, [[HumanMessage(content="Second question")]], run_id=uuid4(), tags=["lead_agent"])
|
|
await j.flush()
|
|
|
|
assert j._first_human_msg == "First question"
|
|
events = await store.list_events("t1", "r1")
|
|
human_events = [e for e in events if e["event_type"] == "llm.human.input"]
|
|
assert len(human_events) == 1
|
|
|
|
@pytest.mark.anyio
|
|
async def test_empty_messages_no_crash(self, journal_setup):
|
|
"""on_chat_model_start with empty messages does not crash."""
|
|
j, store = journal_setup
|
|
j.on_chat_model_start({}, [], run_id=uuid4(), tags=["lead_agent"])
|
|
await j.flush()
|
|
assert j._first_human_msg is None
|