From 5fd0e6ac894c4e23418abfba39591e874fd93f97 Mon Sep 17 00:00:00 2001 From: Eilen Shin <136898293+Eilen6316@users.noreply.github.com> Date: Fri, 8 May 2026 10:08:53 +0800 Subject: [PATCH] fix(middleware): sync raw tool call metadata (#2757) --- .../middlewares/subagent_limit_middleware.py | 3 +- .../middlewares/summarization_middleware.py | 7 +- .../agents/middlewares/tool_call_metadata.py | 50 +++++++++++ .../tests/test_subagent_limit_middleware.py | 25 ++++++ .../tests/test_summarization_middleware.py | 85 +++++++++++++++++++ 5 files changed, 165 insertions(+), 5 deletions(-) create mode 100644 backend/packages/harness/deerflow/agents/middlewares/tool_call_metadata.py diff --git a/backend/packages/harness/deerflow/agents/middlewares/subagent_limit_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/subagent_limit_middleware.py index 11de5131a..eaff3c181 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/subagent_limit_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/subagent_limit_middleware.py @@ -7,6 +7,7 @@ from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langgraph.runtime import Runtime +from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls from deerflow.subagents.executor import MAX_CONCURRENT_SUBAGENTS logger = logging.getLogger(__name__) @@ -63,7 +64,7 @@ class SubagentLimitMiddleware(AgentMiddleware[AgentState]): logger.warning(f"Truncated {dropped_count} excess task tool call(s) from model response (limit: {self.max_concurrent})") # Replace the AIMessage with truncated tool_calls (same id triggers replacement) - updated_msg = last_msg.model_copy(update={"tool_calls": truncated_tool_calls}) + updated_msg = clone_ai_message_with_tool_calls(last_msg, truncated_tool_calls) return {"messages": [updated_msg]} @override diff --git a/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py index 9f2c1a055..65b98f9f5 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/summarization_middleware.py @@ -14,6 +14,8 @@ from langgraph.config import get_config from langgraph.graph.message import REMOVE_ALL_MESSAGES from langgraph.runtime import Runtime +from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls + logger = logging.getLogger(__name__) @@ -78,10 +80,7 @@ def _clone_ai_message( content: Any | None = None, ) -> AIMessage: """Clone an AIMessage while replacing its tool_calls list and optional content.""" - update: dict[str, Any] = {"tool_calls": tool_calls} - if content is not None: - update["content"] = content - return message.model_copy(update=update) + return clone_ai_message_with_tool_calls(message, tool_calls, content=content) @dataclass diff --git a/backend/packages/harness/deerflow/agents/middlewares/tool_call_metadata.py b/backend/packages/harness/deerflow/agents/middlewares/tool_call_metadata.py new file mode 100644 index 000000000..f0845622b --- /dev/null +++ b/backend/packages/harness/deerflow/agents/middlewares/tool_call_metadata.py @@ -0,0 +1,50 @@ +"""Helpers for keeping AIMessage tool-call metadata consistent.""" + +from __future__ import annotations + +from typing import Any + +from langchain_core.messages import AIMessage + + +def _raw_tool_call_id(raw_tool_call: Any) -> str | None: + if not isinstance(raw_tool_call, dict): + return None + + raw_id = raw_tool_call.get("id") + return raw_id if isinstance(raw_id, str) and raw_id else None + + +def clone_ai_message_with_tool_calls( + message: AIMessage, + tool_calls: list[dict[str, Any]], + *, + content: Any | None = None, +) -> AIMessage: + """Clone an AIMessage while keeping raw provider tool-call metadata in sync.""" + kept_ids = {tc["id"] for tc in tool_calls if isinstance(tc.get("id"), str) and tc["id"]} + + update: dict[str, Any] = {"tool_calls": tool_calls} + if content is not None: + update["content"] = content + + additional_kwargs = dict(getattr(message, "additional_kwargs", {}) or {}) + raw_tool_calls = additional_kwargs.get("tool_calls") + if isinstance(raw_tool_calls, list): + synced_raw_tool_calls = [raw_tc for raw_tc in raw_tool_calls if _raw_tool_call_id(raw_tc) in kept_ids] + if synced_raw_tool_calls: + additional_kwargs["tool_calls"] = synced_raw_tool_calls + else: + additional_kwargs.pop("tool_calls", None) + + if not tool_calls: + additional_kwargs.pop("function_call", None) + + update["additional_kwargs"] = additional_kwargs + + response_metadata = dict(getattr(message, "response_metadata", {}) or {}) + if not tool_calls and response_metadata.get("finish_reason") == "tool_calls": + response_metadata["finish_reason"] = "stop" + update["response_metadata"] = response_metadata + + return message.model_copy(update=update) diff --git a/backend/tests/test_subagent_limit_middleware.py b/backend/tests/test_subagent_limit_middleware.py index c331c3aca..969e53353 100644 --- a/backend/tests/test_subagent_limit_middleware.py +++ b/backend/tests/test_subagent_limit_middleware.py @@ -27,6 +27,14 @@ def _other_call(name="bash", call_id="call_other"): return {"name": name, "id": call_id, "args": {}} +def _raw_tool_call(call_id: str, name: str = "task") -> dict: + return { + "id": call_id, + "type": "function", + "function": {"name": name, "arguments": "{}"}, + } + + class TestClampSubagentLimit: def test_below_min_clamped_to_min(self): assert _clamp_subagent_limit(0) == MIN_SUBAGENT_LIMIT @@ -117,6 +125,23 @@ class TestTruncateTaskCalls: task_calls = [tc for tc in updated_msg.tool_calls if tc["name"] == "task"] assert len(task_calls) == 2 + def test_truncation_syncs_raw_provider_tool_calls(self): + mw = SubagentLimitMiddleware(max_concurrent=2) + msg = AIMessage( + content="", + tool_calls=[_task_call("t1"), _task_call("t2"), _task_call("t3"), _task_call("t4")], + additional_kwargs={"tool_calls": [_raw_tool_call("t1"), _raw_tool_call("t2"), _raw_tool_call("t3"), _raw_tool_call("t4")]}, + response_metadata={"finish_reason": "tool_calls"}, + ) + + result = mw._truncate_task_calls({"messages": [msg]}) + + assert result is not None + updated_msg = result["messages"][0] + assert [tc["id"] for tc in updated_msg.tool_calls] == ["t1", "t2"] + assert [tc["id"] for tc in updated_msg.additional_kwargs["tool_calls"]] == ["t1", "t2"] + assert updated_msg.response_metadata["finish_reason"] == "tool_calls" + def test_only_non_task_calls_returns_none(self): mw = SubagentLimitMiddleware() msg = AIMessage( diff --git a/backend/tests/test_summarization_middleware.py b/backend/tests/test_summarization_middleware.py index 79ca8b01c..abed0105a 100644 --- a/backend/tests/test_summarization_middleware.py +++ b/backend/tests/test_summarization_middleware.py @@ -75,6 +75,14 @@ def _skill_conversation() -> list: ] +def _raw_tool_call(tool_id: str, name: str = "read_file") -> dict: + return { + "id": tool_id, + "type": "function", + "function": {"name": name, "arguments": "{}"}, + } + + def test_before_summarization_hook_receives_messages_before_compression() -> None: captured: list[SummarizationEvent] = [] middleware = _middleware(before_summarization=[captured.append]) @@ -413,6 +421,47 @@ def test_skill_rescue_does_not_preserve_non_skill_outputs_from_mixed_tool_calls( assert any(isinstance(m, ToolMessage) and m.content == "user notes" for m in summarized) +def test_skill_rescue_syncs_raw_provider_tool_calls_on_split_ai_messages() -> None: + captured: list[SummarizationEvent] = [] + middleware = _middleware( + before_summarization=[captured.append], + trigger=("messages", 4), + keep=("messages", 2), + preserve_recent_skill_count=5, + preserve_recent_skill_tokens=10_000, + preserve_recent_skill_tokens_per_skill=10_000, + ) + + messages = [ + HumanMessage(content="u1"), + AIMessage( + content="reading skill and notes", + tool_calls=[ + _skill_read_call("skill-1", "alpha"), + {"name": "read_file", "id": "file-1", "args": {"path": "/mnt/user-data/workspace/notes.md"}}, + ], + additional_kwargs={"tool_calls": [_raw_tool_call("skill-1"), _raw_tool_call("file-1")]}, + ), + ToolMessage(content="alpha skill body", tool_call_id="skill-1"), + ToolMessage(content="user notes", tool_call_id="file-1"), + HumanMessage(content="u2"), + AIMessage(content="done"), + ] + + middleware.before_model({"messages": messages}, _runtime()) + + preserved = captured[0].preserved_messages + summarized = captured[0].messages_to_summarize + + preserved_ai = next(m for m in preserved if isinstance(m, AIMessage) and m.tool_calls) + summarized_ai = next(m for m in summarized if isinstance(m, AIMessage) and m.tool_calls) + + assert [tc["id"] for tc in preserved_ai.tool_calls] == ["skill-1"] + assert [tc["id"] for tc in preserved_ai.additional_kwargs["tool_calls"]] == ["skill-1"] + assert [tc["id"] for tc in summarized_ai.tool_calls] == ["file-1"] + assert [tc["id"] for tc in summarized_ai.additional_kwargs["tool_calls"]] == ["file-1"] + + def test_skill_rescue_clears_content_on_rescued_ai_clone() -> None: captured: list[SummarizationEvent] = [] middleware = _middleware( @@ -451,6 +500,42 @@ def test_skill_rescue_clears_content_on_rescued_ai_clone() -> None: assert summarized_ai.content == "reading skill and notes" +def test_skill_rescue_removes_raw_provider_tool_calls_from_content_only_summary_clone() -> None: + captured: list[SummarizationEvent] = [] + middleware = _middleware( + before_summarization=[captured.append], + trigger=("messages", 4), + keep=("messages", 2), + preserve_recent_skill_count=5, + preserve_recent_skill_tokens=10_000, + preserve_recent_skill_tokens_per_skill=10_000, + ) + + messages = [ + HumanMessage(content="u1"), + AIMessage( + content="reading skill", + tool_calls=[_skill_read_call("skill-1", "alpha")], + additional_kwargs={"tool_calls": [_raw_tool_call("skill-1")], "function_call": {"name": "read_file"}}, + response_metadata={"finish_reason": "tool_calls"}, + ), + ToolMessage(content="alpha skill body", tool_call_id="skill-1"), + HumanMessage(content="u2"), + AIMessage(content="done"), + ] + + middleware.before_model({"messages": messages}, _runtime()) + + summarized = captured[0].messages_to_summarize + summarized_ai = next(m for m in summarized if isinstance(m, AIMessage)) + + assert summarized_ai.content == "reading skill" + assert summarized_ai.tool_calls == [] + assert "tool_calls" not in summarized_ai.additional_kwargs + assert "function_call" not in summarized_ai.additional_kwargs + assert summarized_ai.response_metadata["finish_reason"] == "stop" + + def test_skill_rescue_only_preserves_skill_calls_with_matched_tool_results() -> None: captured: list[SummarizationEvent] = [] middleware = _middleware(