diff --git a/backend/packages/harness/deerflow/agents/lead_agent/agent.py b/backend/packages/harness/deerflow/agents/lead_agent/agent.py index fe743e448..b1938608f 100644 --- a/backend/packages/harness/deerflow/agents/lead_agent/agent.py +++ b/backend/packages/harness/deerflow/agents/lead_agent/agent.py @@ -56,13 +56,15 @@ def _create_summarization_middleware() -> SummarizationMiddleware | None: # Prepare keep parameter keep = config.keep.to_tuple() - # Prepare model parameter + # Prepare model parameter. + # Bind "middleware:summarize" tag so RunJournal identifies these LLM calls + # as middleware rather than lead_agent (SummarizationMiddleware is a + # LangChain built-in, so we tag the model at creation time). if config.model_name: model = create_chat_model(name=config.model_name, thinking_enabled=False) else: - # Use a lightweight model for summarization to save costs - # Falls back to default model if not explicitly specified model = create_chat_model(thinking_enabled=False) + model = model.with_config(tags=["middleware:summarize"]) # Prepare kwargs kwargs = { diff --git a/backend/packages/harness/deerflow/agents/middlewares/title_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/title_middleware.py index 20cb02f68..289fc3120 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/title_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/title_middleware.py @@ -1,10 +1,11 @@ """Middleware for automatic thread title generation.""" import logging -from typing import NotRequired, override +from typing import Any, NotRequired, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware +from langgraph.config import get_config from langgraph.runtime import Runtime from deerflow.config.title_config import get_title_config @@ -100,6 +101,20 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]): return user_msg[:fallback_chars].rstrip() + "..." return user_msg if user_msg else "New Conversation" + def _get_runnable_config(self) -> dict[str, Any]: + """Inherit the parent RunnableConfig and add middleware tag. + + This ensures RunJournal identifies LLM calls from this middleware + as ``middleware:title`` instead of ``lead_agent``. + """ + try: + parent = get_config() + except Exception: + parent = {} + config = {**parent} + config["tags"] = [*(config.get("tags") or []), "middleware:title"] + return config + def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None: """Synchronously generate a title. Returns state update or None.""" if not self._should_generate_title(state): @@ -110,7 +125,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]): model = create_chat_model(name=config.model_name, thinking_enabled=False) try: - response = model.invoke(prompt) + response = model.invoke(prompt, config=self._get_runnable_config()) title = self._parse_title(response.content) if not title: title = self._fallback_title(user_msg) @@ -130,7 +145,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]): model = create_chat_model(name=config.model_name, thinking_enabled=False) try: - response = await model.ainvoke(prompt) + response = await model.ainvoke(prompt, config=self._get_runnable_config()) title = self._parse_title(response.content) if not title: title = self._fallback_title(user_msg) diff --git a/backend/packages/harness/deerflow/runtime/journal.py b/backend/packages/harness/deerflow/runtime/journal.py index dccc9f481..15e48bab0 100644 --- a/backend/packages/harness/deerflow/runtime/journal.py +++ b/backend/packages/harness/deerflow/runtime/journal.py @@ -179,7 +179,8 @@ class RunJournal(BaseCallbackHandler): }, ) - # Message events: only lead_agent gets message-category events + # Message events: only lead_agent gets message-category events. + # Content uses message.model_dump() to align with checkpoint format. tool_calls = getattr(message, "tool_calls", None) or [] if caller == "lead_agent": resp_meta = getattr(message, "response_metadata", None) or {} @@ -189,7 +190,7 @@ class RunJournal(BaseCallbackHandler): self._put( event_type="ai_tool_call", category="message", - content=langchain_to_openai_message(message), + content=message.model_dump(), metadata={"model_name": model_name, "finish_reason": "tool_calls"}, ) elif isinstance(content, str) and content: @@ -197,10 +198,10 @@ class RunJournal(BaseCallbackHandler): self._put( event_type="ai_message", category="message", - content={"role": "assistant", "content": content}, + content=message.model_dump(), metadata={"model_name": model_name, "finish_reason": "stop"}, ) - self._last_ai_msg = content[:2000] + self._last_ai_msg = content self._msg_count += 1 # Token accumulation @@ -242,45 +243,87 @@ class RunJournal(BaseCallbackHandler): }, ) - def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None: - tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None) - tool_name = kwargs.get("name", "") + def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None: + from langchain_core.messages import ToolMessage + + # Extract fields from ToolMessage object when LangChain provides one. + # LangChain's _format_output wraps tool results into a ToolMessage + # with tool_call_id, name, status, and artifact — more complete than + # what kwargs alone provides. + if isinstance(output, ToolMessage): + tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None) + tool_name = output.name or kwargs.get("name", "") + status = getattr(output, "status", "success") or "success" + content_str = output.content if isinstance(output.content, str) else str(output.content) + # Use model_dump() for checkpoint-aligned message content. + # Override tool_call_id if it was resolved from cache. + msg_content = output.model_dump() + if msg_content.get("tool_call_id") != tool_call_id: + msg_content["tool_call_id"] = tool_call_id + else: + tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None) + tool_name = kwargs.get("name", "") + status = "success" + content_str = str(output) + # Construct checkpoint-aligned dict when output is a plain string. + msg_content = ToolMessage( + content=content_str, + tool_call_id=tool_call_id or "", + name=tool_name, + status=status, + ).model_dump() # Trace event (always) self._put( event_type="tool_end", category="trace", - content=str(output), + content=content_str, metadata={ "tool_name": tool_name, "tool_call_id": tool_call_id, - "status": "success", + "status": status, }, ) - # Message event: tool_result + # Message event: tool_result (checkpoint-aligned model_dump format) self._put( event_type="tool_result", category="message", - content={ - "role": "tool", - "tool_call_id": tool_call_id or "", - "content": str(output), - }, - metadata={"tool_name": tool_name}, + content=msg_content, + metadata={"tool_name": tool_name, "status": status}, ) def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: + from langchain_core.messages import ToolMessage + + tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None) + tool_name = kwargs.get("name", "") + + # Trace event self._put( event_type="tool_error", category="trace", content=str(error), metadata={ - "tool_name": kwargs.get("name", ""), - "tool_call_id": kwargs.get("tool_call_id"), + "tool_name": tool_name, + "tool_call_id": tool_call_id, }, ) + # Message event: tool_result with error status (checkpoint-aligned) + msg_content = ToolMessage( + content=str(error), + tool_call_id=tool_call_id or "", + name=tool_name, + status="error", + ).model_dump() + self._put( + event_type="tool_result", + category="message", + content=msg_content, + metadata={"tool_name": tool_name, "status": "error"}, + ) + # -- Custom event callback -- def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None: @@ -298,8 +341,8 @@ class RunJournal(BaseCallbackHandler): }, ) self._put( - event_type="summary", - category="message", + event_type="middleware:summarize", + category="middleware", content={"role": "system", "content": data_dict.get("summary", "")}, metadata={"replaced_count": data_dict.get("replaced_count", 0)}, ) @@ -366,16 +409,24 @@ class RunJournal(BaseCallbackHandler): """Record the first human message for convenience fields.""" self._first_human_msg = content[:2000] if content else None - def record_middleware(self, name: str, hook: str, action: str, changes: dict) -> None: - """Record a middleware trace event. + def record_middleware(self, tag: str, *, name: str, hook: str, action: str, changes: dict) -> None: + """Record a middleware state-change event. Called by middleware implementations when they perform a meaningful state change (e.g., title generation, summarization, HITL approval). Pure-observation middleware should not call this. + + Args: + tag: Short identifier for the middleware (e.g., "title", "summarize", + "guardrail"). Used to form event_type="middleware:{tag}". + name: Full middleware class name. + hook: Lifecycle hook that triggered the action (e.g., "after_model"). + action: Specific action performed (e.g., "generate_title"). + changes: Dict describing the state changes made. """ self._put( - event_type="middleware", - category="trace", + event_type=f"middleware:{tag}", + category="middleware", content={"name": name, "hook": hook, "action": action, "changes": changes}, ) diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index 81bf959e8..34d5b7e0f 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -67,9 +67,9 @@ async def run_agent( track_token_usage=getattr(run_events_config, "track_token_usage", True), ) - # Write human_message event - user_input = _extract_user_input(graph_input) - if user_input: + # Write human_message event (model_dump format, aligned with checkpoint) + human_msg = _extract_human_message(graph_input) + if human_msg is not None: msg_metadata = {} if follow_up_to_run_id: msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id @@ -78,10 +78,11 @@ async def run_agent( run_id=run_id, event_type="human_message", category="message", - content={"role": "user", "content": user_input}, + content=human_msg.model_dump(), metadata=msg_metadata or None, ) - journal.set_first_human_message(user_input) + content = human_msg.content + journal.set_first_human_message(content if isinstance(content, str) else str(content)) # Track whether "events" was requested but skipped if "events" in requested_modes: @@ -282,21 +283,29 @@ def _lg_mode_to_sse_event(mode: str) -> str: return mode -def _extract_user_input(graph_input: dict) -> str: - """Extract user input text from graph_input for event recording.""" +def _extract_human_message(graph_input: dict) -> "HumanMessage | None": + """Extract or construct a HumanMessage from graph_input for event recording. + + Returns a LangChain HumanMessage so callers can use .model_dump() to get + the checkpoint-aligned serialization format. + """ + from langchain_core.messages import HumanMessage + messages = graph_input.get("messages") if not messages: - return "" - # Take the last message (usually the user's input) + return None last = messages[-1] if isinstance(messages, list) else messages - if isinstance(last, str): + if isinstance(last, HumanMessage): return last + if isinstance(last, str): + return HumanMessage(content=last) if last else None if hasattr(last, "content"): content = last.content - return content if isinstance(content, str) else str(content) + return HumanMessage(content=content) if isinstance(last, dict): - return str(last.get("content", "")) - return "" + content = last.get("content", "") + return HumanMessage(content=content) if content else None + return None def _unpack_stream_item( diff --git a/backend/tests/test_lead_agent_model_resolution.py b/backend/tests/test_lead_agent_model_resolution.py index 9373c2895..093baaa92 100644 --- a/backend/tests/test_lead_agent_model_resolution.py +++ b/backend/tests/test_lead_agent_model_resolution.py @@ -146,8 +146,11 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch lambda: SummarizationConfig(enabled=True, model_name="model-masswork"), ) + from unittest.mock import MagicMock + captured: dict[str, object] = {} - fake_model = object() + fake_model = MagicMock() + fake_model.with_config.return_value = fake_model def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None): captured["name"] = name @@ -163,3 +166,4 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch assert captured["name"] == "model-masswork" assert captured["thinking_enabled"] is False assert middleware["model"] is fake_model + fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"]) diff --git a/backend/tests/test_run_journal.py b/backend/tests/test_run_journal.py index e32f05cd1..4e6969393 100644 --- a/backend/tests/test_run_journal.py +++ b/backend/tests/test_run_journal.py @@ -20,24 +20,32 @@ def journal_setup(): return j, store -def _make_llm_response(content="Hello", usage=None, tool_calls=None): - """Create a mock LLM response with a message.""" +def _make_llm_response(content="Hello", usage=None, tool_calls=None, additional_kwargs=None): + """Create a mock LLM response with a message. + + model_dump() returns checkpoint-aligned format matching real AIMessage. + """ msg = MagicMock() msg.type = "ai" msg.content = content msg.id = f"msg-{id(msg)}" msg.tool_calls = tool_calls or [] + msg.invalid_tool_calls = [] msg.response_metadata = {"model_name": "test-model"} msg.usage_metadata = usage - # Provide a real model_dump so serialize_lc_object returns a plain dict - # (needed for DB-backed tests where json.dumps must succeed). + msg.additional_kwargs = additional_kwargs or {} + msg.name = None + # model_dump returns checkpoint-aligned format msg.model_dump.return_value = { - "type": "ai", "content": content, + "additional_kwargs": additional_kwargs or {}, + "response_metadata": {"model_name": "test-model"}, + "type": "ai", + "name": None, "id": msg.id, "tool_calls": tool_calls or [], + "invalid_tool_calls": [], "usage_metadata": usage, - "response_metadata": {"model_name": "test-model"}, } gen = MagicMock() @@ -71,7 +79,9 @@ class TestLlmCallbacks: messages = await store.list_messages("t1") assert len(messages) == 1 assert messages[0]["event_type"] == "ai_message" - assert messages[0]["content"] == {"role": "assistant", "content": "Answer"} + # Content is checkpoint-aligned model_dump format + assert messages[0]["content"]["type"] == "ai" + assert messages[0]["content"]["content"] == "Answer" @pytest.mark.anyio async def test_on_llm_end_with_tool_calls_produces_ai_tool_call(self, journal_setup): @@ -211,10 +221,14 @@ class TestCustomEvents: events = await store.list_events("t1", "r1") trace = [e for e in events if e["event_type"] == "summarization"] assert len(trace) == 1 + # Summarization goes to middleware category, not message + mw_events = [e for e in events if e["event_type"] == "middleware:summarize"] + assert len(mw_events) == 1 + assert mw_events[0]["category"] == "middleware" + assert mw_events[0]["content"] == {"role": "system", "content": "Context was summarized."} + # No message events from summarization messages = await store.list_messages("t1") - assert len(messages) == 1 - assert messages[0]["event_type"] == "summary" - assert messages[0]["content"] == {"role": "system", "content": "Context was summarized."} + assert len(messages) == 0 @pytest.mark.anyio async def test_non_summarization_custom_event(self, journal_setup): @@ -375,8 +389,11 @@ class TestDbBackedLifecycle: record = await mgr.create("t1", "lead_agent") run_id = record.run_id - # Write human_message - await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content={"role": "user", "content": "Hello DB"}) + # Write human_message (checkpoint-aligned format) + from langchain_core.messages import HumanMessage + + human_msg = HumanMessage(content="Hello DB") + await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content=human_msg.model_dump()) # Simulate journal journal = RunJournal(run_id, "t1", event_store, flush_threshold=100) @@ -406,12 +423,14 @@ class TestDbBackedLifecycle: assert row["status"] == "success" assert row["total_tokens"] == 15 - # Verify messages from DB + # Verify messages from DB (checkpoint-aligned format) messages = await event_store.list_messages("t1") assert len(messages) == 2 assert messages[0]["event_type"] == "human_message" + assert messages[0]["content"]["type"] == "human" assert messages[1]["event_type"] == "ai_message" - assert messages[1]["content"] == {"role": "assistant", "content": "DB response"} + assert messages[1]["content"]["type"] == "ai" + assert messages[1]["content"]["content"] == "DB response" # Verify events from DB events = await event_store.list_events("t1", run_id) @@ -560,38 +579,45 @@ class TestDictContent: await close_engine() -class TestOpenAIHumanMessage: +class TestCheckpointAlignedHumanMessage: @pytest.mark.anyio - async def test_human_message_openai_format(self): + async def test_human_message_checkpoint_format(self): + """human_message content uses model_dump() checkpoint format.""" + from langchain_core.messages import HumanMessage + store = MemoryRunEventStore() + human_msg = HumanMessage(content="What is AI?") await store.put( thread_id="t1", run_id="r1", event_type="human_message", category="message", - content={"role": "user", "content": "What is AI?"}, + content=human_msg.model_dump(), metadata={"message_id": "msg_001"}, ) messages = await store.list_messages("t1") assert len(messages) == 1 - assert messages[0]["content"] == {"role": "user", "content": "What is AI?"} - assert messages[0]["content"]["role"] == "user" + assert messages[0]["content"]["type"] == "human" + assert messages[0]["content"]["content"] == "What is AI?" -class TestOpenAIMessageFormat: +class TestCheckpointAlignedMessageFormat: @pytest.mark.anyio - async def test_ai_message_openai_format(self, journal_setup): - """ai_message content should be OpenAI assistant message dict.""" + async def test_ai_message_checkpoint_format(self, journal_setup): + """ai_message content should be checkpoint-aligned model_dump dict.""" j, store = journal_setup j.on_llm_end(_make_llm_response("Answer"), run_id=uuid4(), tags=["lead_agent"]) await j.flush() messages = await store.list_messages("t1") assert len(messages) == 1 - assert messages[0]["content"] == {"role": "assistant", "content": "Answer"} + assert messages[0]["content"]["type"] == "ai" + assert messages[0]["content"]["content"] == "Answer" + assert "response_metadata" in messages[0]["content"] + assert "additional_kwargs" in messages[0]["content"] @pytest.mark.anyio async def test_ai_tool_call_event(self, journal_setup): - """LLM response with tool_calls should produce ai_tool_call message event.""" + """LLM response with tool_calls should produce ai_tool_call with model_dump content.""" j, store = journal_setup tool_calls = [{"id": "call_1", "name": "search", "args": {"query": "test"}}] j.on_llm_end( @@ -603,13 +629,12 @@ class TestOpenAIMessageFormat: messages = await store.list_messages("t1") assert len(messages) == 1 assert messages[0]["event_type"] == "ai_tool_call" - assert messages[0]["content"]["role"] == "assistant" + assert messages[0]["content"]["type"] == "ai" assert messages[0]["content"]["content"] == "Let me search" assert len(messages[0]["content"]["tool_calls"]) == 1 tc = messages[0]["content"]["tool_calls"][0] assert tc["id"] == "call_1" - assert tc["type"] == "function" - assert tc["function"]["name"] == "search" + assert tc["name"] == "search" @pytest.mark.anyio async def test_ai_tool_call_only_from_lead_agent(self, journal_setup): @@ -637,11 +662,11 @@ class TestToolResultMessage: messages = await store.list_messages("t1") assert len(messages) == 1 assert messages[0]["event_type"] == "tool_result" - assert messages[0]["content"] == { - "role": "tool", - "tool_call_id": "call_abc", - "content": "search results here", - } + # Content is checkpoint-aligned model_dump format + assert messages[0]["content"]["type"] == "tool" + assert messages[0]["content"]["tool_call_id"] == "call_abc" + assert messages[0]["content"]["content"] == "search results here" + assert messages[0]["content"]["name"] == "web_search" @pytest.mark.anyio async def test_tool_result_missing_tool_call_id(self, journal_setup): @@ -652,15 +677,128 @@ class TestToolResultMessage: await j.flush() messages = await store.list_messages("t1") assert len(messages) == 1 - assert messages[0]["content"]["role"] == "tool" + assert messages[0]["content"]["type"] == "tool" @pytest.mark.anyio - async def test_tool_error_no_tool_result_message(self, journal_setup): + async def test_tool_end_extracts_from_tool_message_object(self, journal_setup): + """When LangChain passes a ToolMessage object as output, extract fields from it.""" + from langchain_core.messages import ToolMessage + + j, store = journal_setup + run_id = uuid4() + tool_msg = ToolMessage( + content="search results", + tool_call_id="call_from_obj", + name="web_search", + status="success", + ) + j.on_tool_end(tool_msg, run_id=run_id) + await j.flush() + + messages = await store.list_messages("t1") + assert len(messages) == 1 + assert messages[0]["content"]["type"] == "tool" + assert messages[0]["content"]["tool_call_id"] == "call_from_obj" + assert messages[0]["content"]["content"] == "search results" + assert messages[0]["content"]["name"] == "web_search" + assert messages[0]["metadata"]["tool_name"] == "web_search" + assert messages[0]["metadata"]["status"] == "success" + + events = await store.list_events("t1", "r1") + tool_end = [e for e in events if e["event_type"] == "tool_end"][0] + assert tool_end["metadata"]["tool_call_id"] == "call_from_obj" + assert tool_end["metadata"]["tool_name"] == "web_search" + + @pytest.mark.anyio + async def test_tool_message_object_overrides_kwargs(self, journal_setup): + """ToolMessage object fields take priority over kwargs.""" + from langchain_core.messages import ToolMessage + + j, store = journal_setup + run_id = uuid4() + tool_msg = ToolMessage( + content="result", + tool_call_id="call_obj", + name="tool_a", + status="success", + ) + # Pass different values in kwargs — ToolMessage should win + j.on_tool_end(tool_msg, run_id=run_id, name="tool_b", tool_call_id="call_kwarg") + await j.flush() + + messages = await store.list_messages("t1") + assert messages[0]["content"]["tool_call_id"] == "call_obj" + assert messages[0]["content"]["name"] == "tool_a" + assert messages[0]["metadata"]["tool_name"] == "tool_a" + + @pytest.mark.anyio + async def test_tool_message_error_status(self, journal_setup): + """ToolMessage with status='error' propagates status to metadata.""" + from langchain_core.messages import ToolMessage + + j, store = journal_setup + run_id = uuid4() + tool_msg = ToolMessage( + content="something went wrong", + tool_call_id="call_err", + name="web_fetch", + status="error", + ) + j.on_tool_end(tool_msg, run_id=run_id) + await j.flush() + + events = await store.list_events("t1", "r1") + tool_end = [e for e in events if e["event_type"] == "tool_end"][0] + assert tool_end["metadata"]["status"] == "error" + + messages = await store.list_messages("t1") + assert messages[0]["content"]["status"] == "error" + assert messages[0]["metadata"]["status"] == "error" + + @pytest.mark.anyio + async def test_tool_message_fallback_to_cache(self, journal_setup): + """If ToolMessage has empty tool_call_id, fall back to cache from on_tool_start.""" + from langchain_core.messages import ToolMessage + + j, store = journal_setup + run_id = uuid4() + j.on_tool_start({"name": "bash"}, "ls", run_id=run_id, tool_call_id="call_cached") + tool_msg = ToolMessage( + content="file list", + tool_call_id="", + name="bash", + ) + j.on_tool_end(tool_msg, run_id=run_id) + await j.flush() + + messages = await store.list_messages("t1") + assert messages[0]["content"]["tool_call_id"] == "call_cached" + + @pytest.mark.anyio + async def test_tool_error_produces_tool_result_message(self, journal_setup): j, store = journal_setup j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch", tool_call_id="call_1") await j.flush() messages = await store.list_messages("t1") - assert len(messages) == 0 + assert len(messages) == 1 + assert messages[0]["event_type"] == "tool_result" + assert messages[0]["content"]["type"] == "tool" + assert messages[0]["content"]["tool_call_id"] == "call_1" + assert "timeout" in messages[0]["content"]["content"] + assert messages[0]["content"]["status"] == "error" + assert messages[0]["metadata"]["status"] == "error" + + @pytest.mark.anyio + async def test_tool_error_uses_cached_tool_call_id(self, journal_setup): + """on_tool_error should fall back to cached tool_call_id from on_tool_start.""" + j, store = journal_setup + run_id = uuid4() + j.on_tool_start({"name": "web_fetch"}, "url", run_id=run_id, tool_call_id="call_cached") + j.on_tool_error(TimeoutError("timeout"), run_id=run_id, name="web_fetch") + await j.flush() + messages = await store.list_messages("t1") + assert len(messages) == 1 + assert messages[0]["content"]["tool_call_id"] == "call_cached" def _make_base_messages(): @@ -745,11 +883,12 @@ class TestLlmRequestResponse: assert not any(e["event_type"] == "llm_start" for e in events) -class TestMiddlewareTrace: +class TestMiddlewareEvents: @pytest.mark.anyio - async def test_record_middleware(self, journal_setup): + async def test_record_middleware_uses_middleware_category(self, journal_setup): j, store = journal_setup j.record_middleware( + "title", name="TitleMiddleware", hook="after_model", action="generate_title", @@ -757,27 +896,60 @@ class TestMiddlewareTrace: ) await j.flush() events = await store.list_events("t1", "r1") - mw_events = [e for e in events if e["event_type"] == "middleware"] + mw_events = [e for e in events if e["event_type"] == "middleware:title"] assert len(mw_events) == 1 - assert mw_events[0]["category"] == "trace" + assert mw_events[0]["category"] == "middleware" assert mw_events[0]["content"]["name"] == "TitleMiddleware" assert mw_events[0]["content"]["hook"] == "after_model" assert mw_events[0]["content"]["action"] == "generate_title" assert mw_events[0]["content"]["changes"]["title"] == "Test Title" + @pytest.mark.anyio + async def test_middleware_events_not_in_messages(self, journal_setup): + """Middleware events should not appear in list_messages().""" + j, store = journal_setup + j.record_middleware( + "title", + name="TitleMiddleware", + hook="after_model", + action="generate_title", + changes={"title": "Test"}, + ) + await j.flush() + messages = await store.list_messages("t1") + assert len(messages) == 0 + + @pytest.mark.anyio + async def test_middleware_tag_variants(self, journal_setup): + """Different middleware tags produce distinct event_types.""" + j, store = journal_setup + j.record_middleware("title", name="TitleMiddleware", hook="after_model", action="generate_title", changes={}) + j.record_middleware("guardrail", name="GuardrailMiddleware", hook="before_tool", action="deny", changes={}) + await j.flush() + events = await store.list_events("t1", "r1") + event_types = {e["event_type"] for e in events} + assert "middleware:title" in event_types + assert "middleware:guardrail" in event_types + class TestFullRunSequence: @pytest.mark.anyio async def test_complete_run_event_sequence(self): - """Simulate a full run: user -> LLM -> tool_call -> tool_result -> LLM -> final reply.""" + """Simulate a full run: user -> LLM -> tool_call -> tool_result -> LLM -> final reply. + + All message events use checkpoint-aligned model_dump format. + """ + from langchain_core.messages import HumanMessage + store = MemoryRunEventStore() j = RunJournal("r1", "t1", store, flush_threshold=100) - # 1. Human message (written by worker, not journal) + # 1. Human message (written by worker, using model_dump format) + human_msg = HumanMessage(content="Search for quantum computing") await store.put( thread_id="t1", run_id="r1", event_type="human_message", category="message", - content={"role": "user", "content": "Search for quantum computing"}, + content=human_msg.model_dump(), ) j.set_first_human_message("Search for quantum computing") @@ -805,7 +977,7 @@ class TestFullRunSequence: j.on_tool_end("Quantum computing results...", run_id=tool_id, name="web_search", tool_call_id="call_1") # 5. Middleware: title generation - j.record_middleware("TitleMiddleware", "after_model", "generate_title", {"title": "Quantum Computing"}) + j.record_middleware("title", name="TitleMiddleware", hook="after_model", action="generate_title", changes={"title": "Quantum Computing"}) # 6. Second LLM call -> final reply llm2_id = uuid4() @@ -824,18 +996,19 @@ class TestFullRunSequence: await asyncio.sleep(0.05) await j.flush() - # Verify message sequence (what gets exported for training) + # Verify message sequence messages = await store.list_messages("t1") msg_types = [m["event_type"] for m in messages] assert msg_types == ["human_message", "ai_tool_call", "tool_result", "ai_message"] - # Verify message content format - assert messages[0]["content"]["role"] == "user" - assert messages[1]["content"]["role"] == "assistant" + # Verify checkpoint-aligned format: all messages use "type" not "role" + assert messages[0]["content"]["type"] == "human" + assert messages[0]["content"]["content"] == "Search for quantum computing" + assert messages[1]["content"]["type"] == "ai" assert "tool_calls" in messages[1]["content"] - assert messages[2]["content"]["role"] == "tool" + assert messages[2]["content"]["type"] == "tool" assert messages[2]["content"]["tool_call_id"] == "call_1" - assert messages[3]["content"]["role"] == "assistant" + assert messages[3]["content"]["type"] == "ai" assert messages[3]["content"]["content"] == "Here are the results about quantum computing..." # Verify trace events @@ -845,10 +1018,14 @@ class TestFullRunSequence: assert "llm_response" in trace_types assert "tool_start" in trace_types assert "tool_end" in trace_types - assert "middleware" in trace_types assert "llm_start" not in trace_types assert "llm_end" not in trace_types + # Verify middleware events are in their own category + mw_events = [e for e in events if e["category"] == "middleware"] + assert len(mw_events) == 1 + assert mw_events[0]["event_type"] == "middleware:title" + # Verify token accumulation data = j.get_completion_data() assert data["total_tokens"] == 420 # 120 + 300 @@ -857,7 +1034,7 @@ class TestFullRunSequence: assert data["message_count"] == 1 # only final ai_message counts assert data["last_ai_message"] == "Here are the results about quantum computing..." - # Verify training data export is trivial - training_messages = [m["content"] for m in messages] - assert all(isinstance(m, dict) for m in training_messages) - assert all("role" in m for m in training_messages) + # Verify all message contents are checkpoint-aligned dicts with "type" field + for m in messages: + assert isinstance(m["content"], dict) + assert "type" in m["content"]