feat(events): align message events with checkpoint format and add middleware tag injection

- Message events (ai_message, ai_tool_call, tool_result, human_message) now use
  BaseMessage.model_dump() format, matching LangGraph checkpoint values.messages
- on_tool_end extracts tool_call_id/name/status from ToolMessage objects
- on_tool_error now emits tool_result message events with error status
- record_middleware uses middleware:{tag} event_type and middleware category
- Summarization custom events use middleware:summarize category
- TitleMiddleware injects middleware:title tag via get_config() inheritance
- SummarizationMiddleware model bound with middleware:summarize tag
- Worker writes human_message using HumanMessage.model_dump()

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng 2026-04-04 20:52:27 +08:00
parent 2d135aad0f
commit 52e7acafee
6 changed files with 356 additions and 98 deletions

View File

@ -56,13 +56,15 @@ def _create_summarization_middleware() -> SummarizationMiddleware | None:
# Prepare keep parameter
keep = config.keep.to_tuple()
# Prepare model parameter
# Prepare model parameter.
# Bind "middleware:summarize" tag so RunJournal identifies these LLM calls
# as middleware rather than lead_agent (SummarizationMiddleware is a
# LangChain built-in, so we tag the model at creation time).
if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False)
else:
# Use a lightweight model for summarization to save costs
# Falls back to default model if not explicitly specified
model = create_chat_model(thinking_enabled=False)
model = model.with_config(tags=["middleware:summarize"])
# Prepare kwargs
kwargs = {

View File

@ -1,10 +1,11 @@
"""Middleware for automatic thread title generation."""
import logging
from typing import NotRequired, override
from typing import Any, NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.config.title_config import get_title_config
@ -100,6 +101,20 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
return user_msg[:fallback_chars].rstrip() + "..."
return user_msg if user_msg else "New Conversation"
def _get_runnable_config(self) -> dict[str, Any]:
"""Inherit the parent RunnableConfig and add middleware tag.
This ensures RunJournal identifies LLM calls from this middleware
as ``middleware:title`` instead of ``lead_agent``.
"""
try:
parent = get_config()
except Exception:
parent = {}
config = {**parent}
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
return config
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Synchronously generate a title. Returns state update or None."""
if not self._should_generate_title(state):
@ -110,7 +125,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
model = create_chat_model(name=config.model_name, thinking_enabled=False)
try:
response = model.invoke(prompt)
response = model.invoke(prompt, config=self._get_runnable_config())
title = self._parse_title(response.content)
if not title:
title = self._fallback_title(user_msg)
@ -130,7 +145,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
model = create_chat_model(name=config.model_name, thinking_enabled=False)
try:
response = await model.ainvoke(prompt)
response = await model.ainvoke(prompt, config=self._get_runnable_config())
title = self._parse_title(response.content)
if not title:
title = self._fallback_title(user_msg)

View File

@ -179,7 +179,8 @@ class RunJournal(BaseCallbackHandler):
},
)
# Message events: only lead_agent gets message-category events
# Message events: only lead_agent gets message-category events.
# Content uses message.model_dump() to align with checkpoint format.
tool_calls = getattr(message, "tool_calls", None) or []
if caller == "lead_agent":
resp_meta = getattr(message, "response_metadata", None) or {}
@ -189,7 +190,7 @@ class RunJournal(BaseCallbackHandler):
self._put(
event_type="ai_tool_call",
category="message",
content=langchain_to_openai_message(message),
content=message.model_dump(),
metadata={"model_name": model_name, "finish_reason": "tool_calls"},
)
elif isinstance(content, str) and content:
@ -197,10 +198,10 @@ class RunJournal(BaseCallbackHandler):
self._put(
event_type="ai_message",
category="message",
content={"role": "assistant", "content": content},
content=message.model_dump(),
metadata={"model_name": model_name, "finish_reason": "stop"},
)
self._last_ai_msg = content[:2000]
self._last_ai_msg = content
self._msg_count += 1
# Token accumulation
@ -242,45 +243,87 @@ class RunJournal(BaseCallbackHandler):
},
)
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
from langchain_core.messages import ToolMessage
# Extract fields from ToolMessage object when LangChain provides one.
# LangChain's _format_output wraps tool results into a ToolMessage
# with tool_call_id, name, status, and artifact — more complete than
# what kwargs alone provides.
if isinstance(output, ToolMessage):
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = output.name or kwargs.get("name", "")
status = getattr(output, "status", "success") or "success"
content_str = output.content if isinstance(output.content, str) else str(output.content)
# Use model_dump() for checkpoint-aligned message content.
# Override tool_call_id if it was resolved from cache.
msg_content = output.model_dump()
if msg_content.get("tool_call_id") != tool_call_id:
msg_content["tool_call_id"] = tool_call_id
else:
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
status = "success"
content_str = str(output)
# Construct checkpoint-aligned dict when output is a plain string.
msg_content = ToolMessage(
content=content_str,
tool_call_id=tool_call_id or "",
name=tool_name,
status=status,
).model_dump()
# Trace event (always)
self._put(
event_type="tool_end",
category="trace",
content=str(output),
content=content_str,
metadata={
"tool_name": tool_name,
"tool_call_id": tool_call_id,
"status": "success",
"status": status,
},
)
# Message event: tool_result
# Message event: tool_result (checkpoint-aligned model_dump format)
self._put(
event_type="tool_result",
category="message",
content={
"role": "tool",
"tool_call_id": tool_call_id or "",
"content": str(output),
},
metadata={"tool_name": tool_name},
content=msg_content,
metadata={"tool_name": tool_name, "status": status},
)
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
from langchain_core.messages import ToolMessage
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
# Trace event
self._put(
event_type="tool_error",
category="trace",
content=str(error),
metadata={
"tool_name": kwargs.get("name", ""),
"tool_call_id": kwargs.get("tool_call_id"),
"tool_name": tool_name,
"tool_call_id": tool_call_id,
},
)
# Message event: tool_result with error status (checkpoint-aligned)
msg_content = ToolMessage(
content=str(error),
tool_call_id=tool_call_id or "",
name=tool_name,
status="error",
).model_dump()
self._put(
event_type="tool_result",
category="message",
content=msg_content,
metadata={"tool_name": tool_name, "status": "error"},
)
# -- Custom event callback --
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
@ -298,8 +341,8 @@ class RunJournal(BaseCallbackHandler):
},
)
self._put(
event_type="summary",
category="message",
event_type="middleware:summarize",
category="middleware",
content={"role": "system", "content": data_dict.get("summary", "")},
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
)
@ -366,16 +409,24 @@ class RunJournal(BaseCallbackHandler):
"""Record the first human message for convenience fields."""
self._first_human_msg = content[:2000] if content else None
def record_middleware(self, name: str, hook: str, action: str, changes: dict) -> None:
"""Record a middleware trace event.
def record_middleware(self, tag: str, *, name: str, hook: str, action: str, changes: dict) -> None:
"""Record a middleware state-change event.
Called by middleware implementations when they perform a meaningful
state change (e.g., title generation, summarization, HITL approval).
Pure-observation middleware should not call this.
Args:
tag: Short identifier for the middleware (e.g., "title", "summarize",
"guardrail"). Used to form event_type="middleware:{tag}".
name: Full middleware class name.
hook: Lifecycle hook that triggered the action (e.g., "after_model").
action: Specific action performed (e.g., "generate_title").
changes: Dict describing the state changes made.
"""
self._put(
event_type="middleware",
category="trace",
event_type=f"middleware:{tag}",
category="middleware",
content={"name": name, "hook": hook, "action": action, "changes": changes},
)

View File

@ -67,9 +67,9 @@ async def run_agent(
track_token_usage=getattr(run_events_config, "track_token_usage", True),
)
# Write human_message event
user_input = _extract_user_input(graph_input)
if user_input:
# Write human_message event (model_dump format, aligned with checkpoint)
human_msg = _extract_human_message(graph_input)
if human_msg is not None:
msg_metadata = {}
if follow_up_to_run_id:
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
@ -78,10 +78,11 @@ async def run_agent(
run_id=run_id,
event_type="human_message",
category="message",
content={"role": "user", "content": user_input},
content=human_msg.model_dump(),
metadata=msg_metadata or None,
)
journal.set_first_human_message(user_input)
content = human_msg.content
journal.set_first_human_message(content if isinstance(content, str) else str(content))
# Track whether "events" was requested but skipped
if "events" in requested_modes:
@ -282,21 +283,29 @@ def _lg_mode_to_sse_event(mode: str) -> str:
return mode
def _extract_user_input(graph_input: dict) -> str:
"""Extract user input text from graph_input for event recording."""
def _extract_human_message(graph_input: dict) -> "HumanMessage | None":
"""Extract or construct a HumanMessage from graph_input for event recording.
Returns a LangChain HumanMessage so callers can use .model_dump() to get
the checkpoint-aligned serialization format.
"""
from langchain_core.messages import HumanMessage
messages = graph_input.get("messages")
if not messages:
return ""
# Take the last message (usually the user's input)
return None
last = messages[-1] if isinstance(messages, list) else messages
if isinstance(last, str):
if isinstance(last, HumanMessage):
return last
if isinstance(last, str):
return HumanMessage(content=last) if last else None
if hasattr(last, "content"):
content = last.content
return content if isinstance(content, str) else str(content)
return HumanMessage(content=content)
if isinstance(last, dict):
return str(last.get("content", ""))
return ""
content = last.get("content", "")
return HumanMessage(content=content) if content else None
return None
def _unpack_stream_item(

View File

@ -146,8 +146,11 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
)
from unittest.mock import MagicMock
captured: dict[str, object] = {}
fake_model = object()
fake_model = MagicMock()
fake_model.with_config.return_value = fake_model
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None):
captured["name"] = name
@ -163,3 +166,4 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
assert captured["name"] == "model-masswork"
assert captured["thinking_enabled"] is False
assert middleware["model"] is fake_model
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])

View File

@ -20,24 +20,32 @@ def journal_setup():
return j, store
def _make_llm_response(content="Hello", usage=None, tool_calls=None):
"""Create a mock LLM response with a message."""
def _make_llm_response(content="Hello", usage=None, tool_calls=None, additional_kwargs=None):
"""Create a mock LLM response with a message.
model_dump() returns checkpoint-aligned format matching real AIMessage.
"""
msg = MagicMock()
msg.type = "ai"
msg.content = content
msg.id = f"msg-{id(msg)}"
msg.tool_calls = tool_calls or []
msg.invalid_tool_calls = []
msg.response_metadata = {"model_name": "test-model"}
msg.usage_metadata = usage
# Provide a real model_dump so serialize_lc_object returns a plain dict
# (needed for DB-backed tests where json.dumps must succeed).
msg.additional_kwargs = additional_kwargs or {}
msg.name = None
# model_dump returns checkpoint-aligned format
msg.model_dump.return_value = {
"type": "ai",
"content": content,
"additional_kwargs": additional_kwargs or {},
"response_metadata": {"model_name": "test-model"},
"type": "ai",
"name": None,
"id": msg.id,
"tool_calls": tool_calls or [],
"invalid_tool_calls": [],
"usage_metadata": usage,
"response_metadata": {"model_name": "test-model"},
}
gen = MagicMock()
@ -71,7 +79,9 @@ class TestLlmCallbacks:
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["event_type"] == "ai_message"
assert messages[0]["content"] == {"role": "assistant", "content": "Answer"}
# Content is checkpoint-aligned model_dump format
assert messages[0]["content"]["type"] == "ai"
assert messages[0]["content"]["content"] == "Answer"
@pytest.mark.anyio
async def test_on_llm_end_with_tool_calls_produces_ai_tool_call(self, journal_setup):
@ -211,10 +221,14 @@ class TestCustomEvents:
events = await store.list_events("t1", "r1")
trace = [e for e in events if e["event_type"] == "summarization"]
assert len(trace) == 1
# Summarization goes to middleware category, not message
mw_events = [e for e in events if e["event_type"] == "middleware:summarize"]
assert len(mw_events) == 1
assert mw_events[0]["category"] == "middleware"
assert mw_events[0]["content"] == {"role": "system", "content": "Context was summarized."}
# No message events from summarization
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["event_type"] == "summary"
assert messages[0]["content"] == {"role": "system", "content": "Context was summarized."}
assert len(messages) == 0
@pytest.mark.anyio
async def test_non_summarization_custom_event(self, journal_setup):
@ -375,8 +389,11 @@ class TestDbBackedLifecycle:
record = await mgr.create("t1", "lead_agent")
run_id = record.run_id
# Write human_message
await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content={"role": "user", "content": "Hello DB"})
# Write human_message (checkpoint-aligned format)
from langchain_core.messages import HumanMessage
human_msg = HumanMessage(content="Hello DB")
await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content=human_msg.model_dump())
# Simulate journal
journal = RunJournal(run_id, "t1", event_store, flush_threshold=100)
@ -406,12 +423,14 @@ class TestDbBackedLifecycle:
assert row["status"] == "success"
assert row["total_tokens"] == 15
# Verify messages from DB
# Verify messages from DB (checkpoint-aligned format)
messages = await event_store.list_messages("t1")
assert len(messages) == 2
assert messages[0]["event_type"] == "human_message"
assert messages[0]["content"]["type"] == "human"
assert messages[1]["event_type"] == "ai_message"
assert messages[1]["content"] == {"role": "assistant", "content": "DB response"}
assert messages[1]["content"]["type"] == "ai"
assert messages[1]["content"]["content"] == "DB response"
# Verify events from DB
events = await event_store.list_events("t1", run_id)
@ -560,38 +579,45 @@ class TestDictContent:
await close_engine()
class TestOpenAIHumanMessage:
class TestCheckpointAlignedHumanMessage:
@pytest.mark.anyio
async def test_human_message_openai_format(self):
async def test_human_message_checkpoint_format(self):
"""human_message content uses model_dump() checkpoint format."""
from langchain_core.messages import HumanMessage
store = MemoryRunEventStore()
human_msg = HumanMessage(content="What is AI?")
await store.put(
thread_id="t1",
run_id="r1",
event_type="human_message",
category="message",
content={"role": "user", "content": "What is AI?"},
content=human_msg.model_dump(),
metadata={"message_id": "msg_001"},
)
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["content"] == {"role": "user", "content": "What is AI?"}
assert messages[0]["content"]["role"] == "user"
assert messages[0]["content"]["type"] == "human"
assert messages[0]["content"]["content"] == "What is AI?"
class TestOpenAIMessageFormat:
class TestCheckpointAlignedMessageFormat:
@pytest.mark.anyio
async def test_ai_message_openai_format(self, journal_setup):
"""ai_message content should be OpenAI assistant message dict."""
async def test_ai_message_checkpoint_format(self, journal_setup):
"""ai_message content should be checkpoint-aligned model_dump dict."""
j, store = journal_setup
j.on_llm_end(_make_llm_response("Answer"), run_id=uuid4(), tags=["lead_agent"])
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["content"] == {"role": "assistant", "content": "Answer"}
assert messages[0]["content"]["type"] == "ai"
assert messages[0]["content"]["content"] == "Answer"
assert "response_metadata" in messages[0]["content"]
assert "additional_kwargs" in messages[0]["content"]
@pytest.mark.anyio
async def test_ai_tool_call_event(self, journal_setup):
"""LLM response with tool_calls should produce ai_tool_call message event."""
"""LLM response with tool_calls should produce ai_tool_call with model_dump content."""
j, store = journal_setup
tool_calls = [{"id": "call_1", "name": "search", "args": {"query": "test"}}]
j.on_llm_end(
@ -603,13 +629,12 @@ class TestOpenAIMessageFormat:
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["event_type"] == "ai_tool_call"
assert messages[0]["content"]["role"] == "assistant"
assert messages[0]["content"]["type"] == "ai"
assert messages[0]["content"]["content"] == "Let me search"
assert len(messages[0]["content"]["tool_calls"]) == 1
tc = messages[0]["content"]["tool_calls"][0]
assert tc["id"] == "call_1"
assert tc["type"] == "function"
assert tc["function"]["name"] == "search"
assert tc["name"] == "search"
@pytest.mark.anyio
async def test_ai_tool_call_only_from_lead_agent(self, journal_setup):
@ -637,11 +662,11 @@ class TestToolResultMessage:
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["event_type"] == "tool_result"
assert messages[0]["content"] == {
"role": "tool",
"tool_call_id": "call_abc",
"content": "search results here",
}
# Content is checkpoint-aligned model_dump format
assert messages[0]["content"]["type"] == "tool"
assert messages[0]["content"]["tool_call_id"] == "call_abc"
assert messages[0]["content"]["content"] == "search results here"
assert messages[0]["content"]["name"] == "web_search"
@pytest.mark.anyio
async def test_tool_result_missing_tool_call_id(self, journal_setup):
@ -652,15 +677,128 @@ class TestToolResultMessage:
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["content"]["role"] == "tool"
assert messages[0]["content"]["type"] == "tool"
@pytest.mark.anyio
async def test_tool_error_no_tool_result_message(self, journal_setup):
async def test_tool_end_extracts_from_tool_message_object(self, journal_setup):
"""When LangChain passes a ToolMessage object as output, extract fields from it."""
from langchain_core.messages import ToolMessage
j, store = journal_setup
run_id = uuid4()
tool_msg = ToolMessage(
content="search results",
tool_call_id="call_from_obj",
name="web_search",
status="success",
)
j.on_tool_end(tool_msg, run_id=run_id)
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["content"]["type"] == "tool"
assert messages[0]["content"]["tool_call_id"] == "call_from_obj"
assert messages[0]["content"]["content"] == "search results"
assert messages[0]["content"]["name"] == "web_search"
assert messages[0]["metadata"]["tool_name"] == "web_search"
assert messages[0]["metadata"]["status"] == "success"
events = await store.list_events("t1", "r1")
tool_end = [e for e in events if e["event_type"] == "tool_end"][0]
assert tool_end["metadata"]["tool_call_id"] == "call_from_obj"
assert tool_end["metadata"]["tool_name"] == "web_search"
@pytest.mark.anyio
async def test_tool_message_object_overrides_kwargs(self, journal_setup):
"""ToolMessage object fields take priority over kwargs."""
from langchain_core.messages import ToolMessage
j, store = journal_setup
run_id = uuid4()
tool_msg = ToolMessage(
content="result",
tool_call_id="call_obj",
name="tool_a",
status="success",
)
# Pass different values in kwargs — ToolMessage should win
j.on_tool_end(tool_msg, run_id=run_id, name="tool_b", tool_call_id="call_kwarg")
await j.flush()
messages = await store.list_messages("t1")
assert messages[0]["content"]["tool_call_id"] == "call_obj"
assert messages[0]["content"]["name"] == "tool_a"
assert messages[0]["metadata"]["tool_name"] == "tool_a"
@pytest.mark.anyio
async def test_tool_message_error_status(self, journal_setup):
"""ToolMessage with status='error' propagates status to metadata."""
from langchain_core.messages import ToolMessage
j, store = journal_setup
run_id = uuid4()
tool_msg = ToolMessage(
content="something went wrong",
tool_call_id="call_err",
name="web_fetch",
status="error",
)
j.on_tool_end(tool_msg, run_id=run_id)
await j.flush()
events = await store.list_events("t1", "r1")
tool_end = [e for e in events if e["event_type"] == "tool_end"][0]
assert tool_end["metadata"]["status"] == "error"
messages = await store.list_messages("t1")
assert messages[0]["content"]["status"] == "error"
assert messages[0]["metadata"]["status"] == "error"
@pytest.mark.anyio
async def test_tool_message_fallback_to_cache(self, journal_setup):
"""If ToolMessage has empty tool_call_id, fall back to cache from on_tool_start."""
from langchain_core.messages import ToolMessage
j, store = journal_setup
run_id = uuid4()
j.on_tool_start({"name": "bash"}, "ls", run_id=run_id, tool_call_id="call_cached")
tool_msg = ToolMessage(
content="file list",
tool_call_id="",
name="bash",
)
j.on_tool_end(tool_msg, run_id=run_id)
await j.flush()
messages = await store.list_messages("t1")
assert messages[0]["content"]["tool_call_id"] == "call_cached"
@pytest.mark.anyio
async def test_tool_error_produces_tool_result_message(self, journal_setup):
j, store = journal_setup
j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch", tool_call_id="call_1")
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 0
assert len(messages) == 1
assert messages[0]["event_type"] == "tool_result"
assert messages[0]["content"]["type"] == "tool"
assert messages[0]["content"]["tool_call_id"] == "call_1"
assert "timeout" in messages[0]["content"]["content"]
assert messages[0]["content"]["status"] == "error"
assert messages[0]["metadata"]["status"] == "error"
@pytest.mark.anyio
async def test_tool_error_uses_cached_tool_call_id(self, journal_setup):
"""on_tool_error should fall back to cached tool_call_id from on_tool_start."""
j, store = journal_setup
run_id = uuid4()
j.on_tool_start({"name": "web_fetch"}, "url", run_id=run_id, tool_call_id="call_cached")
j.on_tool_error(TimeoutError("timeout"), run_id=run_id, name="web_fetch")
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["content"]["tool_call_id"] == "call_cached"
def _make_base_messages():
@ -745,11 +883,12 @@ class TestLlmRequestResponse:
assert not any(e["event_type"] == "llm_start" for e in events)
class TestMiddlewareTrace:
class TestMiddlewareEvents:
@pytest.mark.anyio
async def test_record_middleware(self, journal_setup):
async def test_record_middleware_uses_middleware_category(self, journal_setup):
j, store = journal_setup
j.record_middleware(
"title",
name="TitleMiddleware",
hook="after_model",
action="generate_title",
@ -757,27 +896,60 @@ class TestMiddlewareTrace:
)
await j.flush()
events = await store.list_events("t1", "r1")
mw_events = [e for e in events if e["event_type"] == "middleware"]
mw_events = [e for e in events if e["event_type"] == "middleware:title"]
assert len(mw_events) == 1
assert mw_events[0]["category"] == "trace"
assert mw_events[0]["category"] == "middleware"
assert mw_events[0]["content"]["name"] == "TitleMiddleware"
assert mw_events[0]["content"]["hook"] == "after_model"
assert mw_events[0]["content"]["action"] == "generate_title"
assert mw_events[0]["content"]["changes"]["title"] == "Test Title"
@pytest.mark.anyio
async def test_middleware_events_not_in_messages(self, journal_setup):
"""Middleware events should not appear in list_messages()."""
j, store = journal_setup
j.record_middleware(
"title",
name="TitleMiddleware",
hook="after_model",
action="generate_title",
changes={"title": "Test"},
)
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 0
@pytest.mark.anyio
async def test_middleware_tag_variants(self, journal_setup):
"""Different middleware tags produce distinct event_types."""
j, store = journal_setup
j.record_middleware("title", name="TitleMiddleware", hook="after_model", action="generate_title", changes={})
j.record_middleware("guardrail", name="GuardrailMiddleware", hook="before_tool", action="deny", changes={})
await j.flush()
events = await store.list_events("t1", "r1")
event_types = {e["event_type"] for e in events}
assert "middleware:title" in event_types
assert "middleware:guardrail" in event_types
class TestFullRunSequence:
@pytest.mark.anyio
async def test_complete_run_event_sequence(self):
"""Simulate a full run: user -> LLM -> tool_call -> tool_result -> LLM -> final reply."""
"""Simulate a full run: user -> LLM -> tool_call -> tool_result -> LLM -> final reply.
All message events use checkpoint-aligned model_dump format.
"""
from langchain_core.messages import HumanMessage
store = MemoryRunEventStore()
j = RunJournal("r1", "t1", store, flush_threshold=100)
# 1. Human message (written by worker, not journal)
# 1. Human message (written by worker, using model_dump format)
human_msg = HumanMessage(content="Search for quantum computing")
await store.put(
thread_id="t1", run_id="r1",
event_type="human_message", category="message",
content={"role": "user", "content": "Search for quantum computing"},
content=human_msg.model_dump(),
)
j.set_first_human_message("Search for quantum computing")
@ -805,7 +977,7 @@ class TestFullRunSequence:
j.on_tool_end("Quantum computing results...", run_id=tool_id, name="web_search", tool_call_id="call_1")
# 5. Middleware: title generation
j.record_middleware("TitleMiddleware", "after_model", "generate_title", {"title": "Quantum Computing"})
j.record_middleware("title", name="TitleMiddleware", hook="after_model", action="generate_title", changes={"title": "Quantum Computing"})
# 6. Second LLM call -> final reply
llm2_id = uuid4()
@ -824,18 +996,19 @@ class TestFullRunSequence:
await asyncio.sleep(0.05)
await j.flush()
# Verify message sequence (what gets exported for training)
# Verify message sequence
messages = await store.list_messages("t1")
msg_types = [m["event_type"] for m in messages]
assert msg_types == ["human_message", "ai_tool_call", "tool_result", "ai_message"]
# Verify message content format
assert messages[0]["content"]["role"] == "user"
assert messages[1]["content"]["role"] == "assistant"
# Verify checkpoint-aligned format: all messages use "type" not "role"
assert messages[0]["content"]["type"] == "human"
assert messages[0]["content"]["content"] == "Search for quantum computing"
assert messages[1]["content"]["type"] == "ai"
assert "tool_calls" in messages[1]["content"]
assert messages[2]["content"]["role"] == "tool"
assert messages[2]["content"]["type"] == "tool"
assert messages[2]["content"]["tool_call_id"] == "call_1"
assert messages[3]["content"]["role"] == "assistant"
assert messages[3]["content"]["type"] == "ai"
assert messages[3]["content"]["content"] == "Here are the results about quantum computing..."
# Verify trace events
@ -845,10 +1018,14 @@ class TestFullRunSequence:
assert "llm_response" in trace_types
assert "tool_start" in trace_types
assert "tool_end" in trace_types
assert "middleware" in trace_types
assert "llm_start" not in trace_types
assert "llm_end" not in trace_types
# Verify middleware events are in their own category
mw_events = [e for e in events if e["category"] == "middleware"]
assert len(mw_events) == 1
assert mw_events[0]["event_type"] == "middleware:title"
# Verify token accumulation
data = j.get_completion_data()
assert data["total_tokens"] == 420 # 120 + 300
@ -857,7 +1034,7 @@ class TestFullRunSequence:
assert data["message_count"] == 1 # only final ai_message counts
assert data["last_ai_message"] == "Here are the results about quantum computing..."
# Verify training data export is trivial
training_messages = [m["content"] for m in messages]
assert all(isinstance(m, dict) for m in training_messages)
assert all("role" in m for m in training_messages)
# Verify all message contents are checkpoint-aligned dicts with "type" field
for m in messages:
assert isinstance(m["content"], dict)
assert "type" in m["content"]