diff --git a/backend/packages/harness/deerflow/runtime/journal.py b/backend/packages/harness/deerflow/runtime/journal.py index 465bbbff0..360d904bb 100644 --- a/backend/packages/harness/deerflow/runtime/journal.py +++ b/backend/packages/harness/deerflow/runtime/journal.py @@ -67,6 +67,9 @@ class RunJournal(BaseCallbackHandler): # Latency tracking self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time + # Tool call ID cache + self._tool_call_ids: dict[str, str] = {} # langchain run_id -> tool_call_id + # -- Lifecycle callbacks -- def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None: @@ -189,28 +192,47 @@ class RunJournal(BaseCallbackHandler): # -- Tool callbacks -- def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None: + tool_call_id = kwargs.get("tool_call_id") + if tool_call_id: + self._tool_call_ids[str(run_id)] = tool_call_id self._put( event_type="tool_start", category="trace", metadata={ "tool_name": serialized.get("name", ""), - "tool_call_id": kwargs.get("tool_call_id"), + "tool_call_id": tool_call_id, "args": str(input_str)[:2000], }, ) def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None: + tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None) + tool_name = kwargs.get("name", "") + + # Trace event (always) self._put( event_type="tool_end", category="trace", content=str(output), metadata={ - "tool_name": kwargs.get("name", ""), - "tool_call_id": kwargs.get("tool_call_id"), + "tool_name": tool_name, + "tool_call_id": tool_call_id, "status": "success", }, ) + # Message event: tool_result + self._put( + event_type="tool_result", + category="message", + content={ + "role": "tool", + "tool_call_id": tool_call_id or "", + "content": str(output), + }, + metadata={"tool_name": tool_name}, + ) + def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: self._put( event_type="tool_error", diff --git a/backend/tests/test_run_journal.py b/backend/tests/test_run_journal.py index f6d319279..3b9cdea78 100644 --- a/backend/tests/test_run_journal.py +++ b/backend/tests/test_run_journal.py @@ -185,9 +185,9 @@ class TestToolCallbacks: j.on_tool_end("results", run_id=uuid4(), name="web_search") await j.flush() events = await store.list_events("t1", "r1") - types = {e["event_type"] for e in events} - assert "tool_start" in types - assert "tool_end" in types + trace_types = {e["event_type"] for e in events if e["category"] == "trace"} + assert "tool_start" in trace_types + assert "tool_end" in trace_types @pytest.mark.anyio async def test_on_tool_error(self, journal_setup): @@ -623,3 +623,40 @@ class TestOpenAIMessageFormat: await j.flush() messages = await store.list_messages("t1") assert len(messages) == 0 + + +class TestToolResultMessage: + @pytest.mark.anyio + async def test_tool_end_produces_tool_result_message(self, journal_setup): + j, store = journal_setup + run_id = uuid4() + j.on_tool_start({"name": "web_search"}, '{"query": "test"}', run_id=run_id, tool_call_id="call_abc") + j.on_tool_end("search results here", run_id=run_id, name="web_search", tool_call_id="call_abc") + await j.flush() + messages = await store.list_messages("t1") + assert len(messages) == 1 + assert messages[0]["event_type"] == "tool_result" + assert messages[0]["content"] == { + "role": "tool", + "tool_call_id": "call_abc", + "content": "search results here", + } + + @pytest.mark.anyio + async def test_tool_result_missing_tool_call_id(self, journal_setup): + j, store = journal_setup + run_id = uuid4() + j.on_tool_start({"name": "bash"}, "ls", run_id=run_id) + j.on_tool_end("file1.txt", run_id=run_id, name="bash") + await j.flush() + messages = await store.list_messages("t1") + assert len(messages) == 1 + assert messages[0]["content"]["role"] == "tool" + + @pytest.mark.anyio + async def test_tool_error_no_tool_result_message(self, journal_setup): + j, store = journal_setup + j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch", tool_call_id="call_1") + await j.flush() + messages = await store.list_messages("t1") + assert len(messages) == 0