mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-02 06:48:21 +00:00
* refactor: thread app config through lead prompt * fix: honor explicit app config across runtime paths * style: format subagent executor tests * fix: thread resolved app config and guard subagents-only fallback Address two PR review findings: 1. _create_summarization_middleware passed the original (possibly None) app_config into create_chat_model, forcing the model factory back to ambient get_app_config() and risking config drift between the middleware's resolved view and the model's view. Pass the resolved AppConfig instance through end-to-end. 2. get_available_subagent_names accepted Any-typed config and forwarded it to is_host_bash_allowed, which reads ``.sandbox``. A SubagentsAppConfig (also accepted upstream as a sum-type input) has no ``.sandbox`` attribute and would be silently treated as "no sandbox configured", incorrectly disabling the bash subagent. Guard on hasattr and fall back to ambient lookup otherwise. Adds regression tests for both paths. * chore: simplify hasattr guard and tighten regression tests - Collapse if/else into ternary in get_available_subagent_names; hasattr(None, ...) is False so the explicit None check was redundant. - Drop comments that narrate the change rather than explain non-obvious WHY (test names already convey intent). - Replace stringly-typed sentinel "no-arg" in regression test with direct args tuple comparison. --------- Co-authored-by: greatmengqi <chenmengqi.0376@bytedance.com>
340 lines
12 KiB
Python
340 lines
12 KiB
Python
import asyncio
|
|
from types import SimpleNamespace
|
|
from unittest.mock import AsyncMock, call
|
|
|
|
import pytest
|
|
|
|
from deerflow.runtime.runs.manager import RunManager
|
|
from deerflow.runtime.runs.schemas import RunStatus
|
|
from deerflow.runtime.runs.worker import RunContext, _agent_factory_supports_app_config, _build_runtime_context, _install_runtime_context, _rollback_to_pre_run_checkpoint, run_agent
|
|
|
|
|
|
class FakeCheckpointer:
|
|
def __init__(self, *, put_result):
|
|
self.adelete_thread = AsyncMock()
|
|
self.aput = AsyncMock(return_value=put_result)
|
|
self.aput_writes = AsyncMock()
|
|
|
|
|
|
def test_build_runtime_context_includes_app_config_when_present():
|
|
app_config = object()
|
|
|
|
context = _build_runtime_context("thread-1", "run-1", None, app_config)
|
|
|
|
assert context["thread_id"] == "thread-1"
|
|
assert context["run_id"] == "run-1"
|
|
assert context["app_config"] is app_config
|
|
|
|
|
|
def test_install_runtime_context_preserves_existing_thread_id_and_threads_app_config():
|
|
app_config = object()
|
|
config = {"context": {"thread_id": "caller-thread"}}
|
|
|
|
_install_runtime_context(
|
|
config,
|
|
{
|
|
"thread_id": "record-thread",
|
|
"run_id": "run-1",
|
|
"app_config": app_config,
|
|
},
|
|
)
|
|
|
|
assert config["context"]["thread_id"] == "caller-thread"
|
|
assert config["context"]["run_id"] == "run-1"
|
|
assert config["context"]["app_config"] is app_config
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_run_agent_threads_explicit_app_config_into_config_only_factory():
|
|
run_manager = RunManager()
|
|
record = await run_manager.create("thread-1")
|
|
bridge = SimpleNamespace(
|
|
publish=AsyncMock(),
|
|
publish_end=AsyncMock(),
|
|
cleanup=AsyncMock(),
|
|
)
|
|
app_config = object()
|
|
captured: dict[str, object] = {}
|
|
|
|
class DummyAgent:
|
|
async def astream(self, graph_input, config=None, stream_mode=None, subgraphs=False):
|
|
captured["astream_context"] = config["context"]
|
|
yield {"messages": []}
|
|
|
|
def factory(*, config):
|
|
captured["factory_context"] = config["context"]
|
|
return DummyAgent()
|
|
|
|
await run_agent(
|
|
bridge,
|
|
run_manager,
|
|
record,
|
|
ctx=RunContext(checkpointer=None, app_config=app_config),
|
|
agent_factory=factory,
|
|
graph_input={},
|
|
config={},
|
|
)
|
|
await asyncio.sleep(0)
|
|
|
|
assert captured["factory_context"]["app_config"] is app_config
|
|
assert captured["astream_context"]["app_config"] is app_config
|
|
assert run_manager.get(record.run_id).status == RunStatus.success
|
|
bridge.publish_end.assert_awaited_once_with(record.run_id)
|
|
bridge.cleanup.assert_awaited_once_with(record.run_id, delay=60)
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_rollback_restores_snapshot_without_deleting_thread():
|
|
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
|
|
|
await _rollback_to_pre_run_checkpoint(
|
|
checkpointer=checkpointer,
|
|
thread_id="thread-1",
|
|
run_id="run-1",
|
|
pre_run_checkpoint_id="ckpt-1",
|
|
pre_run_snapshot={
|
|
"checkpoint_ns": "",
|
|
"checkpoint": {
|
|
"id": "ckpt-1",
|
|
"channel_versions": {"messages": 3},
|
|
"channel_values": {"messages": ["before"]},
|
|
},
|
|
"metadata": {"source": "input"},
|
|
"pending_writes": [
|
|
("task-a", "messages", {"content": "first"}),
|
|
("task-a", "status", "done"),
|
|
("task-b", "events", {"type": "tool"}),
|
|
],
|
|
},
|
|
snapshot_capture_failed=False,
|
|
)
|
|
|
|
checkpointer.adelete_thread.assert_not_awaited()
|
|
checkpointer.aput.assert_awaited_once_with(
|
|
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
|
{
|
|
"id": "ckpt-1",
|
|
"channel_versions": {"messages": 3},
|
|
"channel_values": {"messages": ["before"]},
|
|
},
|
|
{"source": "input"},
|
|
{"messages": 3},
|
|
)
|
|
assert checkpointer.aput_writes.await_args_list == [
|
|
call(
|
|
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
|
|
[("messages", {"content": "first"}), ("status", "done")],
|
|
task_id="task-a",
|
|
),
|
|
call(
|
|
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
|
|
[("events", {"type": "tool"})],
|
|
task_id="task-b",
|
|
),
|
|
]
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_rollback_deletes_thread_when_no_snapshot_exists():
|
|
checkpointer = FakeCheckpointer(put_result=None)
|
|
|
|
await _rollback_to_pre_run_checkpoint(
|
|
checkpointer=checkpointer,
|
|
thread_id="thread-1",
|
|
run_id="run-1",
|
|
pre_run_checkpoint_id=None,
|
|
pre_run_snapshot=None,
|
|
snapshot_capture_failed=False,
|
|
)
|
|
|
|
checkpointer.adelete_thread.assert_awaited_once_with("thread-1")
|
|
checkpointer.aput.assert_not_awaited()
|
|
checkpointer.aput_writes.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_rollback_raises_when_restore_config_has_no_checkpoint_id():
|
|
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}})
|
|
|
|
with pytest.raises(RuntimeError, match="did not return checkpoint_id"):
|
|
await _rollback_to_pre_run_checkpoint(
|
|
checkpointer=checkpointer,
|
|
thread_id="thread-1",
|
|
run_id="run-1",
|
|
pre_run_checkpoint_id="ckpt-1",
|
|
pre_run_snapshot={
|
|
"checkpoint_ns": "",
|
|
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
|
"metadata": {},
|
|
"pending_writes": [("task-a", "messages", "value")],
|
|
},
|
|
snapshot_capture_failed=False,
|
|
)
|
|
|
|
checkpointer.adelete_thread.assert_not_awaited()
|
|
checkpointer.aput.assert_awaited_once()
|
|
checkpointer.aput_writes.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace():
|
|
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
|
|
|
await _rollback_to_pre_run_checkpoint(
|
|
checkpointer=checkpointer,
|
|
thread_id="thread-1",
|
|
run_id="run-1",
|
|
pre_run_checkpoint_id="ckpt-1",
|
|
pre_run_snapshot={
|
|
"checkpoint_ns": None,
|
|
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
|
"metadata": {},
|
|
"pending_writes": [],
|
|
},
|
|
snapshot_capture_failed=False,
|
|
)
|
|
|
|
checkpointer.aput.assert_awaited_once_with(
|
|
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
|
{"id": "ckpt-1", "channel_versions": {}},
|
|
{},
|
|
{},
|
|
)
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_rollback_raises_on_malformed_pending_write_not_a_tuple():
|
|
"""pending_writes containing a non-3-tuple item should raise RuntimeError."""
|
|
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
|
|
|
with pytest.raises(RuntimeError, match="rollback failed: pending_write is not a 3-tuple"):
|
|
await _rollback_to_pre_run_checkpoint(
|
|
checkpointer=checkpointer,
|
|
thread_id="thread-1",
|
|
run_id="run-1",
|
|
pre_run_checkpoint_id="ckpt-1",
|
|
pre_run_snapshot={
|
|
"checkpoint_ns": "",
|
|
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
|
"metadata": {},
|
|
"pending_writes": [
|
|
("task-a", "messages", "valid"), # valid
|
|
["only", "two"], # malformed: only 2 elements
|
|
],
|
|
},
|
|
snapshot_capture_failed=False,
|
|
)
|
|
|
|
# aput succeeded but aput_writes should not be called due to malformed data
|
|
checkpointer.aput.assert_awaited_once()
|
|
checkpointer.aput_writes.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_rollback_raises_on_malformed_pending_write_non_string_channel():
|
|
"""pending_writes containing a non-string channel should raise RuntimeError."""
|
|
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
|
|
|
with pytest.raises(RuntimeError, match="rollback failed: pending_write has non-string channel"):
|
|
await _rollback_to_pre_run_checkpoint(
|
|
checkpointer=checkpointer,
|
|
thread_id="thread-1",
|
|
run_id="run-1",
|
|
pre_run_checkpoint_id="ckpt-1",
|
|
pre_run_snapshot={
|
|
"checkpoint_ns": "",
|
|
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
|
"metadata": {},
|
|
"pending_writes": [
|
|
("task-a", 123, "value"), # malformed: channel is not a string
|
|
],
|
|
},
|
|
snapshot_capture_failed=False,
|
|
)
|
|
|
|
checkpointer.aput.assert_awaited_once()
|
|
checkpointer.aput_writes.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_rollback_propagates_aput_writes_failure():
|
|
"""If aput_writes fails, the exception should propagate (not be swallowed)."""
|
|
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
|
# Simulate aput_writes failure
|
|
checkpointer.aput_writes.side_effect = RuntimeError("Database connection lost")
|
|
|
|
with pytest.raises(RuntimeError, match="Database connection lost"):
|
|
await _rollback_to_pre_run_checkpoint(
|
|
checkpointer=checkpointer,
|
|
thread_id="thread-1",
|
|
run_id="run-1",
|
|
pre_run_checkpoint_id="ckpt-1",
|
|
pre_run_snapshot={
|
|
"checkpoint_ns": "",
|
|
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
|
"metadata": {},
|
|
"pending_writes": [
|
|
("task-a", "messages", "value"),
|
|
],
|
|
},
|
|
snapshot_capture_failed=False,
|
|
)
|
|
|
|
# aput succeeded, aput_writes was called but failed
|
|
checkpointer.aput.assert_awaited_once()
|
|
checkpointer.aput_writes.assert_awaited_once()
|
|
|
|
|
|
def test_agent_factory_supports_app_config_detects_supported_signature():
|
|
def factory(*, config, app_config=None):
|
|
return (config, app_config)
|
|
|
|
assert _agent_factory_supports_app_config(factory) is True
|
|
|
|
|
|
def test_build_runtime_context_defaults_to_thread_and_run_id():
|
|
ctx = _build_runtime_context("thread-1", "run-1", None)
|
|
assert ctx == {"thread_id": "thread-1", "run_id": "run-1"}
|
|
|
|
|
|
def test_build_runtime_context_merges_caller_context():
|
|
"""Regression for issue #2677: keys from ``config['context']`` (e.g. ``agent_name``)
|
|
must be merged into the Runtime's context so that ``ToolRuntime.context`` — which
|
|
is what ``setup_agent`` reads — can see them."""
|
|
caller_context = {"agent_name": "my-agent", "is_bootstrap": True, "model_name": "gpt-4"}
|
|
|
|
ctx = _build_runtime_context("thread-1", "run-1", caller_context)
|
|
|
|
assert ctx["thread_id"] == "thread-1"
|
|
assert ctx["run_id"] == "run-1"
|
|
assert ctx["agent_name"] == "my-agent"
|
|
assert ctx["is_bootstrap"] is True
|
|
assert ctx["model_name"] == "gpt-4"
|
|
|
|
|
|
def test_build_runtime_context_caller_cannot_override_thread_id_or_run_id():
|
|
"""A malicious or buggy caller must not be able to overwrite the worker-assigned
|
|
``thread_id`` / ``run_id`` by stuffing them into ``config['context']``."""
|
|
caller_context = {"thread_id": "spoofed", "run_id": "spoofed", "agent_name": "ok"}
|
|
|
|
ctx = _build_runtime_context("real-thread", "real-run", caller_context)
|
|
|
|
assert ctx["thread_id"] == "real-thread"
|
|
assert ctx["run_id"] == "real-run"
|
|
assert ctx["agent_name"] == "ok"
|
|
|
|
|
|
def test_build_runtime_context_ignores_non_dict_caller_context():
|
|
ctx = _build_runtime_context("thread-1", "run-1", "not-a-dict")
|
|
assert ctx == {"thread_id": "thread-1", "run_id": "run-1"}
|
|
|
|
|
|
def test_agent_factory_supports_app_config_returns_false_when_signature_lookup_fails(monkeypatch):
|
|
class BrokenCallable:
|
|
def __call__(self, **kwargs):
|
|
return kwargs
|
|
|
|
monkeypatch.setattr("deerflow.runtime.runs.worker.inspect.signature", lambda _obj: (_ for _ in ()).throw(ValueError("boom")))
|
|
|
|
assert _agent_factory_supports_app_config(BrokenCallable()) is False
|