mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
fix(todo-middleware): prevent premature agent exit with incomplete todos (#2135)
* fix(todo-middleware): prevent premature agent exit with incomplete todos When plan mode is active (is_plan_mode=True), the agent occasionally exits the loop and outputs a final response while todo items are still incomplete. This happens because the routing edge only checks for tool_calls, not todo completion state. Fixes #2112 Add an after_model override to TodoMiddleware with @hook_config(can_jump_to=["model"]). When the model produces a response with no tool calls but there are still incomplete todos, the middleware injects a todo_completion_reminder HumanMessage and returns jump_to=model to force another model turn. A cap of 2 reminders prevents infinite loops when the agent cannot make further progress. Also adds _completion_reminder_count() helper and 14 new unit tests covering all edge cases of the new after_model / aafter_model logic. * Remove unnecessary blank line in test file * Fix runtime argument annotation in before_model * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: octo-patch <octo-patch@github.com> Co-authored-by: Willem Jiang <willem.jiang@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
07fc25d285
commit
e4f896e90d
@ -1,9 +1,14 @@
|
||||
"""Middleware that extends TodoListMiddleware with context-loss detection.
|
||||
"""Middleware that extends TodoListMiddleware with context-loss detection and premature-exit prevention.
|
||||
|
||||
When the message history is truncated (e.g., by SummarizationMiddleware), the
|
||||
original `write_todos` tool call and its ToolMessage can be scrolled out of the
|
||||
active context window. This middleware detects that situation and injects a
|
||||
reminder message so the model still knows about the outstanding todo list.
|
||||
|
||||
Additionally, this middleware prevents the agent from exiting the loop while
|
||||
there are still incomplete todo items. When the model produces a final response
|
||||
(no tool calls) but todos are not yet complete, the middleware injects a reminder
|
||||
and jumps back to the model node to force continued engagement.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@ -12,6 +17,7 @@ from typing import Any, override
|
||||
|
||||
from langchain.agents.middleware import TodoListMiddleware
|
||||
from langchain.agents.middleware.todo import PlanningState, Todo
|
||||
from langchain.agents.middleware.types import hook_config
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
@ -34,6 +40,11 @@ def _reminder_in_messages(messages: list[Any]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _completion_reminder_count(messages: list[Any]) -> int:
|
||||
"""Return the number of todo_completion_reminder HumanMessages in *messages*."""
|
||||
return sum(1 for msg in messages if isinstance(msg, HumanMessage) and getattr(msg, "name", None) == "todo_completion_reminder")
|
||||
|
||||
|
||||
def _format_todos(todos: list[Todo]) -> str:
|
||||
"""Format a list of Todo items into a human-readable string."""
|
||||
lines: list[str] = []
|
||||
@ -57,7 +68,7 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
def before_model(
|
||||
self,
|
||||
state: PlanningState,
|
||||
runtime: Runtime, # noqa: ARG002
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Inject a todo-list reminder when write_todos has left the context window."""
|
||||
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
|
||||
@ -98,3 +109,71 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async version of before_model."""
|
||||
return self.before_model(state, runtime)
|
||||
|
||||
# Maximum number of completion reminders before allowing the agent to exit.
|
||||
# This prevents infinite loops when the agent cannot make further progress.
|
||||
_MAX_COMPLETION_REMINDERS = 2
|
||||
|
||||
@hook_config(can_jump_to=["model"])
|
||||
@override
|
||||
def after_model(
|
||||
self,
|
||||
state: PlanningState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Prevent premature agent exit when todo items are still incomplete.
|
||||
|
||||
In addition to the base class check for parallel ``write_todos`` calls,
|
||||
this override intercepts model responses that have no tool calls while
|
||||
there are still incomplete todo items. It injects a reminder
|
||||
``HumanMessage`` and jumps back to the model node so the agent
|
||||
continues working through the todo list.
|
||||
|
||||
A retry cap of ``_MAX_COMPLETION_REMINDERS`` (default 2) prevents
|
||||
infinite loops when the agent cannot make further progress.
|
||||
"""
|
||||
# 1. Preserve base class logic (parallel write_todos detection).
|
||||
base_result = super().after_model(state, runtime)
|
||||
if base_result is not None:
|
||||
return base_result
|
||||
|
||||
# 2. Only intervene when the agent wants to exit (no tool calls).
|
||||
messages = state.get("messages") or []
|
||||
last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None)
|
||||
if not last_ai or last_ai.tool_calls:
|
||||
return None
|
||||
|
||||
# 3. Allow exit when all todos are completed or there are no todos.
|
||||
todos: list[Todo] = state.get("todos") or [] # type: ignore[assignment]
|
||||
if not todos or all(t.get("status") == "completed" for t in todos):
|
||||
return None
|
||||
|
||||
# 4. Enforce a reminder cap to prevent infinite re-engagement loops.
|
||||
if _completion_reminder_count(messages) >= self._MAX_COMPLETION_REMINDERS:
|
||||
return None
|
||||
|
||||
# 5. Inject a reminder and force the agent back to the model.
|
||||
incomplete = [t for t in todos if t.get("status") != "completed"]
|
||||
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
|
||||
reminder = HumanMessage(
|
||||
name="todo_completion_reminder",
|
||||
content=(
|
||||
"<system_reminder>\n"
|
||||
"You have incomplete todo items that must be finished before giving your final response:\n\n"
|
||||
f"{incomplete_text}\n\n"
|
||||
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
|
||||
"as you finish them, and only respond when all items are done.\n"
|
||||
"</system_reminder>"
|
||||
),
|
||||
)
|
||||
return {"jump_to": "model", "messages": [reminder]}
|
||||
|
||||
@override
|
||||
@hook_config(can_jump_to=["model"])
|
||||
async def aafter_model(
|
||||
self,
|
||||
state: PlanningState,
|
||||
runtime: Runtime,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async version of after_model."""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
@ -7,6 +7,7 @@ from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from deerflow.agents.middlewares.todo_middleware import (
|
||||
TodoMiddleware,
|
||||
_completion_reminder_count,
|
||||
_format_todos,
|
||||
_reminder_in_messages,
|
||||
_todos_in_messages,
|
||||
@ -154,3 +155,148 @@ class TestAbeforeModel:
|
||||
result = asyncio.run(mw.abefore_model(state, _make_runtime()))
|
||||
assert result is not None
|
||||
assert result["messages"][0].name == "todo_reminder"
|
||||
|
||||
|
||||
def _completion_reminder_msg():
|
||||
return HumanMessage(name="todo_completion_reminder", content="finish your todos")
|
||||
|
||||
|
||||
def _ai_no_tool_calls():
|
||||
return AIMessage(content="I'm done!")
|
||||
|
||||
|
||||
def _incomplete_todos():
|
||||
return [
|
||||
{"status": "completed", "content": "Step 1"},
|
||||
{"status": "in_progress", "content": "Step 2"},
|
||||
{"status": "pending", "content": "Step 3"},
|
||||
]
|
||||
|
||||
|
||||
def _all_completed_todos():
|
||||
return [
|
||||
{"status": "completed", "content": "Step 1"},
|
||||
{"status": "completed", "content": "Step 2"},
|
||||
]
|
||||
|
||||
|
||||
class TestCompletionReminderCount:
|
||||
def test_zero_when_no_reminders(self):
|
||||
msgs = [HumanMessage(content="hi"), _ai_no_tool_calls()]
|
||||
assert _completion_reminder_count(msgs) == 0
|
||||
|
||||
def test_counts_completion_reminders(self):
|
||||
msgs = [_completion_reminder_msg(), _completion_reminder_msg()]
|
||||
assert _completion_reminder_count(msgs) == 2
|
||||
|
||||
def test_does_not_count_todo_reminders(self):
|
||||
msgs = [_reminder_msg(), _completion_reminder_msg()]
|
||||
assert _completion_reminder_count(msgs) == 1
|
||||
|
||||
|
||||
class TestAfterModel:
|
||||
def test_returns_none_when_agent_still_using_tools(self):
|
||||
mw = TodoMiddleware()
|
||||
state = {
|
||||
"messages": [_ai_with_write_todos()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
assert mw.after_model(state, _make_runtime()) is None
|
||||
|
||||
def test_returns_none_when_no_todos(self):
|
||||
mw = TodoMiddleware()
|
||||
state = {
|
||||
"messages": [_ai_no_tool_calls()],
|
||||
"todos": [],
|
||||
}
|
||||
assert mw.after_model(state, _make_runtime()) is None
|
||||
|
||||
def test_returns_none_when_todos_is_none(self):
|
||||
mw = TodoMiddleware()
|
||||
state = {
|
||||
"messages": [_ai_no_tool_calls()],
|
||||
"todos": None,
|
||||
}
|
||||
assert mw.after_model(state, _make_runtime()) is None
|
||||
|
||||
def test_returns_none_when_all_completed(self):
|
||||
mw = TodoMiddleware()
|
||||
state = {
|
||||
"messages": [_ai_no_tool_calls()],
|
||||
"todos": _all_completed_todos(),
|
||||
}
|
||||
assert mw.after_model(state, _make_runtime()) is None
|
||||
|
||||
def test_returns_none_when_no_messages(self):
|
||||
mw = TodoMiddleware()
|
||||
state = {
|
||||
"messages": [],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
assert mw.after_model(state, _make_runtime()) is None
|
||||
|
||||
def test_injects_reminder_and_jumps_to_model_when_incomplete(self):
|
||||
mw = TodoMiddleware()
|
||||
state = {
|
||||
"messages": [HumanMessage(content="hi"), _ai_no_tool_calls()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
result = mw.after_model(state, _make_runtime())
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "model"
|
||||
assert len(result["messages"]) == 1
|
||||
reminder = result["messages"][0]
|
||||
assert isinstance(reminder, HumanMessage)
|
||||
assert reminder.name == "todo_completion_reminder"
|
||||
assert "Step 2" in reminder.content
|
||||
assert "Step 3" in reminder.content
|
||||
|
||||
def test_reminder_lists_only_incomplete_items(self):
|
||||
mw = TodoMiddleware()
|
||||
state = {
|
||||
"messages": [_ai_no_tool_calls()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
result = mw.after_model(state, _make_runtime())
|
||||
content = result["messages"][0].content
|
||||
assert "Step 1" not in content # completed — should not appear
|
||||
assert "Step 2" in content
|
||||
assert "Step 3" in content
|
||||
|
||||
def test_allows_exit_after_max_reminders(self):
|
||||
mw = TodoMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_completion_reminder_msg(),
|
||||
_completion_reminder_msg(),
|
||||
_ai_no_tool_calls(),
|
||||
],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
assert mw.after_model(state, _make_runtime()) is None
|
||||
|
||||
def test_still_sends_reminder_before_cap(self):
|
||||
mw = TodoMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_completion_reminder_msg(), # 1 reminder so far
|
||||
_ai_no_tool_calls(),
|
||||
],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
result = mw.after_model(state, _make_runtime())
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "model"
|
||||
|
||||
|
||||
class TestAafterModel:
|
||||
def test_delegates_to_sync(self):
|
||||
mw = TodoMiddleware()
|
||||
state = {
|
||||
"messages": [_ai_no_tool_calls()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
result = asyncio.run(mw.aafter_model(state, _make_runtime()))
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "model"
|
||||
assert result["messages"][0].name == "todo_completion_reminder"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user