From 20d2d2b3731edf9d5d72a191471c1fd856453350 Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Tue, 12 May 2026 04:55:13 +0200 Subject: [PATCH] fix(middleware): Handle invalid tool calls in dangling pairing middleware (#2890) (#2891) --- .../dangling_tool_call_middleware.py | 83 +++++++++++++------ .../test_dangling_tool_call_middleware.py | 50 +++++++++++ 2 files changed, 107 insertions(+), 26 deletions(-) diff --git a/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py index 7bf600b9f..5bb54f3e5 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py @@ -36,42 +36,73 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): @staticmethod def _message_tool_calls(msg) -> list[dict]: - """Return normalized tool calls from structured fields or raw provider payloads.""" + """Return normalized tool calls from structured fields or raw provider payloads. + + LangChain stores malformed provider function calls in ``invalid_tool_calls``. + They do not execute, but provider adapters may still serialize enough of + the call id/name back into the next request that strict OpenAI-compatible + validators expect a matching ToolMessage. Treat them as dangling calls so + the next model request stays well-formed and the model sees a recoverable + tool error instead of another provider 400. + """ + normalized: list[dict] = [] + tool_calls = getattr(msg, "tool_calls", None) or [] - if tool_calls: - return list(tool_calls) + normalized.extend(list(tool_calls)) raw_tool_calls = (getattr(msg, "additional_kwargs", None) or {}).get("tool_calls") or [] - normalized: list[dict] = [] - for raw_tc in raw_tool_calls: - if not isinstance(raw_tc, dict): + if not tool_calls: + for raw_tc in raw_tool_calls: + if not isinstance(raw_tc, dict): + continue + + function = raw_tc.get("function") + name = raw_tc.get("name") + if not name and isinstance(function, dict): + name = function.get("name") + + args = raw_tc.get("args", {}) + if not args and isinstance(function, dict): + raw_args = function.get("arguments") + if isinstance(raw_args, str): + try: + parsed_args = json.loads(raw_args) + except (TypeError, ValueError, json.JSONDecodeError): + parsed_args = {} + args = parsed_args if isinstance(parsed_args, dict) else {} + + normalized.append( + { + "id": raw_tc.get("id"), + "name": name or "unknown", + "args": args if isinstance(args, dict) else {}, + } + ) + + for invalid_tc in getattr(msg, "invalid_tool_calls", None) or []: + if not isinstance(invalid_tc, dict): continue - - function = raw_tc.get("function") - name = raw_tc.get("name") - if not name and isinstance(function, dict): - name = function.get("name") - - args = raw_tc.get("args", {}) - if not args and isinstance(function, dict): - raw_args = function.get("arguments") - if isinstance(raw_args, str): - try: - parsed_args = json.loads(raw_args) - except (TypeError, ValueError, json.JSONDecodeError): - parsed_args = {} - args = parsed_args if isinstance(parsed_args, dict) else {} - normalized.append( { - "id": raw_tc.get("id"), - "name": name or "unknown", - "args": args if isinstance(args, dict) else {}, + "id": invalid_tc.get("id"), + "name": invalid_tc.get("name") or "unknown", + "args": {}, + "invalid": True, + "error": invalid_tc.get("error"), } ) return normalized + @staticmethod + def _synthetic_tool_message_content(tool_call: dict) -> str: + if tool_call.get("invalid"): + error = tool_call.get("error") + if isinstance(error, str) and error: + return f"[Tool call could not be executed because its arguments were invalid: {error}]" + return "[Tool call could not be executed because its arguments were invalid.]" + return "[Tool call was interrupted and did not return a result.]" + def _build_patched_messages(self, messages: list) -> list | None: """Return a new message list with patches inserted at the correct positions. @@ -114,7 +145,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids: patched.append( ToolMessage( - content="[Tool call was interrupted and did not return a result.]", + content=self._synthetic_tool_message_content(tc), tool_call_id=tc_id, name=tc.get("name", "unknown"), status="error", diff --git a/backend/tests/test_dangling_tool_call_middleware.py b/backend/tests/test_dangling_tool_call_middleware.py index 90c162eac..b1d5c476a 100644 --- a/backend/tests/test_dangling_tool_call_middleware.py +++ b/backend/tests/test_dangling_tool_call_middleware.py @@ -14,6 +14,10 @@ def _ai_with_tool_calls(tool_calls): return AIMessage(content="", tool_calls=tool_calls) +def _ai_with_invalid_tool_calls(invalid_tool_calls): + return AIMessage(content="", tool_calls=[], invalid_tool_calls=invalid_tool_calls) + + def _tool_msg(tool_call_id, name="test_tool"): return ToolMessage(content="result", tool_call_id=tool_call_id, name=name) @@ -22,6 +26,16 @@ def _tc(name="bash", tc_id="call_1"): return {"name": name, "id": tc_id, "args": {}} +def _invalid_tc(name="write_file", tc_id="write_file:36", error="Failed to parse tool arguments: malformed JSON"): + return { + "type": "invalid_tool_call", + "name": name, + "id": tc_id, + "args": '{"description":"write report","path":"/mnt/user-data/outputs/report.md","content":"bad {"json"}"}', + "error": error, + } + + class TestBuildPatchedMessagesNoPatch: def test_empty_messages(self): mw = DanglingToolCallMiddleware() @@ -144,6 +158,42 @@ class TestBuildPatchedMessagesPatching: assert patched[1].name == "bash" assert patched[1].status == "error" + def test_invalid_tool_call_is_patched(self): + mw = DanglingToolCallMiddleware() + msgs = [_ai_with_invalid_tool_calls([_invalid_tc()])] + patched = mw._build_patched_messages(msgs) + assert patched is not None + assert len(patched) == 2 + assert isinstance(patched[1], ToolMessage) + assert patched[1].tool_call_id == "write_file:36" + assert patched[1].name == "write_file" + assert patched[1].status == "error" + assert "arguments were invalid" in patched[1].content + assert "Failed to parse tool arguments" in patched[1].content + + def test_valid_and_invalid_tool_calls_are_both_patched(self): + mw = DanglingToolCallMiddleware() + msgs = [ + AIMessage( + content="", + tool_calls=[_tc("bash", "call_1")], + invalid_tool_calls=[_invalid_tc()], + ) + ] + patched = mw._build_patched_messages(msgs) + assert patched is not None + tool_msgs = [m for m in patched if isinstance(m, ToolMessage)] + assert len(tool_msgs) == 2 + assert {tm.tool_call_id for tm in tool_msgs} == {"call_1", "write_file:36"} + + def test_invalid_tool_call_already_responded_is_not_patched(self): + mw = DanglingToolCallMiddleware() + msgs = [ + _ai_with_invalid_tool_calls([_invalid_tc()]), + _tool_msg("write_file:36", "write_file"), + ] + assert mw._build_patched_messages(msgs) is None + class TestWrapModelCall: def test_no_patch_passthrough(self):