mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-27 20:28:16 +00:00
feat(events): add tool_result message event with OpenAI tool message format
Cache tool_call_id from on_tool_start keyed by run_id as fallback for on_tool_end, then emit a tool_result message event (role=tool, tool_call_id, content) after each successful tool completion. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
8b1d569589
commit
704f6a9209
@ -67,6 +67,9 @@ class RunJournal(BaseCallbackHandler):
|
||||
# Latency tracking
|
||||
self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time
|
||||
|
||||
# Tool call ID cache
|
||||
self._tool_call_ids: dict[str, str] = {} # langchain run_id -> tool_call_id
|
||||
|
||||
# -- Lifecycle callbacks --
|
||||
|
||||
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
@ -189,28 +192,47 @@ class RunJournal(BaseCallbackHandler):
|
||||
# -- Tool callbacks --
|
||||
|
||||
def on_tool_start(self, serialized: dict, input_str: str, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
tool_call_id = kwargs.get("tool_call_id")
|
||||
if tool_call_id:
|
||||
self._tool_call_ids[str(run_id)] = tool_call_id
|
||||
self._put(
|
||||
event_type="tool_start",
|
||||
category="trace",
|
||||
metadata={
|
||||
"tool_name": serialized.get("name", ""),
|
||||
"tool_call_id": kwargs.get("tool_call_id"),
|
||||
"tool_call_id": tool_call_id,
|
||||
"args": str(input_str)[:2000],
|
||||
},
|
||||
)
|
||||
|
||||
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = kwargs.get("name", "")
|
||||
|
||||
# Trace event (always)
|
||||
self._put(
|
||||
event_type="tool_end",
|
||||
category="trace",
|
||||
content=str(output),
|
||||
metadata={
|
||||
"tool_name": kwargs.get("name", ""),
|
||||
"tool_call_id": kwargs.get("tool_call_id"),
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": tool_call_id,
|
||||
"status": "success",
|
||||
},
|
||||
)
|
||||
|
||||
# Message event: tool_result
|
||||
self._put(
|
||||
event_type="tool_result",
|
||||
category="message",
|
||||
content={
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id or "",
|
||||
"content": str(output),
|
||||
},
|
||||
metadata={"tool_name": tool_name},
|
||||
)
|
||||
|
||||
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._put(
|
||||
event_type="tool_error",
|
||||
|
||||
@ -185,9 +185,9 @@ class TestToolCallbacks:
|
||||
j.on_tool_end("results", run_id=uuid4(), name="web_search")
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
types = {e["event_type"] for e in events}
|
||||
assert "tool_start" in types
|
||||
assert "tool_end" in types
|
||||
trace_types = {e["event_type"] for e in events if e["category"] == "trace"}
|
||||
assert "tool_start" in trace_types
|
||||
assert "tool_end" in trace_types
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_on_tool_error(self, journal_setup):
|
||||
@ -623,3 +623,40 @@ class TestOpenAIMessageFormat:
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 0
|
||||
|
||||
|
||||
class TestToolResultMessage:
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_end_produces_tool_result_message(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_tool_start({"name": "web_search"}, '{"query": "test"}', run_id=run_id, tool_call_id="call_abc")
|
||||
j.on_tool_end("search results here", run_id=run_id, name="web_search", tool_call_id="call_abc")
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "tool_result"
|
||||
assert messages[0]["content"] == {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_abc",
|
||||
"content": "search results here",
|
||||
}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_result_missing_tool_call_id(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_tool_start({"name": "bash"}, "ls", run_id=run_id)
|
||||
j.on_tool_end("file1.txt", run_id=run_id, name="bash")
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"]["role"] == "tool"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_error_no_tool_result_message(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch", tool_call_id="call_1")
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user