diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index c6704c02d..634b8b9d1 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -98,6 +98,44 @@ def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]: _DEFAULT_ASSISTANT_ID = "lead_agent" +# Whitelist of run-context keys that the langgraph-compat layer forwards from +# ``body.context`` into the run config. ``config["context"]`` exists in +# LangGraph >=0.6, but these values must be written to both ``configurable`` +# (for legacy ``_get_runtime_config`` consumers) and ``context`` because +# LangGraph >=1.1.9 no longer makes ``ToolRuntime.context`` fall back to +# ``configurable`` for consumers like ``setup_agent``. +_CONTEXT_CONFIGURABLE_KEYS: frozenset[str] = frozenset( + { + "model_name", + "mode", + "thinking_enabled", + "reasoning_effort", + "is_plan_mode", + "subagent_enabled", + "max_concurrent_subagents", + "agent_name", + "is_bootstrap", + } +) + + +def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, Any] | None) -> None: + """Merge whitelisted keys from ``body.context`` into both ``config['configurable']`` + and ``config['context']`` so they are visible to legacy configurable readers and + to LangGraph ``ToolRuntime.context`` consumers (e.g. the ``setup_agent`` tool — + see issue #2677).""" + if not context: + return + configurable = config.setdefault("configurable", {}) + runtime_context = config.setdefault("context", {}) + for key in _CONTEXT_CONFIGURABLE_KEYS: + if key in context: + if isinstance(configurable, dict): + configurable.setdefault(key, context[key]) + if isinstance(runtime_context, dict): + runtime_context.setdefault(key, context[key]) + + def resolve_agent_factory(assistant_id: str | None): """Resolve the agent factory callable from config. @@ -245,27 +283,11 @@ async def start_run( graph_input = normalize_input(body.input) config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id) - # Merge DeerFlow-specific context overrides into configurable. + # Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``. # The ``context`` field is a custom extension for the langgraph-compat layer # that carries agent configuration (model_name, thinking_enabled, etc.). # Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored. - context = getattr(body, "context", None) - if context: - _CONTEXT_CONFIGURABLE_KEYS = { - "model_name", - "mode", - "thinking_enabled", - "reasoning_effort", - "is_plan_mode", - "subagent_enabled", - "max_concurrent_subagents", - "agent_name", - "is_bootstrap", - } - configurable = config.setdefault("configurable", {}) - for key in _CONTEXT_CONFIGURABLE_KEYS: - if key in context: - configurable.setdefault(key, context[key]) + merge_run_context_overrides(config, getattr(body, "context", None)) stream_modes = normalize_stream_modes(body.stream_mode) diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index da05fa675..1223c2127 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -39,6 +39,24 @@ logger = logging.getLogger(__name__) _VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"} +def _build_runtime_context(thread_id: str, run_id: str, caller_context: Any | None) -> dict[str, Any]: + """Build the dict that becomes ``ToolRuntime.context`` for the run. + + Always includes ``thread_id`` and ``run_id``. Additional keys from the caller's + ``config['context']`` (e.g. ``agent_name`` for the bootstrap flow — issue #2677) + are merged in but never override ``thread_id``/``run_id``. + + langgraph 1.1+ surfaces this as ``runtime.context`` via the parent runtime stored + under ``config['configurable']['__pregel_runtime']`` — see + ``langgraph.pregel.main`` where ``parent_runtime.merge(...)`` is invoked. + """ + runtime_ctx: dict[str, Any] = {"thread_id": thread_id, "run_id": run_id} + if isinstance(caller_context, dict): + for key, value in caller_context.items(): + runtime_ctx.setdefault(key, value) + return runtime_ctx + + @dataclass(frozen=True) class RunContext: """Infrastructure dependencies for a single agent run. @@ -169,15 +187,15 @@ async def run_agent( from langchain_core.runnables import RunnableConfig from langgraph.runtime import Runtime - # Inject runtime context so middlewares can access thread_id - # (langgraph-cli does this automatically; we must do it manually) - runtime = Runtime(context={"thread_id": thread_id, "run_id": run_id}, store=store) - # If the caller already set a ``context`` key (LangGraph >= 0.6.0 - # prefers it over ``configurable`` for thread-level data), make - # sure ``thread_id`` is available there too. + # Inject runtime context so middlewares and tools (via ToolRuntime.context) can + # access thread-level data. langgraph-cli does this automatically; we must do it + # manually here because we drive the graph through ``agent.astream(config=...)`` + # without passing the official ``context=`` parameter. + runtime_ctx = _build_runtime_context(thread_id, run_id, config.get("context")) if "context" in config and isinstance(config["context"], dict): config["context"].setdefault("thread_id", thread_id) config["context"].setdefault("run_id", run_id) + runtime = Runtime(context=runtime_ctx, store=store) config.setdefault("configurable", {})["__pregel_runtime"] = runtime # Inject RunJournal as a LangChain callback handler. diff --git a/backend/tests/test_gateway_services.py b/backend/tests/test_gateway_services.py index e0fcda294..013991b82 100644 --- a/backend/tests/test_gateway_services.py +++ b/backend/tests/test_gateway_services.py @@ -256,6 +256,37 @@ def test_context_merges_into_configurable(): assert "thread_id" not in {k for k in context if k in _CONTEXT_CONFIGURABLE_KEYS} +def test_merge_run_context_overrides_propagates_to_runtime_context(): + """Regression for issue #2677: ``agent_name`` (and other whitelisted keys) from + ``body.context`` must be propagated into BOTH ``config['configurable']`` and + ``config['context']``. Previously only ``configurable`` was populated, so after + the LangGraph 1.1.x upgrade removed the fallback from ``configurable``, the + ``setup_agent`` tool read ``runtime.context`` with ``agent_name=None`` and + silently wrote SOUL.md to the global base_dir. + """ + from app.gateway.services import build_run_config, merge_run_context_overrides + + config = build_run_config("thread-1", None, None) + merge_run_context_overrides(config, {"agent_name": "my-agent", "is_bootstrap": True, "thread_id": "ignored"}) + + assert config["configurable"]["agent_name"] == "my-agent" + assert config["configurable"]["is_bootstrap"] is True + assert config["context"]["agent_name"] == "my-agent" + assert config["context"]["is_bootstrap"] is True + # Non-whitelisted keys are not forwarded. + assert "thread_id" not in config["context"] + + +def test_merge_run_context_overrides_noop_for_empty_context(): + from app.gateway.services import build_run_config, merge_run_context_overrides + + config = build_run_config("thread-1", None, None) + before = {k: dict(v) if isinstance(v, dict) else v for k, v in config.items()} + merge_run_context_overrides(config, None) + merge_run_context_overrides(config, {}) + assert config == before + + def test_context_does_not_override_existing_configurable(): """Values already in config.configurable must NOT be overridden by context.""" from app.gateway.services import build_run_config diff --git a/backend/tests/test_run_worker_rollback.py b/backend/tests/test_run_worker_rollback.py index f960f800f..b2b8da77f 100644 --- a/backend/tests/test_run_worker_rollback.py +++ b/backend/tests/test_run_worker_rollback.py @@ -2,7 +2,7 @@ from unittest.mock import AsyncMock, call import pytest -from deerflow.runtime.runs.worker import _agent_factory_supports_app_config, _rollback_to_pre_run_checkpoint +from deerflow.runtime.runs.worker import _agent_factory_supports_app_config, _build_runtime_context, _rollback_to_pre_run_checkpoint class FakeCheckpointer: @@ -221,6 +221,43 @@ def test_agent_factory_supports_app_config_detects_supported_signature(): assert _agent_factory_supports_app_config(factory) is True +def test_build_runtime_context_defaults_to_thread_and_run_id(): + ctx = _build_runtime_context("thread-1", "run-1", None) + assert ctx == {"thread_id": "thread-1", "run_id": "run-1"} + + +def test_build_runtime_context_merges_caller_context(): + """Regression for issue #2677: keys from ``config['context']`` (e.g. ``agent_name``) + must be merged into the Runtime's context so that ``ToolRuntime.context`` — which + is what ``setup_agent`` reads — can see them.""" + caller_context = {"agent_name": "my-agent", "is_bootstrap": True, "model_name": "gpt-4"} + + ctx = _build_runtime_context("thread-1", "run-1", caller_context) + + assert ctx["thread_id"] == "thread-1" + assert ctx["run_id"] == "run-1" + assert ctx["agent_name"] == "my-agent" + assert ctx["is_bootstrap"] is True + assert ctx["model_name"] == "gpt-4" + + +def test_build_runtime_context_caller_cannot_override_thread_id_or_run_id(): + """A malicious or buggy caller must not be able to overwrite the worker-assigned + ``thread_id`` / ``run_id`` by stuffing them into ``config['context']``.""" + caller_context = {"thread_id": "spoofed", "run_id": "spoofed", "agent_name": "ok"} + + ctx = _build_runtime_context("real-thread", "real-run", caller_context) + + assert ctx["thread_id"] == "real-thread" + assert ctx["run_id"] == "real-run" + assert ctx["agent_name"] == "ok" + + +def test_build_runtime_context_ignores_non_dict_caller_context(): + ctx = _build_runtime_context("thread-1", "run-1", "not-a-dict") + assert ctx == {"thread_id": "thread-1", "run_id": "run-1"} + + def test_agent_factory_supports_app_config_returns_false_when_signature_lookup_fails(monkeypatch): class BrokenCallable: def __call__(self, **kwargs):