diff --git a/backend/packages/harness/deerflow/runtime/converters.py b/backend/packages/harness/deerflow/runtime/converters.py index 811031160..79d3b2b84 100644 --- a/backend/packages/harness/deerflow/runtime/converters.py +++ b/backend/packages/harness/deerflow/runtime/converters.py @@ -1,6 +1,8 @@ """Pure functions to convert LangChain message objects to OpenAI Chat Completions format. -Used by RunJournal to build content dicts for event storage. +Utility for translating LangChain message types to OpenAI-compatible dicts. +Not currently wired into RunJournal (which uses message.model_dump() directly), +but available for consumers that need the OpenAI wire format. """ from __future__ import annotations diff --git a/backend/packages/harness/deerflow/runtime/journal.py b/backend/packages/harness/deerflow/runtime/journal.py index 688cde200..a0c2d029b 100644 --- a/backend/packages/harness/deerflow/runtime/journal.py +++ b/backend/packages/harness/deerflow/runtime/journal.py @@ -62,9 +62,6 @@ class RunJournal(BaseCallbackHandler): self._total_output_tokens = 0 self._total_tokens = 0 self._llm_call_count = 0 - self._lead_agent_tokens = 0 - self._subagent_tokens = 0 - self._middleware_tokens = 0 # Convenience fields self._last_ai_msg: str | None = None @@ -76,7 +73,7 @@ class RunJournal(BaseCallbackHandler): # LLM request/response tracking self._llm_call_index = 0 - self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages + self._seen_llm_starts: set[str] = set() # langchain run_ids that fired on_chat_model_start # -- Lifecycle callbacks -- @@ -135,8 +132,7 @@ class RunJournal(BaseCallbackHandler): rid = str(run_id) self._llm_start_times[rid] = time.monotonic() self._llm_call_index += 1 - # Mark this run_id as seen so on_llm_end knows not to increment again. - self._cached_prompts[rid] = [] + self._seen_llm_starts.add(rid) logger.debug( "on_chat_model_start %s: tags=%s num_batches=%d message_counts=%s", @@ -148,8 +144,8 @@ class RunJournal(BaseCallbackHandler): # Capture the first human message sent to any LLM in this run. if not self._first_human_msg and messages: - for batch in messages.reversed(): - for m in batch.reversed(): + for batch in reversed(messages): + for m in reversed(batch): if isinstance(m, HumanMessage) and m.name != "summary": caller = self._identify_caller(tags) self.set_first_human_message(m.text) @@ -167,9 +163,17 @@ class RunJournal(BaseCallbackHandler): # Fallback: on_chat_model_start is preferred. This just tracks latency. self._llm_start_times[str(run_id)] = time.monotonic() - def on_llm_end(self, response, *, run_id, parent_run_id, tags, **kwargs) -> None: + def on_llm_end( + self, + response: Any, + *, + run_id: UUID, + parent_run_id: UUID | None = None, + tags: list[str] | None = None, + **kwargs: Any, + ) -> None: messages: list[AnyMessage] = [] - logger.info(f"on_llm_end {run_id}: response: {tags} {kwargs}") + logger.debug("on_llm_end %s: tags=%s", run_id, tags) for generation in response.generations: for gen in generation: if hasattr(gen, "message"): @@ -191,10 +195,11 @@ class RunJournal(BaseCallbackHandler): # Resolve call index call_index = self._llm_call_index - if rid not in self._cached_prompts: + if rid not in self._seen_llm_starts: # Fallback: on_chat_model_start was not called self._llm_call_index += 1 call_index = self._llm_call_index + self._seen_llm_starts.add(rid) # Trace event: llm_response (OpenAI completion format) self._put( @@ -229,7 +234,7 @@ class RunJournal(BaseCallbackHandler): def on_tool_start(self, serialized, input_str, *, run_id, parent_run_id=None, tags=None, metadata=None, inputs=None, **kwargs): """Handle tool start event, cache tool call ID for later correlation""" tool_call_id = str(run_id) - logger.info(f"Tool start for node {run_id}, tool_call_id={tool_call_id}, tags={tags}, metadata={metadata}") + logger.debug("Tool start for node %s, tool_call_id=%s, tags=%s", run_id, tool_call_id, tags) def on_tool_end(self, output, *, run_id, parent_run_id=None, **kwargs): """Handle tool end event, append message and clear node data""" @@ -248,7 +253,7 @@ class RunJournal(BaseCallbackHandler): else: logger.warning(f"on_tool_end {run_id}: output is not ToolMessage: {type(output)}") finally: - logger.info(f"Tool end for node {run_id}") + logger.debug("Tool end for node %s", run_id) # -- Internal methods -- @@ -313,8 +318,8 @@ class RunJournal(BaseCallbackHandler): if exc: logger.warning("Journal flush task failed: %s", exc) - def _identify_caller(self, tags: list[str] | None, **kwargs) -> str: - _tags = tags or kwargs.get("tags", []) + def _identify_caller(self, tags: list[str] | None) -> str: + _tags = tags or [] for tag in _tags: if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"): return tag @@ -371,9 +376,6 @@ class RunJournal(BaseCallbackHandler): "total_output_tokens": self._total_output_tokens, "total_tokens": self._total_tokens, "llm_call_count": self._llm_call_count, - "lead_agent_tokens": self._lead_agent_tokens, - "subagent_tokens": self._subagent_tokens, - "middleware_tokens": self._middleware_tokens, "message_count": self._msg_count, "last_ai_message": self._last_ai_msg, "first_human_message": self._first_human_msg, diff --git a/backend/tests/test_run_journal.py b/backend/tests/test_run_journal.py index a70d02b9b..2188eeef0 100644 --- a/backend/tests/test_run_journal.py +++ b/backend/tests/test_run_journal.py @@ -381,3 +381,62 @@ class TestMiddlewareEvents: event_types = {e["event_type"] for e in events} assert "middleware:title" in event_types assert "middleware:guardrail" in event_types + + +class TestChatModelStartHumanMessage: + """Tests for on_chat_model_start extracting the first human message.""" + + @pytest.mark.anyio + async def test_extracts_first_human_message(self, journal_setup): + """on_chat_model_start captures the first HumanMessage from prompts.""" + from langchain_core.messages import AIMessage, HumanMessage + + j, store = journal_setup + messages_batch = [ + [HumanMessage(content="What is AI?"), AIMessage(content="Hi there")], + ] + j.on_chat_model_start({}, messages_batch, run_id=uuid4(), tags=["lead_agent"]) + await j.flush() + + assert j._first_human_msg == "What is AI?" + events = await store.list_events("t1", "r1") + human_events = [e for e in events if e["event_type"] == "llm.human.input"] + assert len(human_events) == 1 + assert human_events[0]["content"]["content"] == "What is AI?" + + @pytest.mark.anyio + async def test_skips_summary_named_human_messages(self, journal_setup): + """HumanMessages with name='summary' are skipped.""" + from langchain_core.messages import HumanMessage + + j, store = journal_setup + messages_batch = [ + [HumanMessage(content="Summarized context", name="summary"), HumanMessage(content="Real question")], + ] + j.on_chat_model_start({}, messages_batch, run_id=uuid4(), tags=["lead_agent"]) + await j.flush() + + assert j._first_human_msg == "Real question" + + @pytest.mark.anyio + async def test_only_first_human_message_captured(self, journal_setup): + """Subsequent on_chat_model_start calls do not overwrite the first message.""" + from langchain_core.messages import HumanMessage + + j, store = journal_setup + j.on_chat_model_start({}, [[HumanMessage(content="First question")]], run_id=uuid4(), tags=["lead_agent"]) + j.on_chat_model_start({}, [[HumanMessage(content="Second question")]], run_id=uuid4(), tags=["lead_agent"]) + await j.flush() + + assert j._first_human_msg == "First question" + events = await store.list_events("t1", "r1") + human_events = [e for e in events if e["event_type"] == "llm.human.input"] + assert len(human_events) == 1 + + @pytest.mark.anyio + async def test_empty_messages_no_crash(self, journal_setup): + """on_chat_model_start with empty messages does not crash.""" + j, store = journal_setup + j.on_chat_model_start({}, [], run_id=uuid4(), tags=["lead_agent"]) + await j.flush() + assert j._first_human_msg is None