mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
fix(middleware): repair dangling tool-call history after loop interru… (#2035)
* fix(middleware): repair dangling tool-call history after loop interruption (#2029) * docs(backend): fix middleware chain ordering --------- Co-authored-by: luoxiao6645 <luoxiao6645@gmail.com>
This commit is contained in:
parent
4efc8d404f
commit
5db71cb68c
@ -658,6 +658,8 @@ This is the difference between a chatbot with tool access and an agent with an a
|
|||||||
|
|
||||||
**Summarization**: Within a session, DeerFlow manages context aggressively — summarizing completed sub-tasks, offloading intermediate results to the filesystem, compressing what's no longer immediately relevant. This lets it stay sharp across long, multi-step tasks without blowing the context window.
|
**Summarization**: Within a session, DeerFlow manages context aggressively — summarizing completed sub-tasks, offloading intermediate results to the filesystem, compressing what's no longer immediately relevant. This lets it stay sharp across long, multi-step tasks without blowing the context window.
|
||||||
|
|
||||||
|
**Strict Tool-Call Recovery**: When a provider or middleware interrupts a tool-call loop, DeerFlow now strips provider-level raw tool-call metadata on forced-stop assistant messages and injects placeholder tool results for dangling calls before the next model invocation. This keeps OpenAI-compatible reasoning models that strictly validate `tool_call_id` sequences from failing with malformed history errors.
|
||||||
|
|
||||||
### Long-Term Memory
|
### Long-Term Memory
|
||||||
|
|
||||||
Most agents forget everything the moment a conversation ends. DeerFlow remembers.
|
Most agents forget everything the moment a conversation ends. DeerFlow remembers.
|
||||||
|
|||||||
@ -156,20 +156,26 @@ from deerflow.config import get_app_config
|
|||||||
|
|
||||||
### Middleware Chain
|
### Middleware Chain
|
||||||
|
|
||||||
Middlewares execute in strict order in `packages/harness/deerflow/agents/lead_agent/agent.py`:
|
Lead-agent middlewares are assembled in strict append order across `packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py` (`build_lead_runtime_middlewares`) and `packages/harness/deerflow/agents/lead_agent/agent.py` (`_build_middlewares`):
|
||||||
|
|
||||||
1. **ThreadDataMiddleware** - Creates per-thread directories (`backend/.deer-flow/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local `.deer-flow/threads/{thread_id}` directory
|
1. **ThreadDataMiddleware** - Creates per-thread directories (`backend/.deer-flow/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local `.deer-flow/threads/{thread_id}` directory
|
||||||
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
|
||||||
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
|
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
|
||||||
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption)
|
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption), including raw provider tool-call payloads preserved only in `additional_kwargs["tool_calls"]`
|
||||||
5. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
|
5. **LLMErrorHandlingMiddleware** - Normalizes provider/model invocation failures into recoverable assistant-facing errors before later middleware/tool stages run
|
||||||
6. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
|
6. **GuardrailMiddleware** - Pre-tool-call authorization via pluggable `GuardrailProvider` protocol (optional, if `guardrails.enabled` in config). Evaluates each tool call and returns error ToolMessage on deny. Three provider options: built-in `AllowlistProvider` (zero deps), OAP policy providers (e.g. `aport-agent-guardrails`), or custom providers. See [docs/GUARDRAILS.md](docs/GUARDRAILS.md) for setup, usage, and how to implement a provider.
|
||||||
7. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
|
7. **SandboxAuditMiddleware** - Audits sandboxed shell/file operations for security logging before tool execution continues
|
||||||
8. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
|
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
|
||||||
9. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
|
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
|
||||||
10. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
|
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
|
||||||
11. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if subagent_enabled)
|
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional)
|
||||||
12. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
|
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
|
||||||
|
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
|
||||||
|
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
|
||||||
|
15. **DeferredToolFilterMiddleware** - Hides deferred tool schemas from the bound model until tool search is enabled (optional)
|
||||||
|
16. **SubagentLimitMiddleware** - Truncates excess `task` tool calls from model response to enforce `MAX_CONCURRENT_SUBAGENTS` limit (optional, if `subagent_enabled`)
|
||||||
|
17. **LoopDetectionMiddleware** - Detects repeated tool-call loops; hard-stop responses clear both structured `tool_calls` and raw provider tool-call metadata before forcing a final text answer
|
||||||
|
18. **ClarificationMiddleware** - Intercepts `ask_clarification` tool calls, interrupts via `Command(goto=END)` (must be last)
|
||||||
|
|
||||||
### Configuration System
|
### Configuration System
|
||||||
|
|
||||||
|
|||||||
@ -13,6 +13,7 @@ at the correct positions (immediately after each dangling AIMessage), not append
|
|||||||
to the end of the message list as before_model + add_messages reducer would do.
|
to the end of the message list as before_model + add_messages reducer would do.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import override
|
from typing import override
|
||||||
@ -33,6 +34,44 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
|||||||
offending AIMessage so the LLM receives a well-formed conversation.
|
offending AIMessage so the LLM receives a well-formed conversation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _message_tool_calls(msg) -> list[dict]:
|
||||||
|
"""Return normalized tool calls from structured fields or raw provider payloads."""
|
||||||
|
tool_calls = getattr(msg, "tool_calls", None) or []
|
||||||
|
if tool_calls:
|
||||||
|
return list(tool_calls)
|
||||||
|
|
||||||
|
raw_tool_calls = (getattr(msg, "additional_kwargs", None) or {}).get("tool_calls") or []
|
||||||
|
normalized: list[dict] = []
|
||||||
|
for raw_tc in raw_tool_calls:
|
||||||
|
if not isinstance(raw_tc, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
function = raw_tc.get("function")
|
||||||
|
name = raw_tc.get("name")
|
||||||
|
if not name and isinstance(function, dict):
|
||||||
|
name = function.get("name")
|
||||||
|
|
||||||
|
args = raw_tc.get("args", {})
|
||||||
|
if not args and isinstance(function, dict):
|
||||||
|
raw_args = function.get("arguments")
|
||||||
|
if isinstance(raw_args, str):
|
||||||
|
try:
|
||||||
|
parsed_args = json.loads(raw_args)
|
||||||
|
except (TypeError, ValueError, json.JSONDecodeError):
|
||||||
|
parsed_args = {}
|
||||||
|
args = parsed_args if isinstance(parsed_args, dict) else {}
|
||||||
|
|
||||||
|
normalized.append(
|
||||||
|
{
|
||||||
|
"id": raw_tc.get("id"),
|
||||||
|
"name": name or "unknown",
|
||||||
|
"args": args if isinstance(args, dict) else {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return normalized
|
||||||
|
|
||||||
def _build_patched_messages(self, messages: list) -> list | None:
|
def _build_patched_messages(self, messages: list) -> list | None:
|
||||||
"""Return a new message list with patches inserted at the correct positions.
|
"""Return a new message list with patches inserted at the correct positions.
|
||||||
|
|
||||||
@ -51,7 +90,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
|||||||
for msg in messages:
|
for msg in messages:
|
||||||
if getattr(msg, "type", None) != "ai":
|
if getattr(msg, "type", None) != "ai":
|
||||||
continue
|
continue
|
||||||
for tc in getattr(msg, "tool_calls", None) or []:
|
for tc in self._message_tool_calls(msg):
|
||||||
tc_id = tc.get("id")
|
tc_id = tc.get("id")
|
||||||
if tc_id and tc_id not in existing_tool_msg_ids:
|
if tc_id and tc_id not in existing_tool_msg_ids:
|
||||||
needs_patch = True
|
needs_patch = True
|
||||||
@ -70,7 +109,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
|||||||
patched.append(msg)
|
patched.append(msg)
|
||||||
if getattr(msg, "type", None) != "ai":
|
if getattr(msg, "type", None) != "ai":
|
||||||
continue
|
continue
|
||||||
for tc in getattr(msg, "tool_calls", None) or []:
|
for tc in self._message_tool_calls(msg):
|
||||||
tc_id = tc.get("id")
|
tc_id = tc.get("id")
|
||||||
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
|
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
|
||||||
patched.append(
|
patched.append(
|
||||||
|
|||||||
@ -17,6 +17,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
|
from copy import deepcopy
|
||||||
from typing import override
|
from typing import override
|
||||||
|
|
||||||
from langchain.agents import AgentState
|
from langchain.agents import AgentState
|
||||||
@ -323,6 +324,26 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
|||||||
# Fallback: coerce unexpected types to str to avoid TypeError
|
# Fallback: coerce unexpected types to str to avoid TypeError
|
||||||
return str(content) + f"\n\n{text}"
|
return str(content) + f"\n\n{text}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_hard_stop_update(last_msg, content: str | list) -> dict:
|
||||||
|
"""Clear tool-call metadata so forced-stop messages serialize as plain assistant text."""
|
||||||
|
update = {
|
||||||
|
"tool_calls": [],
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
additional_kwargs = dict(getattr(last_msg, "additional_kwargs", {}) or {})
|
||||||
|
for key in ("tool_calls", "function_call"):
|
||||||
|
additional_kwargs.pop(key, None)
|
||||||
|
update["additional_kwargs"] = additional_kwargs
|
||||||
|
|
||||||
|
response_metadata = deepcopy(getattr(last_msg, "response_metadata", {}) or {})
|
||||||
|
if response_metadata.get("finish_reason") == "tool_calls":
|
||||||
|
response_metadata["finish_reason"] = "stop"
|
||||||
|
update["response_metadata"] = response_metadata
|
||||||
|
|
||||||
|
return update
|
||||||
|
|
||||||
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||||
warning, hard_stop = self._track_and_check(state, runtime)
|
warning, hard_stop = self._track_and_check(state, runtime)
|
||||||
|
|
||||||
@ -330,12 +351,8 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
|||||||
# Strip tool_calls from the last AIMessage to force text output
|
# Strip tool_calls from the last AIMessage to force text output
|
||||||
messages = state.get("messages", [])
|
messages = state.get("messages", [])
|
||||||
last_msg = messages[-1]
|
last_msg = messages[-1]
|
||||||
stripped_msg = last_msg.model_copy(
|
content = self._append_text(last_msg.content, warning or _HARD_STOP_MSG)
|
||||||
update={
|
stripped_msg = last_msg.model_copy(update=self._build_hard_stop_update(last_msg, content))
|
||||||
"tool_calls": [],
|
|
||||||
"content": self._append_text(last_msg.content, warning),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return {"messages": [stripped_msg]}
|
return {"messages": [stripped_msg]}
|
||||||
|
|
||||||
if warning:
|
if warning:
|
||||||
|
|||||||
@ -119,6 +119,31 @@ class TestBuildPatchedMessagesPatching:
|
|||||||
assert "interrupted" in tool_msg.content.lower()
|
assert "interrupted" in tool_msg.content.lower()
|
||||||
assert tool_msg.name == "bash"
|
assert tool_msg.name == "bash"
|
||||||
|
|
||||||
|
def test_raw_provider_tool_calls_are_patched(self):
|
||||||
|
mw = DanglingToolCallMiddleware()
|
||||||
|
msgs = [
|
||||||
|
AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[],
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_1",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "bash", "arguments": '{"command":"ls"}'},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
patched = mw._build_patched_messages(msgs)
|
||||||
|
assert patched is not None
|
||||||
|
assert len(patched) == 2
|
||||||
|
assert isinstance(patched[1], ToolMessage)
|
||||||
|
assert patched[1].tool_call_id == "call_1"
|
||||||
|
assert patched[1].name == "bash"
|
||||||
|
assert patched[1].status == "error"
|
||||||
|
|
||||||
|
|
||||||
class TestWrapModelCall:
|
class TestWrapModelCall:
|
||||||
def test_no_patch_passthrough(self):
|
def test_no_patch_passthrough(self):
|
||||||
|
|||||||
@ -413,6 +413,45 @@ class TestHardStopWithListContent:
|
|||||||
assert msg.content.startswith("thinking...")
|
assert msg.content.startswith("thinking...")
|
||||||
assert _HARD_STOP_MSG in msg.content
|
assert _HARD_STOP_MSG in msg.content
|
||||||
|
|
||||||
|
def test_hard_stop_clears_raw_tool_call_metadata(self):
|
||||||
|
"""Forced-stop messages must not retain provider-level raw tool-call payloads."""
|
||||||
|
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
|
||||||
|
runtime = _make_runtime()
|
||||||
|
call = [_bash_call("ls")]
|
||||||
|
|
||||||
|
def _make_provider_state():
|
||||||
|
return {
|
||||||
|
"messages": [
|
||||||
|
AIMessage(
|
||||||
|
content="thinking...",
|
||||||
|
tool_calls=call,
|
||||||
|
additional_kwargs={
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": "call_ls",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "bash", "arguments": '{"command":"ls"}'},
|
||||||
|
"thought_signature": "sig-1",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"function_call": {"name": "bash", "arguments": '{"command":"ls"}'},
|
||||||
|
},
|
||||||
|
response_metadata={"finish_reason": "tool_calls"},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
mw._apply(_make_provider_state(), runtime)
|
||||||
|
|
||||||
|
result = mw._apply(_make_provider_state(), runtime)
|
||||||
|
assert result is not None
|
||||||
|
msg = result["messages"][0]
|
||||||
|
assert msg.tool_calls == []
|
||||||
|
assert "tool_calls" not in msg.additional_kwargs
|
||||||
|
assert "function_call" not in msg.additional_kwargs
|
||||||
|
assert msg.response_metadata["finish_reason"] == "stop"
|
||||||
|
|
||||||
|
|
||||||
class TestToolFrequencyDetection:
|
class TestToolFrequencyDetection:
|
||||||
"""Tests for per-tool-type frequency detection (Layer 2).
|
"""Tests for per-tool-type frequency detection (Layer 2).
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user