mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-13 12:13:46 +00:00
fix(runtime): persist run message summaries (#2850)
* fix(runtime): persist run message summaries (#2849) * fix(runtime): dedupe run message summaries
This commit is contained in:
parent
c3bc6c7cd5
commit
2eb11f97ab
@ -20,12 +20,13 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -71,6 +72,7 @@ class RunJournal(BaseCallbackHandler):
|
||||
# Dedup: LangChain may fire on_llm_end multiple times for the same run_id
|
||||
self._counted_llm_run_ids: set[str] = set()
|
||||
self._counted_external_source_ids: set[str] = set()
|
||||
self._counted_message_llm_run_ids: set[str] = set()
|
||||
|
||||
# Convenience fields
|
||||
self._last_ai_msg: str | None = None
|
||||
@ -86,6 +88,50 @@ class RunJournal(BaseCallbackHandler):
|
||||
|
||||
# -- Lifecycle callbacks --
|
||||
|
||||
@staticmethod
|
||||
def _message_text(message: BaseMessage) -> str:
|
||||
"""Extract displayable text from a message's mixed content shape."""
|
||||
content = getattr(message, "content", None)
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
parts.append(block)
|
||||
elif isinstance(block, Mapping):
|
||||
text = block.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
else:
|
||||
nested = block.get("content")
|
||||
if isinstance(nested, str):
|
||||
parts.append(nested)
|
||||
return "".join(parts)
|
||||
if isinstance(content, Mapping):
|
||||
for key in ("text", "content"):
|
||||
value = content.get(key)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
text = getattr(message, "text", None)
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
return ""
|
||||
|
||||
def _record_message_summary(self, message: BaseMessage, *, caller: str | None = None) -> None:
|
||||
"""Update run-level convenience fields for persisted run rows."""
|
||||
self._msg_count += 1
|
||||
|
||||
# ``last_ai_message`` should represent the lead agent's user-facing
|
||||
# answer. Middleware/subagent model calls and empty tool-call-only
|
||||
# AI messages must not overwrite the last useful assistant text.
|
||||
is_ai_message = isinstance(message, AIMessage) or getattr(message, "type", None) == "ai"
|
||||
if is_ai_message and (caller is None or caller == "lead_agent"):
|
||||
text = self._message_text(message).strip()
|
||||
if text:
|
||||
self._last_ai_msg = text[:2000]
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
@ -164,6 +210,7 @@ class RunJournal(BaseCallbackHandler):
|
||||
content=m.model_dump(),
|
||||
metadata={"caller": caller},
|
||||
)
|
||||
self._record_message_summary(m, caller=caller)
|
||||
break
|
||||
if self._first_human_msg:
|
||||
break
|
||||
@ -222,6 +269,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
"llm_call_index": call_index,
|
||||
},
|
||||
)
|
||||
if rid not in self._counted_message_llm_run_ids:
|
||||
self._record_message_summary(message, caller=caller)
|
||||
|
||||
# Token accumulation (dedup by langchain run_id to avoid double-counting
|
||||
# when the callback fires more than once for the same response)
|
||||
@ -245,6 +294,9 @@ class RunJournal(BaseCallbackHandler):
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
|
||||
if messages:
|
||||
self._counted_message_llm_run_ids.add(str(run_id))
|
||||
|
||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._llm_start_times.pop(str(run_id), None)
|
||||
self._put(event_type="llm.error", category="trace", content=str(error))
|
||||
@ -260,12 +312,14 @@ class RunJournal(BaseCallbackHandler):
|
||||
if isinstance(output, ToolMessage):
|
||||
msg = cast(ToolMessage, output)
|
||||
self._put(event_type="llm.tool.result", category="message", content=msg.model_dump())
|
||||
self._record_message_summary(msg)
|
||||
elif isinstance(output, Command):
|
||||
cmd = cast(Command, output)
|
||||
messages = cmd.update.get("messages", [])
|
||||
for message in messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
self._put(event_type="llm.tool.result", category="message", content=message.model_dump())
|
||||
self._record_message_summary(message)
|
||||
else:
|
||||
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
|
||||
else:
|
||||
|
||||
@ -339,6 +339,99 @@ class TestConvenienceFields:
|
||||
data = j.get_completion_data()
|
||||
assert data["first_human_message"] == "What is AI?"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_completion_data_counts_human_ai_and_tool_messages(self, journal_setup):
|
||||
from langchain_core.messages import HumanMessage, ToolMessage
|
||||
|
||||
j, _ = journal_setup
|
||||
j.on_chat_model_start({}, [[HumanMessage(content="Question")]], run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("Answer"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||
j.on_tool_end(ToolMessage(content="Tool result", tool_call_id="call_1", name="search"), run_id=uuid4())
|
||||
|
||||
data = j.get_completion_data()
|
||||
|
||||
assert data["message_count"] == 3
|
||||
assert data["first_human_message"] == "Question"
|
||||
assert data["last_ai_message"] == "Answer"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_call_only_ai_does_not_clear_last_ai_message(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
j.on_llm_end(_make_llm_response("Useful answer"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||
j.on_llm_end(
|
||||
_make_llm_response("", tool_calls=[{"id": "call_1", "name": "search", "args": {}}]),
|
||||
run_id=uuid4(),
|
||||
parent_run_id=None,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
|
||||
data = j.get_completion_data()
|
||||
|
||||
assert data["message_count"] == 2
|
||||
assert data["last_ai_message"] == "Useful answer"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_last_ai_message_extracts_mixed_content_without_extra_newlines(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
j.on_llm_end(
|
||||
_make_llm_response(
|
||||
[
|
||||
{"type": "text", "text": "First "},
|
||||
{"type": "text", "content": "second"},
|
||||
" third",
|
||||
{"type": "image", "url": "ignored"},
|
||||
]
|
||||
),
|
||||
run_id=uuid4(),
|
||||
parent_run_id=None,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
|
||||
data = j.get_completion_data()
|
||||
|
||||
assert data["message_count"] == 1
|
||||
assert data["last_ai_message"] == "First second third"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_last_ai_message_extracts_mapping_content(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
j.on_llm_end(_make_llm_response({"content": "Nested answer"}), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||
|
||||
data = j.get_completion_data()
|
||||
|
||||
assert data["message_count"] == 1
|
||||
assert data["last_ai_message"] == "Nested answer"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_duplicate_llm_run_id_does_not_double_count_message_summary(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
run_id = uuid4()
|
||||
|
||||
j.on_llm_end(_make_llm_response("Answer", usage=None), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
||||
j.on_llm_end(
|
||||
_make_llm_response("Answer", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
|
||||
run_id=run_id,
|
||||
parent_run_id=None,
|
||||
tags=["lead_agent"],
|
||||
)
|
||||
|
||||
data = j.get_completion_data()
|
||||
|
||||
assert data["message_count"] == 1
|
||||
assert data["last_ai_message"] == "Answer"
|
||||
assert data["total_tokens"] == 15
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_subagent_ai_does_not_overwrite_lead_last_ai_message(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
j.on_llm_end(_make_llm_response("Lead answer"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("Subagent detail"), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"])
|
||||
|
||||
data = j.get_completion_data()
|
||||
|
||||
assert data["message_count"] == 2
|
||||
assert data["last_ai_message"] == "Lead answer"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_completion_data(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user