mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
feat(events): align message events with checkpoint format and add middleware tag injection
- Message events (ai_message, ai_tool_call, tool_result, human_message) now use
BaseMessage.model_dump() format, matching LangGraph checkpoint values.messages
- on_tool_end extracts tool_call_id/name/status from ToolMessage objects
- on_tool_error now emits tool_result message events with error status
- record_middleware uses middleware:{tag} event_type and middleware category
- Summarization custom events use middleware:summarize category
- TitleMiddleware injects middleware:title tag via get_config() inheritance
- SummarizationMiddleware model bound with middleware:summarize tag
- Worker writes human_message using HumanMessage.model_dump()
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
2d135aad0f
commit
52e7acafee
@ -56,13 +56,15 @@ def _create_summarization_middleware() -> SummarizationMiddleware | None:
|
||||
# Prepare keep parameter
|
||||
keep = config.keep.to_tuple()
|
||||
|
||||
# Prepare model parameter
|
||||
# Prepare model parameter.
|
||||
# Bind "middleware:summarize" tag so RunJournal identifies these LLM calls
|
||||
# as middleware rather than lead_agent (SummarizationMiddleware is a
|
||||
# LangChain built-in, so we tag the model at creation time).
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
else:
|
||||
# Use a lightweight model for summarization to save costs
|
||||
# Falls back to default model if not explicitly specified
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
model = model.with_config(tags=["middleware:summarize"])
|
||||
|
||||
# Prepare kwargs
|
||||
kwargs = {
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
"""Middleware for automatic thread title generation."""
|
||||
|
||||
import logging
|
||||
from typing import NotRequired, override
|
||||
from typing import Any, NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.title_config import get_title_config
|
||||
@ -100,6 +101,20 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
return user_msg[:fallback_chars].rstrip() + "..."
|
||||
return user_msg if user_msg else "New Conversation"
|
||||
|
||||
def _get_runnable_config(self) -> dict[str, Any]:
|
||||
"""Inherit the parent RunnableConfig and add middleware tag.
|
||||
|
||||
This ensures RunJournal identifies LLM calls from this middleware
|
||||
as ``middleware:title`` instead of ``lead_agent``.
|
||||
"""
|
||||
try:
|
||||
parent = get_config()
|
||||
except Exception:
|
||||
parent = {}
|
||||
config = {**parent}
|
||||
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
|
||||
return config
|
||||
|
||||
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
"""Synchronously generate a title. Returns state update or None."""
|
||||
if not self._should_generate_title(state):
|
||||
@ -110,7 +125,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
|
||||
try:
|
||||
response = model.invoke(prompt)
|
||||
response = model.invoke(prompt, config=self._get_runnable_config())
|
||||
title = self._parse_title(response.content)
|
||||
if not title:
|
||||
title = self._fallback_title(user_msg)
|
||||
@ -130,7 +145,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
|
||||
try:
|
||||
response = await model.ainvoke(prompt)
|
||||
response = await model.ainvoke(prompt, config=self._get_runnable_config())
|
||||
title = self._parse_title(response.content)
|
||||
if not title:
|
||||
title = self._fallback_title(user_msg)
|
||||
|
||||
@ -179,7 +179,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
},
|
||||
)
|
||||
|
||||
# Message events: only lead_agent gets message-category events
|
||||
# Message events: only lead_agent gets message-category events.
|
||||
# Content uses message.model_dump() to align with checkpoint format.
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
if caller == "lead_agent":
|
||||
resp_meta = getattr(message, "response_metadata", None) or {}
|
||||
@ -189,7 +190,7 @@ class RunJournal(BaseCallbackHandler):
|
||||
self._put(
|
||||
event_type="ai_tool_call",
|
||||
category="message",
|
||||
content=langchain_to_openai_message(message),
|
||||
content=message.model_dump(),
|
||||
metadata={"model_name": model_name, "finish_reason": "tool_calls"},
|
||||
)
|
||||
elif isinstance(content, str) and content:
|
||||
@ -197,10 +198,10 @@ class RunJournal(BaseCallbackHandler):
|
||||
self._put(
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content={"role": "assistant", "content": content},
|
||||
content=message.model_dump(),
|
||||
metadata={"model_name": model_name, "finish_reason": "stop"},
|
||||
)
|
||||
self._last_ai_msg = content[:2000]
|
||||
self._last_ai_msg = content
|
||||
self._msg_count += 1
|
||||
|
||||
# Token accumulation
|
||||
@ -242,45 +243,87 @@ class RunJournal(BaseCallbackHandler):
|
||||
},
|
||||
)
|
||||
|
||||
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = kwargs.get("name", "")
|
||||
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
# Extract fields from ToolMessage object when LangChain provides one.
|
||||
# LangChain's _format_output wraps tool results into a ToolMessage
|
||||
# with tool_call_id, name, status, and artifact — more complete than
|
||||
# what kwargs alone provides.
|
||||
if isinstance(output, ToolMessage):
|
||||
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = output.name or kwargs.get("name", "")
|
||||
status = getattr(output, "status", "success") or "success"
|
||||
content_str = output.content if isinstance(output.content, str) else str(output.content)
|
||||
# Use model_dump() for checkpoint-aligned message content.
|
||||
# Override tool_call_id if it was resolved from cache.
|
||||
msg_content = output.model_dump()
|
||||
if msg_content.get("tool_call_id") != tool_call_id:
|
||||
msg_content["tool_call_id"] = tool_call_id
|
||||
else:
|
||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = kwargs.get("name", "")
|
||||
status = "success"
|
||||
content_str = str(output)
|
||||
# Construct checkpoint-aligned dict when output is a plain string.
|
||||
msg_content = ToolMessage(
|
||||
content=content_str,
|
||||
tool_call_id=tool_call_id or "",
|
||||
name=tool_name,
|
||||
status=status,
|
||||
).model_dump()
|
||||
|
||||
# Trace event (always)
|
||||
self._put(
|
||||
event_type="tool_end",
|
||||
category="trace",
|
||||
content=str(output),
|
||||
content=content_str,
|
||||
metadata={
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": tool_call_id,
|
||||
"status": "success",
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
|
||||
# Message event: tool_result
|
||||
# Message event: tool_result (checkpoint-aligned model_dump format)
|
||||
self._put(
|
||||
event_type="tool_result",
|
||||
category="message",
|
||||
content={
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id or "",
|
||||
"content": str(output),
|
||||
},
|
||||
metadata={"tool_name": tool_name},
|
||||
content=msg_content,
|
||||
metadata={"tool_name": tool_name, "status": status},
|
||||
)
|
||||
|
||||
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = kwargs.get("name", "")
|
||||
|
||||
# Trace event
|
||||
self._put(
|
||||
event_type="tool_error",
|
||||
category="trace",
|
||||
content=str(error),
|
||||
metadata={
|
||||
"tool_name": kwargs.get("name", ""),
|
||||
"tool_call_id": kwargs.get("tool_call_id"),
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Message event: tool_result with error status (checkpoint-aligned)
|
||||
msg_content = ToolMessage(
|
||||
content=str(error),
|
||||
tool_call_id=tool_call_id or "",
|
||||
name=tool_name,
|
||||
status="error",
|
||||
).model_dump()
|
||||
self._put(
|
||||
event_type="tool_result",
|
||||
category="message",
|
||||
content=msg_content,
|
||||
metadata={"tool_name": tool_name, "status": "error"},
|
||||
)
|
||||
|
||||
# -- Custom event callback --
|
||||
|
||||
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
@ -298,8 +341,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
},
|
||||
)
|
||||
self._put(
|
||||
event_type="summary",
|
||||
category="message",
|
||||
event_type="middleware:summarize",
|
||||
category="middleware",
|
||||
content={"role": "system", "content": data_dict.get("summary", "")},
|
||||
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
|
||||
)
|
||||
@ -366,16 +409,24 @@ class RunJournal(BaseCallbackHandler):
|
||||
"""Record the first human message for convenience fields."""
|
||||
self._first_human_msg = content[:2000] if content else None
|
||||
|
||||
def record_middleware(self, name: str, hook: str, action: str, changes: dict) -> None:
|
||||
"""Record a middleware trace event.
|
||||
def record_middleware(self, tag: str, *, name: str, hook: str, action: str, changes: dict) -> None:
|
||||
"""Record a middleware state-change event.
|
||||
|
||||
Called by middleware implementations when they perform a meaningful
|
||||
state change (e.g., title generation, summarization, HITL approval).
|
||||
Pure-observation middleware should not call this.
|
||||
|
||||
Args:
|
||||
tag: Short identifier for the middleware (e.g., "title", "summarize",
|
||||
"guardrail"). Used to form event_type="middleware:{tag}".
|
||||
name: Full middleware class name.
|
||||
hook: Lifecycle hook that triggered the action (e.g., "after_model").
|
||||
action: Specific action performed (e.g., "generate_title").
|
||||
changes: Dict describing the state changes made.
|
||||
"""
|
||||
self._put(
|
||||
event_type="middleware",
|
||||
category="trace",
|
||||
event_type=f"middleware:{tag}",
|
||||
category="middleware",
|
||||
content={"name": name, "hook": hook, "action": action, "changes": changes},
|
||||
)
|
||||
|
||||
|
||||
@ -67,9 +67,9 @@ async def run_agent(
|
||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||
)
|
||||
|
||||
# Write human_message event
|
||||
user_input = _extract_user_input(graph_input)
|
||||
if user_input:
|
||||
# Write human_message event (model_dump format, aligned with checkpoint)
|
||||
human_msg = _extract_human_message(graph_input)
|
||||
if human_msg is not None:
|
||||
msg_metadata = {}
|
||||
if follow_up_to_run_id:
|
||||
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
|
||||
@ -78,10 +78,11 @@ async def run_agent(
|
||||
run_id=run_id,
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content={"role": "user", "content": user_input},
|
||||
content=human_msg.model_dump(),
|
||||
metadata=msg_metadata or None,
|
||||
)
|
||||
journal.set_first_human_message(user_input)
|
||||
content = human_msg.content
|
||||
journal.set_first_human_message(content if isinstance(content, str) else str(content))
|
||||
|
||||
# Track whether "events" was requested but skipped
|
||||
if "events" in requested_modes:
|
||||
@ -282,21 +283,29 @@ def _lg_mode_to_sse_event(mode: str) -> str:
|
||||
return mode
|
||||
|
||||
|
||||
def _extract_user_input(graph_input: dict) -> str:
|
||||
"""Extract user input text from graph_input for event recording."""
|
||||
def _extract_human_message(graph_input: dict) -> "HumanMessage | None":
|
||||
"""Extract or construct a HumanMessage from graph_input for event recording.
|
||||
|
||||
Returns a LangChain HumanMessage so callers can use .model_dump() to get
|
||||
the checkpoint-aligned serialization format.
|
||||
"""
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
messages = graph_input.get("messages")
|
||||
if not messages:
|
||||
return ""
|
||||
# Take the last message (usually the user's input)
|
||||
return None
|
||||
last = messages[-1] if isinstance(messages, list) else messages
|
||||
if isinstance(last, str):
|
||||
if isinstance(last, HumanMessage):
|
||||
return last
|
||||
if isinstance(last, str):
|
||||
return HumanMessage(content=last) if last else None
|
||||
if hasattr(last, "content"):
|
||||
content = last.content
|
||||
return content if isinstance(content, str) else str(content)
|
||||
return HumanMessage(content=content)
|
||||
if isinstance(last, dict):
|
||||
return str(last.get("content", ""))
|
||||
return ""
|
||||
content = last.get("content", "")
|
||||
return HumanMessage(content=content) if content else None
|
||||
return None
|
||||
|
||||
|
||||
def _unpack_stream_item(
|
||||
|
||||
@ -146,8 +146,11 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
||||
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
|
||||
)
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
fake_model = object()
|
||||
fake_model = MagicMock()
|
||||
fake_model.with_config.return_value = fake_model
|
||||
|
||||
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None):
|
||||
captured["name"] = name
|
||||
@ -163,3 +166,4 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
||||
assert captured["name"] == "model-masswork"
|
||||
assert captured["thinking_enabled"] is False
|
||||
assert middleware["model"] is fake_model
|
||||
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])
|
||||
|
||||
@ -20,24 +20,32 @@ def journal_setup():
|
||||
return j, store
|
||||
|
||||
|
||||
def _make_llm_response(content="Hello", usage=None, tool_calls=None):
|
||||
"""Create a mock LLM response with a message."""
|
||||
def _make_llm_response(content="Hello", usage=None, tool_calls=None, additional_kwargs=None):
|
||||
"""Create a mock LLM response with a message.
|
||||
|
||||
model_dump() returns checkpoint-aligned format matching real AIMessage.
|
||||
"""
|
||||
msg = MagicMock()
|
||||
msg.type = "ai"
|
||||
msg.content = content
|
||||
msg.id = f"msg-{id(msg)}"
|
||||
msg.tool_calls = tool_calls or []
|
||||
msg.invalid_tool_calls = []
|
||||
msg.response_metadata = {"model_name": "test-model"}
|
||||
msg.usage_metadata = usage
|
||||
# Provide a real model_dump so serialize_lc_object returns a plain dict
|
||||
# (needed for DB-backed tests where json.dumps must succeed).
|
||||
msg.additional_kwargs = additional_kwargs or {}
|
||||
msg.name = None
|
||||
# model_dump returns checkpoint-aligned format
|
||||
msg.model_dump.return_value = {
|
||||
"type": "ai",
|
||||
"content": content,
|
||||
"additional_kwargs": additional_kwargs or {},
|
||||
"response_metadata": {"model_name": "test-model"},
|
||||
"type": "ai",
|
||||
"name": None,
|
||||
"id": msg.id,
|
||||
"tool_calls": tool_calls or [],
|
||||
"invalid_tool_calls": [],
|
||||
"usage_metadata": usage,
|
||||
"response_metadata": {"model_name": "test-model"},
|
||||
}
|
||||
|
||||
gen = MagicMock()
|
||||
@ -71,7 +79,9 @@ class TestLlmCallbacks:
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "ai_message"
|
||||
assert messages[0]["content"] == {"role": "assistant", "content": "Answer"}
|
||||
# Content is checkpoint-aligned model_dump format
|
||||
assert messages[0]["content"]["type"] == "ai"
|
||||
assert messages[0]["content"]["content"] == "Answer"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_with_tool_calls_produces_ai_tool_call(self, journal_setup):
|
||||
@ -211,10 +221,14 @@ class TestCustomEvents:
|
||||
events = await store.list_events("t1", "r1")
|
||||
trace = [e for e in events if e["event_type"] == "summarization"]
|
||||
assert len(trace) == 1
|
||||
# Summarization goes to middleware category, not message
|
||||
mw_events = [e for e in events if e["event_type"] == "middleware:summarize"]
|
||||
assert len(mw_events) == 1
|
||||
assert mw_events[0]["category"] == "middleware"
|
||||
assert mw_events[0]["content"] == {"role": "system", "content": "Context was summarized."}
|
||||
# No message events from summarization
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "summary"
|
||||
assert messages[0]["content"] == {"role": "system", "content": "Context was summarized."}
|
||||
assert len(messages) == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_non_summarization_custom_event(self, journal_setup):
|
||||
@ -375,8 +389,11 @@ class TestDbBackedLifecycle:
|
||||
record = await mgr.create("t1", "lead_agent")
|
||||
run_id = record.run_id
|
||||
|
||||
# Write human_message
|
||||
await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content={"role": "user", "content": "Hello DB"})
|
||||
# Write human_message (checkpoint-aligned format)
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
human_msg = HumanMessage(content="Hello DB")
|
||||
await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content=human_msg.model_dump())
|
||||
|
||||
# Simulate journal
|
||||
journal = RunJournal(run_id, "t1", event_store, flush_threshold=100)
|
||||
@ -406,12 +423,14 @@ class TestDbBackedLifecycle:
|
||||
assert row["status"] == "success"
|
||||
assert row["total_tokens"] == 15
|
||||
|
||||
# Verify messages from DB
|
||||
# Verify messages from DB (checkpoint-aligned format)
|
||||
messages = await event_store.list_messages("t1")
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["event_type"] == "human_message"
|
||||
assert messages[0]["content"]["type"] == "human"
|
||||
assert messages[1]["event_type"] == "ai_message"
|
||||
assert messages[1]["content"] == {"role": "assistant", "content": "DB response"}
|
||||
assert messages[1]["content"]["type"] == "ai"
|
||||
assert messages[1]["content"]["content"] == "DB response"
|
||||
|
||||
# Verify events from DB
|
||||
events = await event_store.list_events("t1", run_id)
|
||||
@ -560,38 +579,45 @@ class TestDictContent:
|
||||
await close_engine()
|
||||
|
||||
|
||||
class TestOpenAIHumanMessage:
|
||||
class TestCheckpointAlignedHumanMessage:
|
||||
@pytest.mark.anyio
|
||||
async def test_human_message_openai_format(self):
|
||||
async def test_human_message_checkpoint_format(self):
|
||||
"""human_message content uses model_dump() checkpoint format."""
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
store = MemoryRunEventStore()
|
||||
human_msg = HumanMessage(content="What is AI?")
|
||||
await store.put(
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content={"role": "user", "content": "What is AI?"},
|
||||
content=human_msg.model_dump(),
|
||||
metadata={"message_id": "msg_001"},
|
||||
)
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == {"role": "user", "content": "What is AI?"}
|
||||
assert messages[0]["content"]["role"] == "user"
|
||||
assert messages[0]["content"]["type"] == "human"
|
||||
assert messages[0]["content"]["content"] == "What is AI?"
|
||||
|
||||
|
||||
class TestOpenAIMessageFormat:
|
||||
class TestCheckpointAlignedMessageFormat:
|
||||
@pytest.mark.anyio
|
||||
async def test_ai_message_openai_format(self, journal_setup):
|
||||
"""ai_message content should be OpenAI assistant message dict."""
|
||||
async def test_ai_message_checkpoint_format(self, journal_setup):
|
||||
"""ai_message content should be checkpoint-aligned model_dump dict."""
|
||||
j, store = journal_setup
|
||||
j.on_llm_end(_make_llm_response("Answer"), run_id=uuid4(), tags=["lead_agent"])
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == {"role": "assistant", "content": "Answer"}
|
||||
assert messages[0]["content"]["type"] == "ai"
|
||||
assert messages[0]["content"]["content"] == "Answer"
|
||||
assert "response_metadata" in messages[0]["content"]
|
||||
assert "additional_kwargs" in messages[0]["content"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_ai_tool_call_event(self, journal_setup):
|
||||
"""LLM response with tool_calls should produce ai_tool_call message event."""
|
||||
"""LLM response with tool_calls should produce ai_tool_call with model_dump content."""
|
||||
j, store = journal_setup
|
||||
tool_calls = [{"id": "call_1", "name": "search", "args": {"query": "test"}}]
|
||||
j.on_llm_end(
|
||||
@ -603,13 +629,12 @@ class TestOpenAIMessageFormat:
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "ai_tool_call"
|
||||
assert messages[0]["content"]["role"] == "assistant"
|
||||
assert messages[0]["content"]["type"] == "ai"
|
||||
assert messages[0]["content"]["content"] == "Let me search"
|
||||
assert len(messages[0]["content"]["tool_calls"]) == 1
|
||||
tc = messages[0]["content"]["tool_calls"][0]
|
||||
assert tc["id"] == "call_1"
|
||||
assert tc["type"] == "function"
|
||||
assert tc["function"]["name"] == "search"
|
||||
assert tc["name"] == "search"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_ai_tool_call_only_from_lead_agent(self, journal_setup):
|
||||
@ -637,11 +662,11 @@ class TestToolResultMessage:
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "tool_result"
|
||||
assert messages[0]["content"] == {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_abc",
|
||||
"content": "search results here",
|
||||
}
|
||||
# Content is checkpoint-aligned model_dump format
|
||||
assert messages[0]["content"]["type"] == "tool"
|
||||
assert messages[0]["content"]["tool_call_id"] == "call_abc"
|
||||
assert messages[0]["content"]["content"] == "search results here"
|
||||
assert messages[0]["content"]["name"] == "web_search"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_result_missing_tool_call_id(self, journal_setup):
|
||||
@ -652,15 +677,128 @@ class TestToolResultMessage:
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"]["role"] == "tool"
|
||||
assert messages[0]["content"]["type"] == "tool"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_error_no_tool_result_message(self, journal_setup):
|
||||
async def test_tool_end_extracts_from_tool_message_object(self, journal_setup):
|
||||
"""When LangChain passes a ToolMessage object as output, extract fields from it."""
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
tool_msg = ToolMessage(
|
||||
content="search results",
|
||||
tool_call_id="call_from_obj",
|
||||
name="web_search",
|
||||
status="success",
|
||||
)
|
||||
j.on_tool_end(tool_msg, run_id=run_id)
|
||||
await j.flush()
|
||||
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"]["type"] == "tool"
|
||||
assert messages[0]["content"]["tool_call_id"] == "call_from_obj"
|
||||
assert messages[0]["content"]["content"] == "search results"
|
||||
assert messages[0]["content"]["name"] == "web_search"
|
||||
assert messages[0]["metadata"]["tool_name"] == "web_search"
|
||||
assert messages[0]["metadata"]["status"] == "success"
|
||||
|
||||
events = await store.list_events("t1", "r1")
|
||||
tool_end = [e for e in events if e["event_type"] == "tool_end"][0]
|
||||
assert tool_end["metadata"]["tool_call_id"] == "call_from_obj"
|
||||
assert tool_end["metadata"]["tool_name"] == "web_search"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_message_object_overrides_kwargs(self, journal_setup):
|
||||
"""ToolMessage object fields take priority over kwargs."""
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
tool_msg = ToolMessage(
|
||||
content="result",
|
||||
tool_call_id="call_obj",
|
||||
name="tool_a",
|
||||
status="success",
|
||||
)
|
||||
# Pass different values in kwargs — ToolMessage should win
|
||||
j.on_tool_end(tool_msg, run_id=run_id, name="tool_b", tool_call_id="call_kwarg")
|
||||
await j.flush()
|
||||
|
||||
messages = await store.list_messages("t1")
|
||||
assert messages[0]["content"]["tool_call_id"] == "call_obj"
|
||||
assert messages[0]["content"]["name"] == "tool_a"
|
||||
assert messages[0]["metadata"]["tool_name"] == "tool_a"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_message_error_status(self, journal_setup):
|
||||
"""ToolMessage with status='error' propagates status to metadata."""
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
tool_msg = ToolMessage(
|
||||
content="something went wrong",
|
||||
tool_call_id="call_err",
|
||||
name="web_fetch",
|
||||
status="error",
|
||||
)
|
||||
j.on_tool_end(tool_msg, run_id=run_id)
|
||||
await j.flush()
|
||||
|
||||
events = await store.list_events("t1", "r1")
|
||||
tool_end = [e for e in events if e["event_type"] == "tool_end"][0]
|
||||
assert tool_end["metadata"]["status"] == "error"
|
||||
|
||||
messages = await store.list_messages("t1")
|
||||
assert messages[0]["content"]["status"] == "error"
|
||||
assert messages[0]["metadata"]["status"] == "error"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_message_fallback_to_cache(self, journal_setup):
|
||||
"""If ToolMessage has empty tool_call_id, fall back to cache from on_tool_start."""
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_tool_start({"name": "bash"}, "ls", run_id=run_id, tool_call_id="call_cached")
|
||||
tool_msg = ToolMessage(
|
||||
content="file list",
|
||||
tool_call_id="",
|
||||
name="bash",
|
||||
)
|
||||
j.on_tool_end(tool_msg, run_id=run_id)
|
||||
await j.flush()
|
||||
|
||||
messages = await store.list_messages("t1")
|
||||
assert messages[0]["content"]["tool_call_id"] == "call_cached"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_error_produces_tool_result_message(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch", tool_call_id="call_1")
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 0
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "tool_result"
|
||||
assert messages[0]["content"]["type"] == "tool"
|
||||
assert messages[0]["content"]["tool_call_id"] == "call_1"
|
||||
assert "timeout" in messages[0]["content"]["content"]
|
||||
assert messages[0]["content"]["status"] == "error"
|
||||
assert messages[0]["metadata"]["status"] == "error"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_error_uses_cached_tool_call_id(self, journal_setup):
|
||||
"""on_tool_error should fall back to cached tool_call_id from on_tool_start."""
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_tool_start({"name": "web_fetch"}, "url", run_id=run_id, tool_call_id="call_cached")
|
||||
j.on_tool_error(TimeoutError("timeout"), run_id=run_id, name="web_fetch")
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"]["tool_call_id"] == "call_cached"
|
||||
|
||||
|
||||
def _make_base_messages():
|
||||
@ -745,11 +883,12 @@ class TestLlmRequestResponse:
|
||||
assert not any(e["event_type"] == "llm_start" for e in events)
|
||||
|
||||
|
||||
class TestMiddlewareTrace:
|
||||
class TestMiddlewareEvents:
|
||||
@pytest.mark.anyio
|
||||
async def test_record_middleware(self, journal_setup):
|
||||
async def test_record_middleware_uses_middleware_category(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
j.record_middleware(
|
||||
"title",
|
||||
name="TitleMiddleware",
|
||||
hook="after_model",
|
||||
action="generate_title",
|
||||
@ -757,27 +896,60 @@ class TestMiddlewareTrace:
|
||||
)
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
mw_events = [e for e in events if e["event_type"] == "middleware"]
|
||||
mw_events = [e for e in events if e["event_type"] == "middleware:title"]
|
||||
assert len(mw_events) == 1
|
||||
assert mw_events[0]["category"] == "trace"
|
||||
assert mw_events[0]["category"] == "middleware"
|
||||
assert mw_events[0]["content"]["name"] == "TitleMiddleware"
|
||||
assert mw_events[0]["content"]["hook"] == "after_model"
|
||||
assert mw_events[0]["content"]["action"] == "generate_title"
|
||||
assert mw_events[0]["content"]["changes"]["title"] == "Test Title"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_middleware_events_not_in_messages(self, journal_setup):
|
||||
"""Middleware events should not appear in list_messages()."""
|
||||
j, store = journal_setup
|
||||
j.record_middleware(
|
||||
"title",
|
||||
name="TitleMiddleware",
|
||||
hook="after_model",
|
||||
action="generate_title",
|
||||
changes={"title": "Test"},
|
||||
)
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_middleware_tag_variants(self, journal_setup):
|
||||
"""Different middleware tags produce distinct event_types."""
|
||||
j, store = journal_setup
|
||||
j.record_middleware("title", name="TitleMiddleware", hook="after_model", action="generate_title", changes={})
|
||||
j.record_middleware("guardrail", name="GuardrailMiddleware", hook="before_tool", action="deny", changes={})
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
event_types = {e["event_type"] for e in events}
|
||||
assert "middleware:title" in event_types
|
||||
assert "middleware:guardrail" in event_types
|
||||
|
||||
|
||||
class TestFullRunSequence:
|
||||
@pytest.mark.anyio
|
||||
async def test_complete_run_event_sequence(self):
|
||||
"""Simulate a full run: user -> LLM -> tool_call -> tool_result -> LLM -> final reply."""
|
||||
"""Simulate a full run: user -> LLM -> tool_call -> tool_result -> LLM -> final reply.
|
||||
|
||||
All message events use checkpoint-aligned model_dump format.
|
||||
"""
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
store = MemoryRunEventStore()
|
||||
j = RunJournal("r1", "t1", store, flush_threshold=100)
|
||||
|
||||
# 1. Human message (written by worker, not journal)
|
||||
# 1. Human message (written by worker, using model_dump format)
|
||||
human_msg = HumanMessage(content="Search for quantum computing")
|
||||
await store.put(
|
||||
thread_id="t1", run_id="r1",
|
||||
event_type="human_message", category="message",
|
||||
content={"role": "user", "content": "Search for quantum computing"},
|
||||
content=human_msg.model_dump(),
|
||||
)
|
||||
j.set_first_human_message("Search for quantum computing")
|
||||
|
||||
@ -805,7 +977,7 @@ class TestFullRunSequence:
|
||||
j.on_tool_end("Quantum computing results...", run_id=tool_id, name="web_search", tool_call_id="call_1")
|
||||
|
||||
# 5. Middleware: title generation
|
||||
j.record_middleware("TitleMiddleware", "after_model", "generate_title", {"title": "Quantum Computing"})
|
||||
j.record_middleware("title", name="TitleMiddleware", hook="after_model", action="generate_title", changes={"title": "Quantum Computing"})
|
||||
|
||||
# 6. Second LLM call -> final reply
|
||||
llm2_id = uuid4()
|
||||
@ -824,18 +996,19 @@ class TestFullRunSequence:
|
||||
await asyncio.sleep(0.05)
|
||||
await j.flush()
|
||||
|
||||
# Verify message sequence (what gets exported for training)
|
||||
# Verify message sequence
|
||||
messages = await store.list_messages("t1")
|
||||
msg_types = [m["event_type"] for m in messages]
|
||||
assert msg_types == ["human_message", "ai_tool_call", "tool_result", "ai_message"]
|
||||
|
||||
# Verify message content format
|
||||
assert messages[0]["content"]["role"] == "user"
|
||||
assert messages[1]["content"]["role"] == "assistant"
|
||||
# Verify checkpoint-aligned format: all messages use "type" not "role"
|
||||
assert messages[0]["content"]["type"] == "human"
|
||||
assert messages[0]["content"]["content"] == "Search for quantum computing"
|
||||
assert messages[1]["content"]["type"] == "ai"
|
||||
assert "tool_calls" in messages[1]["content"]
|
||||
assert messages[2]["content"]["role"] == "tool"
|
||||
assert messages[2]["content"]["type"] == "tool"
|
||||
assert messages[2]["content"]["tool_call_id"] == "call_1"
|
||||
assert messages[3]["content"]["role"] == "assistant"
|
||||
assert messages[3]["content"]["type"] == "ai"
|
||||
assert messages[3]["content"]["content"] == "Here are the results about quantum computing..."
|
||||
|
||||
# Verify trace events
|
||||
@ -845,10 +1018,14 @@ class TestFullRunSequence:
|
||||
assert "llm_response" in trace_types
|
||||
assert "tool_start" in trace_types
|
||||
assert "tool_end" in trace_types
|
||||
assert "middleware" in trace_types
|
||||
assert "llm_start" not in trace_types
|
||||
assert "llm_end" not in trace_types
|
||||
|
||||
# Verify middleware events are in their own category
|
||||
mw_events = [e for e in events if e["category"] == "middleware"]
|
||||
assert len(mw_events) == 1
|
||||
assert mw_events[0]["event_type"] == "middleware:title"
|
||||
|
||||
# Verify token accumulation
|
||||
data = j.get_completion_data()
|
||||
assert data["total_tokens"] == 420 # 120 + 300
|
||||
@ -857,7 +1034,7 @@ class TestFullRunSequence:
|
||||
assert data["message_count"] == 1 # only final ai_message counts
|
||||
assert data["last_ai_message"] == "Here are the results about quantum computing..."
|
||||
|
||||
# Verify training data export is trivial
|
||||
training_messages = [m["content"] for m in messages]
|
||||
assert all(isinstance(m, dict) for m in training_messages)
|
||||
assert all("role" in m for m in training_messages)
|
||||
# Verify all message contents are checkpoint-aligned dicts with "type" field
|
||||
for m in messages:
|
||||
assert isinstance(m["content"], dict)
|
||||
assert "type" in m["content"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user