import asyncio from types import SimpleNamespace from unittest.mock import AsyncMock, call import pytest from langgraph.checkpoint.base import empty_checkpoint from langgraph.checkpoint.memory import InMemorySaver from deerflow.runtime.runs.manager import RunManager from deerflow.runtime.runs.schemas import RunStatus from deerflow.runtime.runs.worker import RunContext, _agent_factory_supports_app_config, _build_runtime_context, _install_runtime_context, _rollback_to_pre_run_checkpoint, run_agent class FakeCheckpointer: def __init__(self, *, put_result): self.adelete_thread = AsyncMock() self.aput = AsyncMock(return_value=put_result) self.aput_writes = AsyncMock() def _make_checkpoint(checkpoint_id: str, messages: list[str], version: int): checkpoint = empty_checkpoint() checkpoint["id"] = checkpoint_id checkpoint["channel_values"] = {"messages": messages} checkpoint["channel_versions"] = {"messages": version} return checkpoint def test_build_runtime_context_includes_app_config_when_present(): app_config = object() context = _build_runtime_context("thread-1", "run-1", None, app_config) assert context["thread_id"] == "thread-1" assert context["run_id"] == "run-1" assert context["app_config"] is app_config def test_install_runtime_context_preserves_existing_thread_id_and_threads_app_config(): app_config = object() config = {"context": {"thread_id": "caller-thread"}} _install_runtime_context( config, { "thread_id": "record-thread", "run_id": "run-1", "app_config": app_config, }, ) assert config["context"]["thread_id"] == "caller-thread" assert config["context"]["run_id"] == "run-1" assert config["context"]["app_config"] is app_config @pytest.mark.anyio async def test_run_agent_threads_explicit_app_config_into_config_only_factory(): run_manager = RunManager() record = await run_manager.create("thread-1") bridge = SimpleNamespace( publish=AsyncMock(), publish_end=AsyncMock(), cleanup=AsyncMock(), ) app_config = object() captured: dict[str, object] = {} class DummyAgent: async def astream(self, graph_input, config=None, stream_mode=None, subgraphs=False): captured["astream_context"] = config["context"] yield {"messages": []} def factory(*, config): captured["factory_context"] = config["context"] return DummyAgent() await run_agent( bridge, run_manager, record, ctx=RunContext(checkpointer=None, app_config=app_config), agent_factory=factory, graph_input={}, config={}, ) await asyncio.sleep(0) assert captured["factory_context"]["app_config"] is app_config assert captured["astream_context"]["app_config"] is app_config assert run_manager.get(record.run_id).status == RunStatus.success bridge.publish_end.assert_awaited_once_with(record.run_id) bridge.cleanup.assert_awaited_once_with(record.run_id, delay=60) @pytest.mark.anyio async def test_rollback_restores_snapshot_without_deleting_thread(): checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) await _rollback_to_pre_run_checkpoint( checkpointer=checkpointer, thread_id="thread-1", run_id="run-1", pre_run_checkpoint_id="ckpt-1", pre_run_snapshot={ "checkpoint_ns": "", "checkpoint": { "id": "ckpt-1", "channel_versions": {"messages": 3}, "channel_values": {"messages": ["before"]}, }, "metadata": {"source": "input"}, "pending_writes": [ ("task-a", "messages", {"content": "first"}), ("task-a", "status", "done"), ("task-b", "events", {"type": "tool"}), ], }, snapshot_capture_failed=False, ) checkpointer.adelete_thread.assert_not_awaited() checkpointer.aput.assert_awaited_once() restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} assert restored_checkpoint["id"] != "ckpt-1" assert "channel_versions" in restored_checkpoint assert "channel_values" in restored_checkpoint assert restored_checkpoint["channel_versions"] == {"messages": 3} assert restored_checkpoint["channel_values"] == {"messages": ["before"]} assert restored_metadata == {"source": "input"} assert new_versions == {"messages": 3} assert checkpointer.aput_writes.await_args_list == [ call( {"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}, [("messages", {"content": "first"}), ("status", "done")], task_id="task-a", ), call( {"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}, [("events", {"type": "tool"})], task_id="task-b", ), ] @pytest.mark.anyio async def test_rollback_restored_checkpoint_becomes_latest_with_real_checkpointer(): checkpointer = InMemorySaver() thread_config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} before_checkpoint = _make_checkpoint("0001", ["before"], 1) before_config = checkpointer.put(thread_config, before_checkpoint, {"step": 1}, {"messages": 1}) after_checkpoint = _make_checkpoint("0002", ["after"], 2) after_config = checkpointer.put(before_config, after_checkpoint, {"step": 2}, {"messages": 2}) checkpointer.put_writes(after_config, [("messages", "pending-after")], task_id="task-after") await _rollback_to_pre_run_checkpoint( checkpointer=checkpointer, thread_id="thread-1", run_id="run-1", pre_run_checkpoint_id="0001", pre_run_snapshot={ "checkpoint_ns": "", "checkpoint": before_checkpoint, "metadata": {"step": 1}, "pending_writes": [("task-before", "messages", "pending-before")], }, snapshot_capture_failed=False, ) latest = checkpointer.get_tuple(thread_config) assert latest is not None assert latest.config["configurable"]["checkpoint_id"] != "0001" assert latest.config["configurable"]["checkpoint_id"] != "0002" assert latest.checkpoint["channel_values"] == {"messages": ["before"]} assert latest.pending_writes == [("task-before", "messages", "pending-before")] assert ("task-after", "messages", "pending-after") not in latest.pending_writes @pytest.mark.anyio async def test_rollback_deletes_thread_when_no_snapshot_exists(): checkpointer = FakeCheckpointer(put_result=None) await _rollback_to_pre_run_checkpoint( checkpointer=checkpointer, thread_id="thread-1", run_id="run-1", pre_run_checkpoint_id=None, pre_run_snapshot=None, snapshot_capture_failed=False, ) checkpointer.adelete_thread.assert_awaited_once_with("thread-1") checkpointer.aput.assert_not_awaited() checkpointer.aput_writes.assert_not_awaited() @pytest.mark.anyio async def test_rollback_raises_when_restore_config_has_no_checkpoint_id(): checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}) with pytest.raises(RuntimeError, match="did not return checkpoint_id"): await _rollback_to_pre_run_checkpoint( checkpointer=checkpointer, thread_id="thread-1", run_id="run-1", pre_run_checkpoint_id="ckpt-1", pre_run_snapshot={ "checkpoint_ns": "", "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, "metadata": {}, "pending_writes": [("task-a", "messages", "value")], }, snapshot_capture_failed=False, ) checkpointer.adelete_thread.assert_not_awaited() checkpointer.aput.assert_awaited_once() checkpointer.aput_writes.assert_not_awaited() @pytest.mark.anyio async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace(): checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) await _rollback_to_pre_run_checkpoint( checkpointer=checkpointer, thread_id="thread-1", run_id="run-1", pre_run_checkpoint_id="ckpt-1", pre_run_snapshot={ "checkpoint_ns": None, "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, "metadata": {}, "pending_writes": [], }, snapshot_capture_failed=False, ) checkpointer.aput.assert_awaited_once() restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} assert restored_checkpoint["id"] != "ckpt-1" assert restored_checkpoint["channel_versions"] == {} assert restored_metadata == {} assert new_versions == {} @pytest.mark.anyio async def test_rollback_raises_on_malformed_pending_write_not_a_tuple(): """pending_writes containing a non-3-tuple item should raise RuntimeError.""" checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) with pytest.raises(RuntimeError, match="rollback failed: pending_write is not a 3-tuple"): await _rollback_to_pre_run_checkpoint( checkpointer=checkpointer, thread_id="thread-1", run_id="run-1", pre_run_checkpoint_id="ckpt-1", pre_run_snapshot={ "checkpoint_ns": "", "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, "metadata": {}, "pending_writes": [ ("task-a", "messages", "valid"), # valid ["only", "two"], # malformed: only 2 elements ], }, snapshot_capture_failed=False, ) # aput succeeded but aput_writes should not be called due to malformed data checkpointer.aput.assert_awaited_once() checkpointer.aput_writes.assert_not_awaited() @pytest.mark.anyio async def test_rollback_raises_on_malformed_pending_write_non_string_channel(): """pending_writes containing a non-string channel should raise RuntimeError.""" checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) with pytest.raises(RuntimeError, match="rollback failed: pending_write has non-string channel"): await _rollback_to_pre_run_checkpoint( checkpointer=checkpointer, thread_id="thread-1", run_id="run-1", pre_run_checkpoint_id="ckpt-1", pre_run_snapshot={ "checkpoint_ns": "", "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, "metadata": {}, "pending_writes": [ ("task-a", 123, "value"), # malformed: channel is not a string ], }, snapshot_capture_failed=False, ) checkpointer.aput.assert_awaited_once() checkpointer.aput_writes.assert_not_awaited() @pytest.mark.anyio async def test_rollback_propagates_aput_writes_failure(): """If aput_writes fails, the exception should propagate (not be swallowed).""" checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) # Simulate aput_writes failure checkpointer.aput_writes.side_effect = RuntimeError("Database connection lost") with pytest.raises(RuntimeError, match="Database connection lost"): await _rollback_to_pre_run_checkpoint( checkpointer=checkpointer, thread_id="thread-1", run_id="run-1", pre_run_checkpoint_id="ckpt-1", pre_run_snapshot={ "checkpoint_ns": "", "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, "metadata": {}, "pending_writes": [ ("task-a", "messages", "value"), ], }, snapshot_capture_failed=False, ) # aput succeeded, aput_writes was called but failed checkpointer.aput.assert_awaited_once() checkpointer.aput_writes.assert_awaited_once() def test_agent_factory_supports_app_config_detects_supported_signature(): def factory(*, config, app_config=None): return (config, app_config) assert _agent_factory_supports_app_config(factory) is True def test_build_runtime_context_defaults_to_thread_and_run_id(): ctx = _build_runtime_context("thread-1", "run-1", None) assert ctx == {"thread_id": "thread-1", "run_id": "run-1"} def test_build_runtime_context_merges_caller_context(): """Regression for issue #2677: keys from ``config['context']`` (e.g. ``agent_name``) must be merged into the Runtime's context so that ``ToolRuntime.context`` — which is what ``setup_agent`` reads — can see them.""" caller_context = {"agent_name": "my-agent", "is_bootstrap": True, "model_name": "gpt-4"} ctx = _build_runtime_context("thread-1", "run-1", caller_context) assert ctx["thread_id"] == "thread-1" assert ctx["run_id"] == "run-1" assert ctx["agent_name"] == "my-agent" assert ctx["is_bootstrap"] is True assert ctx["model_name"] == "gpt-4" def test_build_runtime_context_caller_cannot_override_thread_id_or_run_id(): """A malicious or buggy caller must not be able to overwrite the worker-assigned ``thread_id`` / ``run_id`` by stuffing them into ``config['context']``.""" caller_context = {"thread_id": "spoofed", "run_id": "spoofed", "agent_name": "ok"} ctx = _build_runtime_context("real-thread", "real-run", caller_context) assert ctx["thread_id"] == "real-thread" assert ctx["run_id"] == "real-run" assert ctx["agent_name"] == "ok" def test_build_runtime_context_ignores_non_dict_caller_context(): ctx = _build_runtime_context("thread-1", "run-1", "not-a-dict") assert ctx == {"thread_id": "thread-1", "run_id": "run-1"} def test_agent_factory_supports_app_config_returns_false_when_signature_lookup_fails(monkeypatch): class BrokenCallable: def __call__(self, **kwargs): return kwargs monkeypatch.setattr("deerflow.runtime.runs.worker.inspect.signature", lambda _obj: (_ for _ in ()).throw(ValueError("boom"))) assert _agent_factory_supports_app_config(BrokenCallable()) is False