From 9892a7d46823888351bef32beddae247ffb1d294 Mon Sep 17 00:00:00 2001 From: YuJitang Date: Sun, 10 May 2026 22:47:30 +0800 Subject: [PATCH 01/86] fix: bucket subagent token usage into parent run totals (#2838) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: bucket subagent token usage into RunRow.subagent_tokens Add caller-bucketed token tracking to RunJournal so subagent and middleware LLM calls are written to the correct RunRow columns instead of all falling into lead_agent_tokens (default 0). - RunJournal: accumulate _lead_agent_tokens / _subagent_tokens / _middleware_tokens in on_llm_end, deduped by langchain run_id. Add record_external_llm_usage_records() for external sources (respects track_token_usage flag). Return caller buckets from get_completion_data(). - SubagentTokenCollector: new lightweight callback handler that collects LLM usage within subagent execution. - SubagentExecutor: wire collector into subagent run_config and sync records to SubagentResult on every chunk (timeout/cancel safe). - SubagentResult: add token_usage_records and usage_reported fields. - task_tool: report subagent usage to parent RunJournal on every terminal status (COMPLETED/FAILED/CANCELLED/TIMED_OUT), including the CancelledError path, guarded against double-reporting. No DB migration needed — RunRow columns already exist. * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix: address token usage review feedback * Address review follow-ups --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- .../harness/deerflow/runtime/journal.py | 68 ++++- .../harness/deerflow/subagents/executor.py | 16 ++ .../deerflow/subagents/token_collector.py | 63 +++++ .../deerflow/tools/builtins/task_tool.py | 139 +++++++--- backend/tests/test_run_journal.py | 238 ++++++++++++++++++ .../tests/test_subagent_token_collector.py | 161 ++++++++++++ backend/tests/test_task_tool_core_logic.py | 229 +++++++++++++---- frontend/tests/unit/core/threads/api.test.ts | 6 +- 8 files changed, 843 insertions(+), 77 deletions(-) create mode 100644 backend/packages/harness/deerflow/subagents/token_collector.py create mode 100644 backend/tests/test_subagent_token_collector.py diff --git a/backend/packages/harness/deerflow/runtime/journal.py b/backend/packages/harness/deerflow/runtime/journal.py index a0c2d029b..41e48efed 100644 --- a/backend/packages/harness/deerflow/runtime/journal.py +++ b/backend/packages/harness/deerflow/runtime/journal.py @@ -63,6 +63,15 @@ class RunJournal(BaseCallbackHandler): self._total_tokens = 0 self._llm_call_count = 0 + # Caller-bucketed token accumulators + self._lead_agent_tokens = 0 + self._subagent_tokens = 0 + self._middleware_tokens = 0 + + # Dedup: LangChain may fire on_llm_end multiple times for the same run_id + self._counted_llm_run_ids: set[str] = set() + self._counted_external_source_ids: set[str] = set() + # Convenience fields self._last_ai_msg: str | None = None self._first_human_msg: str | None = None @@ -214,19 +223,28 @@ class RunJournal(BaseCallbackHandler): }, ) - # Token accumulation + # Token accumulation (dedup by langchain run_id to avoid double-counting + # when the callback fires more than once for the same response) if self._track_tokens: input_tk = usage_dict.get("input_tokens", 0) or 0 output_tk = usage_dict.get("output_tokens", 0) or 0 total_tk = usage_dict.get("total_tokens", 0) or 0 if total_tk == 0: total_tk = input_tk + output_tk - if total_tk > 0: + if total_tk > 0 and rid not in self._counted_llm_run_ids: + self._counted_llm_run_ids.add(rid) self._total_input_tokens += input_tk self._total_output_tokens += output_tk self._total_tokens += total_tk self._llm_call_count += 1 + if caller.startswith("subagent:"): + self._subagent_tokens += total_tk + elif caller.startswith("middleware:"): + self._middleware_tokens += total_tk + else: + self._lead_agent_tokens += total_tk + def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: self._llm_start_times.pop(str(run_id), None) self._put(event_type="llm.error", category="trace", content=str(error)) @@ -330,6 +348,49 @@ class RunJournal(BaseCallbackHandler): # -- Public methods (called by worker) -- + def record_external_llm_usage_records( + self, + records: list[dict[str, int | str]], + ) -> None: + """Record token usage from external sources (e.g., subagents). + + Each record should contain: + source_run_id: Unique identifier to prevent double-counting + caller: Caller tag (e.g. "subagent:general-purpose") + input_tokens: Input token count + output_tokens: Output token count + total_tokens: Total token count (computed from input+output if 0/missing) + """ + if not self._track_tokens: + return + for record in records: + source_id = str(record.get("source_run_id", "")) + if not source_id: + continue + if source_id in self._counted_external_source_ids: + continue + + total_tk = record.get("total_tokens", 0) or 0 + if total_tk <= 0: + input_tk = record.get("input_tokens", 0) or 0 + output_tk = record.get("output_tokens", 0) or 0 + total_tk = input_tk + output_tk + if total_tk <= 0: + continue + + self._counted_external_source_ids.add(source_id) + self._total_input_tokens += record.get("input_tokens", 0) or 0 + self._total_output_tokens += record.get("output_tokens", 0) or 0 + self._total_tokens += total_tk + + caller = str(record.get("caller", "")) + if caller.startswith("subagent:"): + self._subagent_tokens += total_tk + elif caller.startswith("middleware:"): + self._middleware_tokens += total_tk + else: + self._lead_agent_tokens += total_tk + def set_first_human_message(self, content: str) -> None: """Record the first human message for convenience fields.""" self._first_human_msg = content[:2000] if content else None @@ -376,6 +437,9 @@ class RunJournal(BaseCallbackHandler): "total_output_tokens": self._total_output_tokens, "total_tokens": self._total_tokens, "llm_call_count": self._llm_call_count, + "lead_agent_tokens": self._lead_agent_tokens, + "subagent_tokens": self._subagent_tokens, + "middleware_tokens": self._middleware_tokens, "message_count": self._msg_count, "last_ai_message": self._last_ai_msg, "first_human_message": self._first_human_msg, diff --git a/backend/packages/harness/deerflow/subagents/executor.py b/backend/packages/harness/deerflow/subagents/executor.py index 64ba4c2c5..a2fec6432 100644 --- a/backend/packages/harness/deerflow/subagents/executor.py +++ b/backend/packages/harness/deerflow/subagents/executor.py @@ -26,6 +26,7 @@ from deerflow.models import create_chat_model from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools from deerflow.skills.types import Skill from deerflow.subagents.config import SubagentConfig, resolve_subagent_model_name +from deerflow.subagents.token_collector import SubagentTokenCollector logger = logging.getLogger(__name__) @@ -70,6 +71,8 @@ class SubagentResult: started_at: datetime | None = None completed_at: datetime | None = None ai_messages: list[dict[str, Any]] | None = None + token_usage_records: list[dict[str, int | str]] = field(default_factory=list) + usage_reported: bool = False cancel_event: threading.Event = field(default_factory=threading.Event, repr=False) def __post_init__(self): @@ -412,13 +415,20 @@ class SubagentExecutor: ai_messages = [] result.ai_messages = ai_messages + collector: SubagentTokenCollector | None = None try: state, filtered_tools = await self._build_initial_state(task) agent = self._create_agent(filtered_tools) + # Token collector for subagent LLM calls + collector_caller = f"subagent:{self.config.name}" + collector = SubagentTokenCollector(caller=collector_caller) + # Build config with thread_id for sandbox access and recursion limit run_config: RunnableConfig = { "recursion_limit": self.config.max_turns, + "callbacks": [collector], + "tags": [collector_caller], } context: dict[str, Any] = {} if self.thread_id: @@ -441,6 +451,8 @@ class SubagentExecutor: result.status = SubagentStatus.CANCELLED result.error = "Cancelled by user" result.completed_at = datetime.now() + if collector is not None: + result.token_usage_records = collector.snapshot_records() return result async for chunk in agent.astream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type] @@ -455,6 +467,7 @@ class SubagentExecutor: result.status = SubagentStatus.CANCELLED result.error = "Cancelled by user" result.completed_at = datetime.now() + result.token_usage_records = collector.snapshot_records() return result final_state = chunk @@ -481,6 +494,7 @@ class SubagentExecutor: logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(ai_messages)}") logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed async execution") + result.token_usage_records = collector.snapshot_records() if final_state is None: logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no final state") @@ -560,6 +574,8 @@ class SubagentExecutor: result.status = SubagentStatus.FAILED result.error = str(e) result.completed_at = datetime.now() + if collector is not None: + result.token_usage_records = collector.snapshot_records() return result diff --git a/backend/packages/harness/deerflow/subagents/token_collector.py b/backend/packages/harness/deerflow/subagents/token_collector.py new file mode 100644 index 000000000..56b419f01 --- /dev/null +++ b/backend/packages/harness/deerflow/subagents/token_collector.py @@ -0,0 +1,63 @@ +"""Callback handler that collects LLM token usage within a subagent. + +Each subagent execution creates its own collector. After the subagent +finishes, the collected records are transferred to the parent RunJournal +via :meth:`RunJournal.record_external_llm_usage_records`. +""" + +from __future__ import annotations + +from typing import Any + +from langchain_core.callbacks import BaseCallbackHandler + + +class SubagentTokenCollector(BaseCallbackHandler): + """Lightweight callback handler that collects LLM token usage within a subagent.""" + + def __init__(self, caller: str): + super().__init__() + self.caller = caller + self._records: list[dict[str, int | str]] = [] + self._counted_run_ids: set[str] = set() + + def on_llm_end( + self, + response: Any, + *, + run_id: Any, + tags: list[str] | None = None, + **kwargs: Any, + ) -> None: + rid = str(run_id) + if rid in self._counted_run_ids: + return + + for generation in response.generations: + for gen in generation: + if not hasattr(gen, "message"): + continue + usage = getattr(gen.message, "usage_metadata", None) + usage_dict = dict(usage) if usage else {} + input_tk = usage_dict.get("input_tokens", 0) or 0 + output_tk = usage_dict.get("output_tokens", 0) or 0 + total_tk = usage_dict.get("total_tokens", 0) or 0 + if total_tk <= 0: + total_tk = input_tk + output_tk + if total_tk <= 0: + continue + self._counted_run_ids.add(rid) + self._records.append( + { + "source_run_id": rid, + "caller": self.caller, + "input_tokens": input_tk, + "output_tokens": output_tk, + "total_tokens": total_tk, + } + ) + return + + def snapshot_records(self) -> list[dict[str, int | str]]: + """Return a copy of the accumulated usage records.""" + return list(self._records) diff --git a/backend/packages/harness/deerflow/tools/builtins/task_tool.py b/backend/packages/harness/deerflow/tools/builtins/task_tool.py index a124e00ba..861c45b45 100644 --- a/backend/packages/harness/deerflow/tools/builtins/task_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/task_tool.py @@ -27,6 +27,92 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _is_subagent_terminal(result: Any) -> bool: + """Return whether a background subagent result is safe to clean up.""" + return result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.CANCELLED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None + + +async def _await_subagent_terminal(task_id: str, max_polls: int) -> Any | None: + """Poll until the background subagent reaches a terminal status or we run out of polls.""" + for _ in range(max_polls): + result = get_background_task_result(task_id) + if result is None: + return None + if _is_subagent_terminal(result): + return result + await asyncio.sleep(5) + return None + + +async def _deferred_cleanup_subagent_task(task_id: str, trace_id: str, max_polls: int) -> None: + """Keep polling a cancelled subagent until it can be safely removed.""" + cleanup_poll_count = 0 + while True: + result = get_background_task_result(task_id) + if result is None: + return + if _is_subagent_terminal(result): + cleanup_background_task(task_id) + return + if cleanup_poll_count >= max_polls: + logger.warning(f"[trace={trace_id}] Deferred cleanup for task {task_id} timed out after {cleanup_poll_count} polls") + return + await asyncio.sleep(5) + cleanup_poll_count += 1 + + +def _log_cleanup_failure(cleanup_task: asyncio.Task[None], *, trace_id: str, task_id: str) -> None: + if cleanup_task.cancelled(): + return + + exc = cleanup_task.exception() + if exc is not None: + logger.error(f"[trace={trace_id}] Deferred cleanup failed for task {task_id}: {exc}") + + +def _schedule_deferred_subagent_cleanup(task_id: str, trace_id: str, max_polls: int) -> None: + logger.debug(f"[trace={trace_id}] Scheduling deferred cleanup for cancelled task {task_id}") + cleanup_task = asyncio.create_task(_deferred_cleanup_subagent_task(task_id, trace_id, max_polls)) + cleanup_task.add_done_callback(lambda task: _log_cleanup_failure(task, trace_id=trace_id, task_id=task_id)) + + +def _find_usage_recorder(runtime: Any) -> Any | None: + """Find a callback handler with ``record_external_llm_usage_records`` in the runtime config.""" + if runtime is None: + return None + config = getattr(runtime, "config", None) + if not isinstance(config, dict): + return None + callbacks = config.get("callbacks", []) + if not callbacks: + return None + for cb in callbacks: + if hasattr(cb, "record_external_llm_usage_records"): + return cb + return None + + +def _report_subagent_usage(runtime: Any, result: Any) -> None: + """Report subagent token usage to the parent RunJournal, if available. + + Each subagent task must be reported only once (guarded by usage_reported). + """ + if getattr(result, "usage_reported", True): + return + records = getattr(result, "token_usage_records", None) or [] + if not records: + return + journal = _find_usage_recorder(runtime) + if journal is None: + logger.debug("No usage recorder found in runtime callbacks — subagent token usage not recorded") + return + try: + journal.record_external_llm_usage_records(records) + result.usage_reported = True + except Exception: + logger.warning("Failed to report subagent token usage", exc_info=True) + + def _get_runtime_app_config(runtime: Any) -> "AppConfig | None": context = getattr(runtime, "context", None) if isinstance(context, dict): @@ -227,21 +313,25 @@ async def task_tool( # Check if task completed, failed, or timed out if result.status == SubagentStatus.COMPLETED: + _report_subagent_usage(runtime, result) writer({"type": "task_completed", "task_id": task_id, "result": result.result}) logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls") cleanup_background_task(task_id) return f"Task Succeeded. Result: {result.result}" elif result.status == SubagentStatus.FAILED: + _report_subagent_usage(runtime, result) writer({"type": "task_failed", "task_id": task_id, "error": result.error}) logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}") cleanup_background_task(task_id) return f"Task failed. Error: {result.error}" elif result.status == SubagentStatus.CANCELLED: + _report_subagent_usage(runtime, result) writer({"type": "task_cancelled", "task_id": task_id, "error": result.error}) logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}") cleanup_background_task(task_id) return "Task cancelled by user." elif result.status == SubagentStatus.TIMED_OUT: + _report_subagent_usage(runtime, result) writer({"type": "task_timed_out", "task_id": task_id, "error": result.error}) logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}") cleanup_background_task(task_id) @@ -260,43 +350,28 @@ async def task_tool( if poll_count > max_poll_count: timeout_minutes = config.timeout_seconds // 60 logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)") + _report_subagent_usage(runtime, result) writer({"type": "task_timed_out", "task_id": task_id}) return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}" except asyncio.CancelledError: # Signal the background subagent thread to stop cooperatively. - # Without this, the thread (running in ThreadPoolExecutor with its - # own event loop via asyncio.run) would continue executing even - # after the parent task is cancelled. request_cancel_background_task(task_id) - async def cleanup_when_done() -> None: - max_cleanup_polls = max_poll_count - cleanup_poll_count = 0 + # Wait (shielded) for the subagent to reach a terminal state so the + # final token usage snapshot is reported to the parent RunJournal + # before the parent worker persists get_completion_data(). + terminal_result = None + try: + terminal_result = await asyncio.shield(_await_subagent_terminal(task_id, max_poll_count)) + except asyncio.CancelledError: + pass - while True: - result = get_background_task_result(task_id) - if result is None: - return - - if result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.CANCELLED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None: - cleanup_background_task(task_id) - return - - if cleanup_poll_count > max_cleanup_polls: - logger.warning(f"[trace={trace_id}] Deferred cleanup for task {task_id} timed out after {cleanup_poll_count} polls") - return - - await asyncio.sleep(5) - cleanup_poll_count += 1 - - def log_cleanup_failure(cleanup_task: asyncio.Task[None]) -> None: - if cleanup_task.cancelled(): - return - - exc = cleanup_task.exception() - if exc is not None: - logger.error(f"[trace={trace_id}] Deferred cleanup failed for task {task_id}: {exc}") - - logger.debug(f"[trace={trace_id}] Scheduling deferred cleanup for cancelled task {task_id}") - asyncio.create_task(cleanup_when_done()).add_done_callback(log_cleanup_failure) + # Report whatever the subagent collected (even if we timed out). + final_result = terminal_result or get_background_task_result(task_id) + if final_result is not None: + _report_subagent_usage(runtime, final_result) + if final_result is not None and _is_subagent_terminal(final_result): + cleanup_background_task(task_id) + else: + _schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count) raise diff --git a/backend/tests/test_run_journal.py b/backend/tests/test_run_journal.py index 2188eeef0..27c05619c 100644 --- a/backend/tests/test_run_journal.py +++ b/backend/tests/test_run_journal.py @@ -383,6 +383,244 @@ class TestMiddlewareEvents: assert "middleware:guardrail" in event_types +class TestCallerBucketing: + """Tests for caller-bucketed token accumulation (lead_agent / subagent / middleware).""" + + def test_lead_agent_bucketing(self, journal_setup): + j, _ = journal_setup + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) + assert j._lead_agent_tokens == 15 + assert j._subagent_tokens == 0 + assert j._middleware_tokens == 0 + + def test_subagent_bucketing(self, journal_setup): + j, _ = journal_setup + usage = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30} + j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"]) + assert j._subagent_tokens == 30 + assert j._lead_agent_tokens == 0 + assert j._middleware_tokens == 0 + + def test_middleware_bucketing(self, journal_setup): + j, _ = journal_setup + usage = {"input_tokens": 5, "output_tokens": 2, "total_tokens": 7} + j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:summarize"]) + assert j._middleware_tokens == 7 + assert j._lead_agent_tokens == 0 + assert j._subagent_tokens == 0 + + def test_mixed_callers_sum_independently(self, journal_setup): + j, _ = journal_setup + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:bash"]) + j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:title"]) + assert j._lead_agent_tokens == 15 + assert j._subagent_tokens == 15 + assert j._middleware_tokens == 15 + assert j._total_tokens == 45 + + def test_get_completion_data_includes_buckets(self, journal_setup): + j, _ = journal_setup + j._lead_agent_tokens = 100 + j._subagent_tokens = 200 + j._middleware_tokens = 50 + data = j.get_completion_data() + assert data["lead_agent_tokens"] == 100 + assert data["subagent_tokens"] == 200 + assert data["middleware_tokens"] == 50 + + def test_dedup_same_run_id(self, journal_setup): + """Same langchain run_id in on_llm_end must not double-count.""" + j, _ = journal_setup + run_id = uuid4() + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"]) + assert j._total_tokens == 15 + assert j._lead_agent_tokens == 15 + assert j._llm_call_count == 1 + + def test_first_no_usage_second_with_usage(self, journal_setup): + """First callback with no usage must not block second callback with usage for same run_id.""" + j, _ = journal_setup + run_id = uuid4() + j.on_llm_end(_make_llm_response("A", usage=None), run_id=run_id, parent_run_id=None, tags=["lead_agent"]) + assert str(run_id) not in j._counted_llm_run_ids + # Second callback for the same run_id with actual usage must still count + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"]) + assert j._total_tokens == 15 + assert j._lead_agent_tokens == 15 + + def test_track_token_usage_false_skips_buckets(self): + """When token tracking is disabled, caller buckets stay at 0.""" + store = MemoryRunEventStore() + j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100) + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + j.on_llm_end(_make_llm_response("X", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"]) + assert j._subagent_tokens == 0 + assert j._lead_agent_tokens == 0 + + def test_default_no_tags_buckets_as_lead_agent(self, journal_setup): + """LLM calls without explicit tags default to lead_agent bucket.""" + j, _ = journal_setup + usage = {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10} + j.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4(), parent_run_id=None) + assert j._lead_agent_tokens == 10 + assert j._subagent_tokens == 0 + assert j._middleware_tokens == 0 + + def test_unknown_tag_buckets_as_lead_agent(self, journal_setup): + """Calls with unrecognized tags (not lead_agent/subagent:/middleware:) go to lead_agent.""" + j, _ = journal_setup + usage = {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10} + j.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["some_random_tag"]) + assert j._lead_agent_tokens == 10 + + +class TestExternalUsageRecords: + """Tests for record_external_llm_usage_records.""" + + def test_records_added_to_subagent_bucket(self, journal_setup): + j, _ = journal_setup + records = [ + { + "source_run_id": "ext-1", + "caller": "subagent:general-purpose", + "input_tokens": 100, + "output_tokens": 50, + "total_tokens": 150, + } + ] + j.record_external_llm_usage_records(records) + assert j._subagent_tokens == 150 + assert j._total_tokens == 150 + assert j._total_input_tokens == 100 + assert j._total_output_tokens == 50 + + def test_records_added_to_middleware_bucket(self, journal_setup): + j, _ = journal_setup + records = [ + { + "source_run_id": "ext-2", + "caller": "middleware:summarize", + "input_tokens": 30, + "output_tokens": 10, + "total_tokens": 40, + } + ] + j.record_external_llm_usage_records(records) + assert j._middleware_tokens == 40 + assert j._lead_agent_tokens == 0 + assert j._subagent_tokens == 0 + + def test_records_added_to_lead_agent_bucket(self, journal_setup): + j, _ = journal_setup + records = [ + { + "source_run_id": "ext-3", + "caller": "lead_agent", + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + } + ] + j.record_external_llm_usage_records(records) + assert j._lead_agent_tokens == 15 + + def test_dedup_same_source_run_id(self, journal_setup): + """Same source_run_id must not be double-counted.""" + j, _ = journal_setup + records = [ + { + "source_run_id": "dup-1", + "caller": "subagent:research", + "input_tokens": 50, + "output_tokens": 25, + "total_tokens": 75, + } + ] + j.record_external_llm_usage_records(records) + j.record_external_llm_usage_records(records) + assert j._subagent_tokens == 75 + assert j._total_tokens == 75 + + def test_total_tokens_missing_computed_from_input_output(self, journal_setup): + j, _ = journal_setup + records = [ + { + "source_run_id": "ext-4", + "caller": "subagent:bash", + "input_tokens": 200, + "output_tokens": 100, + "total_tokens": 0, + } + ] + j.record_external_llm_usage_records(records) + assert j._subagent_tokens == 300 + assert j._total_tokens == 300 + + def test_total_tokens_zero_no_count(self, journal_setup): + """Records with zero total and zero input+output must not be counted.""" + j, _ = journal_setup + records = [ + { + "source_run_id": "ext-5", + "caller": "subagent:research", + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + } + ] + j.record_external_llm_usage_records(records) + assert j._total_tokens == 0 + assert j._subagent_tokens == 0 + + def test_empty_source_run_id_skipped(self, journal_setup): + j, _ = journal_setup + records = [ + { + "source_run_id": "", + "caller": "subagent:research", + "input_tokens": 50, + "output_tokens": 25, + "total_tokens": 75, + } + ] + j.record_external_llm_usage_records(records) + assert j._total_tokens == 0 + + def test_multiple_records_in_single_call(self, journal_setup): + j, _ = journal_setup + records = [ + {"source_run_id": "r1", "caller": "subagent:gp", "input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, + {"source_run_id": "r2", "caller": "subagent:bash", "input_tokens": 20, "output_tokens": 10, "total_tokens": 30}, + ] + j.record_external_llm_usage_records(records) + assert j._subagent_tokens == 45 + assert j._total_tokens == 45 + + def test_external_records_coexist_with_inline_callbacks(self, journal_setup): + """External records and inline on_llm_end must not interfere.""" + j, _ = journal_setup + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) + j.record_external_llm_usage_records([{"source_run_id": "ext-6", "caller": "subagent:gp", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150}]) + assert j._lead_agent_tokens == 15 + assert j._subagent_tokens == 150 + assert j._total_tokens == 165 + + def test_track_token_usage_false_skips_external_records(self): + """When token tracking is disabled, external records must not accumulate.""" + store = MemoryRunEventStore() + j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100) + j.record_external_llm_usage_records([{"source_run_id": "ext-7", "caller": "subagent:gp", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150}]) + assert j._total_tokens == 0 + assert j._subagent_tokens == 0 + + class TestChatModelStartHumanMessage: """Tests for on_chat_model_start extracting the first human message.""" diff --git a/backend/tests/test_subagent_token_collector.py b/backend/tests/test_subagent_token_collector.py new file mode 100644 index 000000000..76f003760 --- /dev/null +++ b/backend/tests/test_subagent_token_collector.py @@ -0,0 +1,161 @@ +"""Tests for SubagentTokenCollector callback handler.""" + +from unittest.mock import MagicMock +from uuid import uuid4 + +from deerflow.subagents.token_collector import SubagentTokenCollector + + +def _make_llm_response(content="Hello", usage=None): + """Create a mock LLM response with a message.""" + msg = MagicMock() + msg.content = content + msg.usage_metadata = usage + + gen = MagicMock() + gen.message = msg + + response = MagicMock() + response.generations = [[gen]] + return response + + +def _make_llm_response_from_usages(usages): + """Create a mock LLM response with one generation per usage entry.""" + generations = [] + for usage in usages: + msg = MagicMock() + msg.content = "chunk" + msg.usage_metadata = usage + + gen = MagicMock() + gen.message = msg + generations.append([gen]) + + response = MagicMock() + response.generations = generations + return response + + +class TestSubagentTokenCollector: + def test_collects_usage_from_response(self): + collector = SubagentTokenCollector(caller="subagent:test") + usage = {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150} + collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4()) + records = collector.snapshot_records() + assert len(records) == 1 + assert records[0]["caller"] == "subagent:test" + assert records[0]["input_tokens"] == 100 + assert records[0]["output_tokens"] == 50 + assert records[0]["total_tokens"] == 150 + assert "source_run_id" in records[0] + + def test_total_tokens_zero_uses_input_plus_output(self): + collector = SubagentTokenCollector(caller="subagent:test") + usage = {"input_tokens": 200, "output_tokens": 100, "total_tokens": 0} + collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4()) + records = collector.snapshot_records() + assert len(records) == 1 + assert records[0]["total_tokens"] == 300 + + def test_total_tokens_missing_uses_input_plus_output(self): + collector = SubagentTokenCollector(caller="subagent:test") + usage = {"input_tokens": 30, "output_tokens": 20} + collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4()) + records = collector.snapshot_records() + assert len(records) == 1 + assert records[0]["total_tokens"] == 50 + + def test_dedup_same_run_id(self): + collector = SubagentTokenCollector(caller="subagent:test") + run_id = uuid4() + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + collector.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id) + collector.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id) + records = collector.snapshot_records() + assert len(records) == 1 + + def test_no_usage_no_record(self): + collector = SubagentTokenCollector(caller="subagent:test") + collector.on_llm_end(_make_llm_response("Hi", usage=None), run_id=uuid4()) + records = collector.snapshot_records() + assert len(records) == 0 + + def test_zero_usage_no_record(self): + collector = SubagentTokenCollector(caller="subagent:test") + usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4()) + records = collector.snapshot_records() + assert len(records) == 0 + + def test_skips_empty_generation_and_records_later_usage(self): + collector = SubagentTokenCollector(caller="subagent:test") + response = _make_llm_response_from_usages( + [ + None, + {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}, + ] + ) + + collector.on_llm_end(response, run_id=uuid4()) + + records = collector.snapshot_records() + assert len(records) == 1 + assert records[0]["total_tokens"] == 30 + + def test_snapshot_returns_copy(self): + collector = SubagentTokenCollector(caller="subagent:test") + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4()) + snap1 = collector.snapshot_records() + snap2 = collector.snapshot_records() + assert snap1 == snap2 + assert snap1 is not snap2 + # Mutating snapshot does not affect internal records + snap1.append({"source_run_id": "fake"}) + assert len(collector.snapshot_records()) == 1 + + def test_multiple_calls_accumulate(self): + collector = SubagentTokenCollector(caller="subagent:test") + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + collector.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4()) + collector.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4()) + records = collector.snapshot_records() + assert len(records) == 2 + + def test_different_run_ids_accumulate_separately(self): + collector = SubagentTokenCollector(caller="subagent:test") + usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30} + collector.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4()) + collector.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4()) + records = collector.snapshot_records() + assert len(records) == 2 + assert records[0]["total_tokens"] == 15 + assert records[1]["total_tokens"] == 30 + + def test_message_without_usage_metadata_skipped(self): + """A response where message has no usage_metadata attribute must be skipped.""" + collector = SubagentTokenCollector(caller="subagent:test") + + msg = MagicMock(spec=[]) # object without usage_metadata + gen = MagicMock() + gen.message = msg + response = MagicMock() + response.generations = [[gen]] + + collector.on_llm_end(response, run_id=uuid4()) + records = collector.snapshot_records() + assert len(records) == 0 + + def test_generation_without_message_skipped(self): + """A generation without a message attribute must be skipped.""" + collector = SubagentTokenCollector(caller="subagent:test") + + gen = MagicMock(spec=[]) # object without message + response = MagicMock() + response.generations = [[gen]] + + collector.on_llm_end(response, run_id=uuid4()) + records = collector.snapshot_records() + assert len(records) == 0 diff --git a/backend/tests/test_task_tool_core_logic.py b/backend/tests/test_task_tool_core_logic.py index 3be1e4b5c..0591c0e8d 100644 --- a/backend/tests/test_task_tool_core_logic.py +++ b/backend/tests/test_task_tool_core_logic.py @@ -777,22 +777,27 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch): def test_cleanup_scheduled_on_cancellation(monkeypatch): - """Verify cancellation schedules deferred cleanup for the background task.""" + """Verify cancellation handler synchronously cleans up after shielded wait.""" config = _make_subagent_config() events = [] cleanup_calls = [] - scheduled_cleanup_coros = [] poll_count = 0 def get_result(_: str): nonlocal poll_count poll_count += 1 - if poll_count == 1: + # Main loop polls RUNNING twice, then shielded wait gets COMPLETED + if poll_count <= 2: return _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]) return _make_result(FakeSubagentStatus.COMPLETED, result="done") - async def cancel_on_first_sleep(_: float) -> None: - raise asyncio.CancelledError + sleep_count = 0 + + async def cancel_on_second_sleep(_: float) -> None: + nonlocal sleep_count + sleep_count += 1 + if sleep_count == 2: + raise asyncio.CancelledError monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) monkeypatch.setattr( @@ -804,12 +809,7 @@ def test_cleanup_scheduled_on_cancellation(monkeypatch): monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) - monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep) - monkeypatch.setattr( - task_tool_module.asyncio, - "create_task", - lambda coro: scheduled_cleanup_coros.append(coro) or _DummyScheduledTask(), - ) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_second_sleep) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr( task_tool_module, @@ -826,25 +826,48 @@ def test_cleanup_scheduled_on_cancellation(monkeypatch): tool_call_id="tc-cancelled-cleanup", ) - assert cleanup_calls == [] - assert len(scheduled_cleanup_coros) == 1 - - asyncio.run(scheduled_cleanup_coros.pop()) - + # Cleanup happens synchronously within the cancellation handler assert cleanup_calls == ["tc-cancelled-cleanup"] def test_cancelled_cleanup_stops_after_timeout(monkeypatch): - """Verify deferred cleanup gives up after a bounded number of polls.""" + """Verify cancellation handler survives a shielded-wait timeout gracefully. + + When the subagent never reaches a terminal state, the shielded wait times + out (or is interrupted), the handler reports whatever usage it can, calls + cleanup (which is a no-op for non-terminal tasks), and re-raises. + """ config = _make_subagent_config() - config.timeout_seconds = 1 events = [] + report_calls = [] cleanup_calls = [] - scheduled_cleanup_coros = [] + scheduled_cleanups = [] + + # Always return RUNNING — subagent never finishes + monkeypatch.setattr( + task_tool_module, + "get_background_task_result", + lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]), + ) async def cancel_on_first_sleep(_: float) -> None: raise asyncio.CancelledError + def fake_report_subagent_usage(runtime, result): + report_calls.append((runtime, result)) + + class DummyCleanupTask: + def __init__(self, coro): + self.coro = coro + + def add_done_callback(self, callback): + self.callback = callback + + def fake_create_task(coro): + scheduled_cleanups.append(coro) + coro.close() + return DummyCleanupTask(coro) + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) monkeypatch.setattr( task_tool_module, @@ -852,19 +875,10 @@ def test_cancelled_cleanup_stops_after_timeout(monkeypatch): type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}), ) monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config) - - monkeypatch.setattr( - task_tool_module, - "get_background_task_result", - lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]), - ) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep) - monkeypatch.setattr( - task_tool_module.asyncio, - "create_task", - lambda coro: scheduled_cleanup_coros.append(coro) or _DummyScheduledTask(), - ) + monkeypatch.setattr(task_tool_module.asyncio, "create_task", fake_create_task) + monkeypatch.setattr(task_tool_module, "_report_subagent_usage", fake_report_subagent_usage) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr( task_tool_module, @@ -881,13 +895,73 @@ def test_cancelled_cleanup_stops_after_timeout(monkeypatch): tool_call_id="tc-cancelled-timeout", ) - async def bounded_sleep(_seconds: float) -> None: - return None - - monkeypatch.setattr(task_tool_module.asyncio, "sleep", bounded_sleep) - asyncio.run(scheduled_cleanup_coros.pop()) - + # Non-terminal tasks cannot be cleaned immediately; a deferred cleanup + # keeps polling after the parent cancellation path exits. assert cleanup_calls == [] + assert len(scheduled_cleanups) == 1 + # _report_subagent_usage is called (but skips because result has no records) + assert len(report_calls) == 1 + + +def test_cancellation_wait_uses_subagent_polling_budget(monkeypatch): + """Cancelled parent waits on the existing subagent polling budget, not a fixed timeout.""" + config = _make_subagent_config() + events = [] + report_calls = [] + cleanup_calls = [] + sleep_count = 0 + result_polls = 0 + terminal_result = _make_result(FakeSubagentStatus.COMPLETED, result="done") + + def get_result(_: str): + nonlocal result_polls + result_polls += 1 + if result_polls < 5: + return _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]) + return terminal_result + + async def cancel_then_continue(_: float) -> None: + nonlocal sleep_count + sleep_count += 1 + if sleep_count == 1: + raise asyncio.CancelledError + + def fake_report_subagent_usage(runtime, result): + report_calls.append((runtime, result)) + + async def fail_on_fixed_timeout(awaitable, *, timeout=None): + raise AssertionError(f"cancellation wait should not use fixed timeout={timeout}") + + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr( + task_tool_module, + "SubagentExecutor", + type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}), + ) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config) + monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_then_continue) + monkeypatch.setattr(task_tool_module.asyncio, "wait_for", fail_on_fixed_timeout) + monkeypatch.setattr(task_tool_module, "_report_subagent_usage", fake_report_subagent_usage) + monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) + monkeypatch.setattr( + task_tool_module, + "cleanup_background_task", + lambda task_id: cleanup_calls.append(task_id), + ) + + with pytest.raises(asyncio.CancelledError): + _run_task_tool( + runtime=_make_runtime(), + description="执行任务", + prompt="cancel task", + subagent_type="general-purpose", + tool_call_id="tc-cancel-budget", + ) + + assert report_calls == [(_make_runtime(), terminal_result)] + assert cleanup_calls == ["tc-cancel-budget"] def test_cancellation_calls_request_cancel(monkeypatch): @@ -895,7 +969,6 @@ def test_cancellation_calls_request_cancel(monkeypatch): config = _make_subagent_config() events = [] cancel_requests = [] - scheduled_cleanup_coros = [] async def cancel_on_first_sleep(_: float) -> None: raise asyncio.CancelledError @@ -915,11 +988,6 @@ def test_cancellation_calls_request_cancel(monkeypatch): ) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep) - monkeypatch.setattr( - task_tool_module.asyncio, - "create_task", - lambda coro: (coro.close(), scheduled_cleanup_coros.append(None))[-1] or _DummyScheduledTask(), - ) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr( task_tool_module, @@ -987,3 +1055,80 @@ def test_task_tool_returns_cancelled_message(monkeypatch): assert output == "Task cancelled by user." assert any(e.get("type") == "task_cancelled" for e in events) assert cleanup_calls == ["tc-poll-cancelled"] + + +def test_cancellation_reports_subagent_usage(monkeypatch): + """Verify cancellation handler waits (shielded) for subagent terminal state, + then reports the final token usage before re-raising CancelledError. + + The report must happen synchronously within the cancellation handler so + the parent worker's finally block sees the updated journal totals. + """ + config = _make_subagent_config() + events = [] + report_calls = [] + cleanup_calls = [] + + # Terminal result with token usage collected after cancellation processing + cancel_result = _make_result(FakeSubagentStatus.CANCELLED, error="Cancelled by user") + cancel_result.token_usage_records = [{"source_run_id": "sub-run-1", "caller": "subagent:gp", "input_tokens": 50, "output_tokens": 25, "total_tokens": 75}] + cancel_result.usage_reported = False + + poll_count = 0 + + def get_result(_: str): + nonlocal poll_count + poll_count += 1 + # Main loop polls 3 times (RUNNING each time to keep looping) + if poll_count <= 3: + running = _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]) + running.token_usage_records = [] + running.usage_reported = False + return running + # Shielded wait poll gets the terminal result + return cancel_result + + sleep_count = 0 + + async def cancel_on_third_sleep(_: float) -> None: + nonlocal sleep_count + sleep_count += 1 + if sleep_count == 3: + raise asyncio.CancelledError + + def fake_report_subagent_usage(runtime, result): + report_calls.append((runtime, result)) + + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr( + task_tool_module, + "SubagentExecutor", + type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}), + ) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config) + monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_third_sleep) + monkeypatch.setattr(task_tool_module, "_report_subagent_usage", fake_report_subagent_usage) + monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) + monkeypatch.setattr(task_tool_module, "request_cancel_background_task", lambda _: None) + monkeypatch.setattr( + task_tool_module, + "cleanup_background_task", + lambda task_id: cleanup_calls.append(task_id), + ) + + with pytest.raises(asyncio.CancelledError): + _run_task_tool( + runtime=_make_runtime(), + description="执行任务", + prompt="cancel me", + subagent_type="general-purpose", + tool_call_id="tc-cancel-report", + ) + + # _report_subagent_usage is called synchronously within the cancellation + # handler (after the shielded wait), before CancelledError is re-raised. + assert len(report_calls) == 1 + assert report_calls[0][1] is cancel_result + assert cleanup_calls == ["tc-cancel-report"] diff --git a/frontend/tests/unit/core/threads/api.test.ts b/frontend/tests/unit/core/threads/api.test.ts index d91a2bcdf..4d1268694 100644 --- a/frontend/tests/unit/core/threads/api.test.ts +++ b/frontend/tests/unit/core/threads/api.test.ts @@ -20,7 +20,11 @@ test("fetchThreadTokenUsage uses shared auth fetch without JSON GET headers", as total_tokens: 7, total_runs: 1, by_model: { unknown: { tokens: 7, runs: 1 } }, - by_caller: {}, + by_caller: { + lead_agent: 0, + subagent: 0, + middleware: 0, + }, }), }); From 30a58462192386cb57924d670c23412855776c96 Mon Sep 17 00:00:00 2001 From: Maz Benoscar Date: Sun, 10 May 2026 11:09:03 -0400 Subject: [PATCH 02/86] fix(tools): make write_file append discoverable in model-facing schema (#2843) * fix: make tool argument behavior discoverable The write_file tool already supported append=false by default with append=true for end-of-file writes, but the parsed docstring did not describe append in the model-facing schema. This records the overwrite default and append path in the tool description, adds resilient schema regression coverage, and keeps backend sandbox docs aligned. The regression now also checks that every public parameter in the existing tool schema test matrix has a description. Enabling docstring parsing on setup_agent and update_agent fills the two existing gaps with their existing Args docs instead of duplicating descriptions elsewhere. Constraint: Issue #2831 asks for a small docstring/schema discoverability fix without changing runtime file-writing behavior Rejected: Changing write_file defaults | would alter existing overwrite semantics and broaden the fix beyond schema discoverability Rejected: Exact phrase assertions | too brittle for future docstring rewording while testing the same behavior Confidence: high Scope-risk: narrow Directive: Keep model-facing tool parameters documented through parsed docstrings or equivalent schema descriptions Tested: cd backend && uv run pytest tests/test_setup_agent_tool.py tests/test_update_agent_tool.py tests/test_tool_args_schema_no_pydantic_warning.py tests/test_sandbox_tools_security.py::test_str_replace_and_append_on_same_path_should_preserve_both_updates -q Tested: cd backend && uv run ruff check packages/harness/deerflow/sandbox/tools.py packages/harness/deerflow/tools/builtins/setup_agent_tool.py packages/harness/deerflow/tools/builtins/update_agent_tool.py tests/test_tool_args_schema_no_pydantic_warning.py Not-tested: Full backend test suite Co-authored-by: OmX * Fix the lint error --------- Co-authored-by: OmX Co-authored-by: Willem Jiang --- backend/CLAUDE.md | 2 +- backend/README.md | 2 +- .../packages/harness/deerflow/sandbox/tools.py | 3 ++- .../deerflow/tools/builtins/setup_agent_tool.py | 2 +- .../tools/builtins/update_agent_tool.py | 2 +- ...test_tool_args_schema_no_pydantic_warning.py | 17 +++++++++++++++++ 6 files changed, 23 insertions(+), 5 deletions(-) diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index d03aeefd8..99922a61e 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -243,7 +243,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → - `bash` - Execute commands with path translation and error handling - `ls` - Directory listing (tree format, max 2 levels) - `read_file` - Read file contents with optional line range -- `write_file` - Write/append to files, creates directories +- `write_file` - Write/append to files, creates directories; overwrites by default and exposes the `append` argument in the model-facing schema for end-of-file writes - `str_replace` - Substring replacement (single or all occurrences); same-path serialization is scoped to `(sandbox.id, path)` so isolated sandboxes do not contend on identical virtual paths inside one process ### Subagent System (`packages/harness/deerflow/subagents/`) diff --git a/backend/README.md b/backend/README.md index 0e2d966ee..6295eba22 100644 --- a/backend/README.md +++ b/backend/README.md @@ -79,7 +79,7 @@ Per-thread isolated execution with virtual path translation: - **Skills path**: `/mnt/skills` → `deer-flow/skills/` directory - **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths - **File-write safety**: `str_replace` serializes read-modify-write per `(sandbox.id, path)` so isolated sandboxes keep concurrency even when virtual paths match -- **Tools**: `bash`, `ls`, `read_file`, `write_file`, `str_replace` (`bash` is disabled by default when using `LocalSandboxProvider`; use `AioSandboxProvider` for isolated shell access) +- **Tools**: `bash`, `ls`, `read_file`, `write_file`, `str_replace` (`write_file` overwrites by default and exposes `append` for end-of-file writes; `bash` is disabled by default when using `LocalSandboxProvider`; use `AioSandboxProvider` for isolated shell access) ### Subagent System diff --git a/backend/packages/harness/deerflow/sandbox/tools.py b/backend/packages/harness/deerflow/sandbox/tools.py index a20004a8a..7c746b1aa 100644 --- a/backend/packages/harness/deerflow/sandbox/tools.py +++ b/backend/packages/harness/deerflow/sandbox/tools.py @@ -1499,12 +1499,13 @@ def write_file_tool( content: str, append: bool = False, ) -> str: - """Write text content to a file. + """Write text content to a file. By default this overwrites the target file; set append to true to add content to the end without replacing existing content. Args: description: Explain why you are writing to this file in short words. ALWAYS PROVIDE THIS PARAMETER FIRST. path: The **absolute** path to the file to write to. ALWAYS PROVIDE THIS PARAMETER SECOND. content: The content to write to the file. ALWAYS PROVIDE THIS PARAMETER THIRD. + append: Whether to append content to the end of the file instead of overwriting it. Defaults to false. """ try: sandbox = ensure_sandbox_initialized(runtime) diff --git a/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py b/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py index 97929ad56..2f796b005 100644 --- a/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py @@ -20,7 +20,7 @@ def _get_runtime_user_id(runtime: Runtime) -> str: return get_effective_user_id() -@tool +@tool(parse_docstring=True) def setup_agent( soul: str, description: str, diff --git a/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py b/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py index 90d951859..b2dc8ca72 100644 --- a/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py @@ -67,7 +67,7 @@ def _cleanup_temps(temps: list[Path]) -> None: logger.debug("Failed to clean up temp file %s", tmp, exc_info=True) -@tool +@tool(parse_docstring=True) def update_agent( runtime: Runtime, soul: str | None = None, diff --git a/backend/tests/test_tool_args_schema_no_pydantic_warning.py b/backend/tests/test_tool_args_schema_no_pydantic_warning.py index 037771b3e..6da56347f 100644 --- a/backend/tests/test_tool_args_schema_no_pydantic_warning.py +++ b/backend/tests/test_tool_args_schema_no_pydantic_warning.py @@ -89,3 +89,20 @@ def test_tool_args_schema_does_not_emit_pydantic_context_warning(tool_obj, extra pydantic_warnings = [w for w in caught if "PydanticSerializationUnexpectedValue" in str(w.message)] assert not pydantic_warnings, f"{tool_obj.name} args_schema.model_dump() emitted Pydantic context serialization warnings: {[str(w.message) for w in pydantic_warnings]}" + + +def test_write_file_append_is_discoverable_in_tool_schema() -> None: + """``append`` must be visible and described in the model-facing tool schema.""" + assert "append" in write_file_tool.description + + append_field = write_file_tool.tool_call_schema.model_fields["append"] + assert append_field.default is False + assert append_field.description + assert "append" in append_field.description + + +@pytest.mark.parametrize("tool_obj", [case[0] for case in _TOOL_CASES], ids=[case[0].name for case in _TOOL_CASES]) +def test_model_facing_tool_parameters_have_descriptions(tool_obj) -> None: + """Every model-facing tool parameter should explain when and how to use it.""" + missing_descriptions = [field_name for field_name, field in tool_obj.tool_call_schema.model_fields.items() if not field.description] + assert missing_descriptions == [], f"{tool_obj.name} has model-facing parameters without descriptions: {missing_descriptions}. Add an Args: section to the tool's docstring and ensure @tool(parse_docstring=True) is set." From e82b2fb4d0a8feb333c4b533292e5ea16a136f08 Mon Sep 17 00:00:00 2001 From: YuJitang Date: Mon, 11 May 2026 07:17:49 +0800 Subject: [PATCH 03/86] docs: clarify token usage accounting semantics (#2845) --- .../en/application/workspace-usage.mdx | 20 +++++++++++++++++++ .../zh/application/workspace-usage.mdx | 11 ++++++++++ frontend/src/core/i18n/locales/en-US.ts | 2 +- frontend/src/core/i18n/locales/zh-CN.ts | 2 +- 4 files changed, 33 insertions(+), 2 deletions(-) diff --git a/frontend/src/content/en/application/workspace-usage.mdx b/frontend/src/content/en/application/workspace-usage.mdx index 686614aa7..253519af8 100644 --- a/frontend/src/content/en/application/workspace-usage.mdx +++ b/frontend/src/content/en/application/workspace-usage.mdx @@ -67,6 +67,26 @@ Each agent response in the conversation may contain: Tool calls and thinking steps are collapsed by default. Click to expand them. +## Understanding token usage + +If token usage display is enabled, DeerFlow shows one conversation-level total in +the header and optional per-turn or debug summaries in the message list. + +- **Header total**: the persisted thread-level total from the backend. While the + current run is still streaming, the header may also include the visible + in-flight usage for that unfinished response. +- **Per-turn / debug usage**: usage derived from the assistant messages that are + currently visible in the conversation view. + +This means the header total and the visible per-turn totals do **not** need to +add up exactly. The header is a thread ledger; the per-turn view is a rendering +of the messages you can currently see. + +These totals may also differ from your provider's billing page. Common reasons +include retries, failed requests, cached input tokens, reasoning tokens, +provider-specific billing rules, and internal calls that do not appear as normal +chat messages. + ## Switching agents If you have created custom agents, use the **Agent** selector in the input bar to switch to a different agent. The selected agent persists for the duration of the thread. diff --git a/frontend/src/content/zh/application/workspace-usage.mdx b/frontend/src/content/zh/application/workspace-usage.mdx index e4e3fb541..35cafcc84 100644 --- a/frontend/src/content/zh/application/workspace-usage.mdx +++ b/frontend/src/content/zh/application/workspace-usage.mdx @@ -70,6 +70,17 @@ DeerFlow 工作区是一个基于浏览器的对话界面,你可以在其中 点击消息旁边的展开箭头查看完整的推理链。 +## 理解 Token 用量 + +如果启用了 Token 用量显示,DeerFlow 会在顶部显示一个对话级总量,并在消息列表中按配置显示每轮或调试级别的用量摘要。 + +- **顶部总量**:后端持久化的线程级总账。当当前回复仍在流式返回时,顶部还可能临时叠加这条未完成回复的可见进行中用量。 +- **每轮 / 调试用量**:根据当前界面里可见的 assistant 消息计算出来的用量。 + +因此,顶部总量和当前可见的每轮总和**不要求完全相等**。顶部展示的是整个线程的总账;每轮展示的是你当前能看到的消息视图。 + +这些数字也可能与模型供应商的账单页不同。常见原因包括重试请求、失败请求、缓存输入 token、推理 token、供应商自己的计费口径,以及不会以普通聊天消息形式显示的内部调用。 + ## 查看产出物 当 Agent 生成文件(报告、图表、代码文件、演示文稿)时,它们会以**产出物**的形式出现在对话中。 diff --git a/frontend/src/core/i18n/locales/en-US.ts b/frontend/src/core/i18n/locales/en-US.ts index 1daaa21b0..b6ce0c76a 100644 --- a/frontend/src/core/i18n/locales/en-US.ts +++ b/frontend/src/core/i18n/locales/en-US.ts @@ -310,7 +310,7 @@ export const enUS: Translations = { unavailable: "No token usage yet. Usage appears only after a successful model response when the provider returns usage_metadata.", unavailableShort: "No usage returned", - note: "Header totals use persisted thread usage when available. Per-turn and debug usage come from visible messages. Totals may differ from provider billing pages.", + note: "Header totals use persisted thread usage, plus visible in-flight usage while a run is still streaming. Per-turn and debug usage come from currently visible messages only. Totals may differ from provider billing pages.", presets: { off: "Off", summary: "Summary", diff --git a/frontend/src/core/i18n/locales/zh-CN.ts b/frontend/src/core/i18n/locales/zh-CN.ts index aadedad65..105aca551 100644 --- a/frontend/src/core/i18n/locales/zh-CN.ts +++ b/frontend/src/core/i18n/locales/zh-CN.ts @@ -296,7 +296,7 @@ export const zhCN: Translations = { unavailable: "暂无 Token 用量。只有模型成功返回且供应商提供 usage_metadata 时才会显示。", unavailableShort: "未返回用量", - note: "顶部总量优先使用后端持久化的线程用量。每轮和调试用量来自当前可见消息,可能与平台账单页不完全一致。", + note: "顶部总量优先使用后端持久化的线程用量;当当前回复仍在流式返回时,还会叠加可见的进行中用量。每轮和调试用量只来自当前可见消息,可能与平台账单页不完全一致。", presets: { off: "关闭", summary: "总览", From 2b5bece7441728529e58585ccb2dddf7b1716732 Mon Sep 17 00:00:00 2001 From: KiteEater <145987840+Kiteeater@users.noreply.github.com> Date: Mon, 11 May 2026 07:42:15 +0800 Subject: [PATCH 04/86] fix(harness): reset local sandbox singleton with provider lifecycle (#2834) * Fix local sandbox singleton reset on provider lifecycle * Fix local sandbox singleton reset on provider reset --------- Co-authored-by: Willem Jiang --- .../sandbox/local/local_sandbox_provider.py | 10 ++ .../deerflow/sandbox/sandbox_provider.py | 13 +- .../test_local_sandbox_provider_mounts.py | 145 ++++++++++++++++++ 3 files changed, 167 insertions(+), 1 deletion(-) diff --git a/backend/packages/harness/deerflow/sandbox/local/local_sandbox_provider.py b/backend/packages/harness/deerflow/sandbox/local/local_sandbox_provider.py index 651db11ec..0510a2473 100644 --- a/backend/packages/harness/deerflow/sandbox/local/local_sandbox_provider.py +++ b/backend/packages/harness/deerflow/sandbox/local/local_sandbox_provider.py @@ -119,3 +119,13 @@ class LocalSandboxProvider(SandboxProvider): # For Docker-based providers (e.g., AioSandboxProvider), cleanup # happens at application shutdown via the shutdown() method. pass + + def reset(self) -> None: + # reset_sandbox_provider() must also clear the module singleton. + global _singleton + _singleton = None + + def shutdown(self) -> None: + # LocalSandboxProvider has no extra resources beyond the shared + # singleton, so shutdown uses the same cleanup path as reset. + self.reset() diff --git a/backend/packages/harness/deerflow/sandbox/sandbox_provider.py b/backend/packages/harness/deerflow/sandbox/sandbox_provider.py index ecb1f7a67..0aa4d619a 100644 --- a/backend/packages/harness/deerflow/sandbox/sandbox_provider.py +++ b/backend/packages/harness/deerflow/sandbox/sandbox_provider.py @@ -37,6 +37,10 @@ class SandboxProvider(ABC): """ pass + def reset(self) -> None: + """Clear cached state that survives provider instance replacement.""" + pass + _default_sandbox_provider: SandboxProvider | None = None @@ -65,11 +69,18 @@ def reset_sandbox_provider() -> None: The next call to `get_sandbox_provider()` will create a new instance. Useful for testing or when switching configurations. + Providers can override `reset()` to clear any module-level state they keep + alive across instances (for example, `LocalSandboxProvider`'s cached + `LocalSandbox` singleton). Without it, config/mount changes would not take + effect on the next acquire(). + Note: If the provider has active sandboxes, they will be orphaned. Use `shutdown_sandbox_provider()` for proper cleanup. """ global _default_sandbox_provider - _default_sandbox_provider = None + if _default_sandbox_provider is not None: + _default_sandbox_provider.reset() + _default_sandbox_provider = None def shutdown_sandbox_provider() -> None: diff --git a/backend/tests/test_local_sandbox_provider_mounts.py b/backend/tests/test_local_sandbox_provider_mounts.py index 5c50a1aa0..5e7a06b6d 100644 --- a/backend/tests/test_local_sandbox_provider_mounts.py +++ b/backend/tests/test_local_sandbox_provider_mounts.py @@ -639,3 +639,148 @@ class TestLocalSandboxProviderMounts: provider = LocalSandboxProvider() assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills", "/mnt/data"] + + +class TestLocalSandboxProviderResetClearsSingleton: + """Regression coverage for issue #2815. + + The module-level LocalSandbox singleton must be cleared whenever the + provider is reset or shut down — otherwise stale path mappings and + mount policy survive config reloads and test teardown. + """ + + def _build_config(self, skills_dir, mounts): + from deerflow.config.sandbox_config import SandboxConfig + + sandbox_config = SandboxConfig( + use="deerflow.sandbox.local:LocalSandboxProvider", + mounts=mounts, + ) + return SimpleNamespace( + skills=SimpleNamespace( + container_path="/mnt/skills", + get_skills_path=lambda: skills_dir, + use="deerflow.skills.storage.local_skill_storage:LocalSkillStorage", + ), + sandbox=sandbox_config, + ) + + def test_reset_sandbox_provider_clears_local_singleton(self, tmp_path): + from deerflow.config.sandbox_config import VolumeMountConfig + from deerflow.sandbox import local as local_module + from deerflow.sandbox.local import local_sandbox_provider as lsp_module + from deerflow.sandbox.sandbox_provider import ( + get_sandbox_provider, + reset_sandbox_provider, + ) + + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + first_dir = tmp_path / "first" + first_dir.mkdir() + second_dir = tmp_path / "second" + second_dir.mkdir() + + first_cfg = self._build_config( + skills_dir, + [VolumeMountConfig(host_path=str(first_dir), container_path="/mnt/first", read_only=False)], + ) + second_cfg = self._build_config( + skills_dir, + [VolumeMountConfig(host_path=str(second_dir), container_path="/mnt/second", read_only=False)], + ) + + # Make sure no leftover singleton from a prior test interferes. + lsp_module._singleton = None + reset_sandbox_provider() + + try: + with patch("deerflow.sandbox.sandbox_provider.get_app_config", return_value=first_cfg), patch("deerflow.config.get_app_config", return_value=first_cfg): + provider = get_sandbox_provider() + provider.acquire() + + assert lsp_module._singleton is not None + first_container_paths = {m.container_path for m in lsp_module._singleton.path_mappings} + assert "/mnt/first" in first_container_paths + + reset_sandbox_provider() + + # The whole point of the regression: reset must drop the cached LocalSandbox. + assert lsp_module._singleton is None + + with patch("deerflow.sandbox.sandbox_provider.get_app_config", return_value=second_cfg), patch("deerflow.config.get_app_config", return_value=second_cfg): + provider2 = get_sandbox_provider() + provider2.acquire() + + assert provider2 is not provider + second_container_paths = {m.container_path for m in lsp_module._singleton.path_mappings} + assert "/mnt/second" in second_container_paths + assert "/mnt/first" not in second_container_paths + finally: + lsp_module._singleton = None + reset_sandbox_provider() + + # Sanity: the local sandbox module still exposes the singleton symbol + # at the same module path (guards against accidental rename). + assert hasattr(local_module.local_sandbox_provider, "_singleton") + + def test_shutdown_sandbox_provider_clears_local_singleton(self, tmp_path): + from deerflow.config.sandbox_config import VolumeMountConfig + from deerflow.sandbox.local import local_sandbox_provider as lsp_module + from deerflow.sandbox.sandbox_provider import ( + get_sandbox_provider, + reset_sandbox_provider, + shutdown_sandbox_provider, + ) + + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + mount_dir = tmp_path / "mount" + mount_dir.mkdir() + + cfg = self._build_config( + skills_dir, + [VolumeMountConfig(host_path=str(mount_dir), container_path="/mnt/data", read_only=False)], + ) + + lsp_module._singleton = None + reset_sandbox_provider() + + try: + with patch("deerflow.sandbox.sandbox_provider.get_app_config", return_value=cfg), patch("deerflow.config.get_app_config", return_value=cfg): + provider = get_sandbox_provider() + provider.acquire() + + assert lsp_module._singleton is not None + + shutdown_sandbox_provider() + + assert lsp_module._singleton is None + finally: + lsp_module._singleton = None + reset_sandbox_provider() + + def test_provider_reset_method_is_idempotent(self, tmp_path): + from deerflow.sandbox.local import local_sandbox_provider as lsp_module + from deerflow.sandbox.local.local_sandbox_provider import LocalSandboxProvider + + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + cfg = self._build_config(skills_dir, []) + + lsp_module._singleton = None + + try: + with patch("deerflow.config.get_app_config", return_value=cfg): + provider = LocalSandboxProvider() + provider.acquire() + assert lsp_module._singleton is not None + + provider.reset() + assert lsp_module._singleton is None + + # Calling reset again on an already-cleared singleton is safe. + provider.reset() + assert lsp_module._singleton is None + finally: + lsp_module._singleton = None From 813d3c94efa7fdea6aafcb4f459304db91fcaed0 Mon Sep 17 00:00:00 2001 From: Willem Jiang Date: Mon, 11 May 2026 09:59:06 +0800 Subject: [PATCH 05/86] fix(subagents): consolidate system_prompt and skills into single SystemMessage (#2701) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(subagents): consolidate system_prompt and skills into single SystemMessage Some LLM APIs (vLLM, Xinference, Chinese LLM providers) reject multiple system messages with \”System message must be at the beginning.\” The subagent executor was sending separate SystemMessages for the configured system_prompt and each loaded skill, which caused failures when calling task tool with sub-agents. Merge system_prompt and all skill content into one SystemMessage in the initial state, and pass system_prompt=None to create_agent() so the factory doesn't prepend a second one. Fixes #2693 * fix(subagents): update SubagentConfig.system_prompt to str | None and add astream regression test Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/2ee03a26-e19b-4106-abc5-c76a2906383b Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com> * fixed the lint error * fix the lint error in the backend * fix the unit test error of test_subagent_executor --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> --- .../harness/deerflow/subagents/config.py | 2 +- .../harness/deerflow/subagents/executor.py | 19 +- backend/tests/test_subagent_executor.py | 184 +++++++++++++++++- 3 files changed, 200 insertions(+), 5 deletions(-) diff --git a/backend/packages/harness/deerflow/subagents/config.py b/backend/packages/harness/deerflow/subagents/config.py index b0b094e28..9081e2df9 100644 --- a/backend/packages/harness/deerflow/subagents/config.py +++ b/backend/packages/harness/deerflow/subagents/config.py @@ -26,7 +26,7 @@ class SubagentConfig: name: str description: str - system_prompt: str + system_prompt: str | None = None tools: list[str] | None = None disallowed_tools: list[str] | None = field(default_factory=lambda: ["task"]) skills: list[str] | None = None diff --git a/backend/packages/harness/deerflow/subagents/executor.py b/backend/packages/harness/deerflow/subagents/executor.py index a2fec6432..d6d2e4fc5 100644 --- a/backend/packages/harness/deerflow/subagents/executor.py +++ b/backend/packages/harness/deerflow/subagents/executor.py @@ -286,11 +286,13 @@ class SubagentExecutor: # Reuse shared middleware composition with lead agent. middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name=self.model_name, lazy_init=True) + # system_prompt is included in initial state messages (see _build_initial_state) + # to avoid multiple SystemMessages which some LLM APIs don't support. return create_agent( model=model, tools=tools if tools is not None else self.tools, middleware=middlewares, - system_prompt=self.config.system_prompt, + system_prompt=None, state_schema=ThreadState, ) @@ -365,14 +367,25 @@ class SubagentExecutor: Returns: Initial state dictionary and tools filtered by loaded skill metadata. """ + # Load skills as conversation items (Codex pattern) skills = await self._load_skills() filtered_tools = self._apply_skill_allowed_tools(skills) skill_messages = await self._load_skill_messages(skills) + # Combine system_prompt and skills into a single SystemMessage. + # Some LLM APIs reject multiple SystemMessages with + # "System message must be at the beginning." + system_parts: list[str] = [] + if self.config.system_prompt: + system_parts.append(self.config.system_prompt) + for skill_msg in skill_messages: + system_parts.append(skill_msg.content) + messages: list[Any] = [] - # Skill content injected as developer/system messages before the task - messages.extend(skill_messages) + if system_parts: + messages.append(SystemMessage(content="\n\n".join(system_parts))) + # Then the actual task messages.append(HumanMessage(content=task)) diff --git a/backend/tests/test_subagent_executor.py b/backend/tests/test_subagent_executor.py index b8da323f4..87c82ff96 100644 --- a/backend/tests/test_subagent_executor.py +++ b/backend/tests/test_subagent_executor.py @@ -291,7 +291,7 @@ class TestAgentConstruction: assert captured["agent"]["model"] is model assert captured["agent"]["middleware"] is middlewares assert captured["agent"]["tools"] == [] - assert captured["agent"]["system_prompt"] == base_config.system_prompt + assert captured["agent"]["system_prompt"] is None # system_prompt is merged into initial state messages @pytest.mark.anyio async def test_load_skill_messages_uses_explicit_app_config_for_skill_storage( @@ -331,6 +331,124 @@ class TestAgentConstruction: assert len(messages) == 1 assert "Use demo skill" in messages[0].content + @pytest.mark.anyio + async def test_build_initial_state_consolidates_system_prompt_and_skills( + self, + classes, + base_config, + monkeypatch: pytest.MonkeyPatch, + tmp_path, + ): + """_build_initial_state merges system_prompt and skills into one SystemMessage.""" + SubagentExecutor = classes["SubagentExecutor"] + + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + skill_file = skill_dir / "SKILL.md" + skill_file.write_text("Skill instructions here", encoding="utf-8") + + monkeypatch.setattr( + sys.modules["deerflow.skills.storage"], + "get_or_new_skill_storage", + lambda *, app_config=None: SimpleNamespace(load_skills=lambda *, enabled_only: [SimpleNamespace(name="my-skill", skill_file=skill_file, allowed_tools=None)]), + ) + + executor = SubagentExecutor( + config=base_config, + tools=[], + thread_id="test-thread", + ) + + state, _filtered_tools = await executor._build_initial_state("Do the task") + + messages = state["messages"] + # Should have exactly 2 messages: one combined SystemMessage + one HumanMessage + assert len(messages) == 2 + + from langchain_core.messages import HumanMessage, SystemMessage + + assert isinstance(messages[0], SystemMessage) + assert isinstance(messages[1], HumanMessage) + # SystemMessage should contain both the system_prompt and skill content + assert base_config.system_prompt in messages[0].content + assert "Skill instructions here" in messages[0].content + # HumanMessage should be the task + assert messages[1].content == "Do the task" + + @pytest.mark.anyio + async def test_build_initial_state_no_skills_only_system_prompt( + self, + classes, + base_config, + monkeypatch: pytest.MonkeyPatch, + ): + """_build_initial_state works when there are no skills.""" + SubagentExecutor = classes["SubagentExecutor"] + + monkeypatch.setattr( + sys.modules["deerflow.skills.storage"], + "get_or_new_skill_storage", + lambda *, app_config=None: SimpleNamespace(load_skills=lambda *, enabled_only: []), + ) + + executor = SubagentExecutor( + config=base_config, + tools=[], + thread_id="test-thread", + ) + + state, _filtered_tools = await executor._build_initial_state("Do the task") + + messages = state["messages"] + from langchain_core.messages import HumanMessage, SystemMessage + + assert len(messages) == 2 + assert isinstance(messages[0], SystemMessage) + assert base_config.system_prompt in messages[0].content + assert isinstance(messages[1], HumanMessage) + + @pytest.mark.anyio + async def test_build_initial_state_no_system_prompt_with_skills( + self, + classes, + monkeypatch: pytest.MonkeyPatch, + tmp_path, + ): + """_build_initial_state works when there is no system_prompt but there are skills.""" + SubagentConfig = classes["SubagentConfig"] + + config = SubagentConfig( + name="test-agent", + description="Test agent", + system_prompt=None, + max_turns=10, + timeout_seconds=60, + ) + + skill_dir = tmp_path / "my-skill" + skill_dir.mkdir() + skill_file = skill_dir / "SKILL.md" + skill_file.write_text("Skill content", encoding="utf-8") + + monkeypatch.setattr( + sys.modules["deerflow.skills.storage"], + "get_or_new_skill_storage", + lambda *, app_config=None: SimpleNamespace(load_skills=lambda *, enabled_only: [SimpleNamespace(name="my-skill", skill_file=skill_file, allowed_tools=None)]), + ) + + SubagentExecutor = classes["SubagentExecutor"] + executor = SubagentExecutor(config=config, tools=[], thread_id="test-thread") + + state, _filtered_tools = await executor._build_initial_state("Do the task") + + messages = state["messages"] + from langchain_core.messages import HumanMessage, SystemMessage + + assert len(messages) == 2 + assert isinstance(messages[0], SystemMessage) + assert "Skill content" in messages[0].content + assert isinstance(messages[1], HumanMessage) + # ----------------------------------------------------------------------------- # Async Execution Path Tests @@ -514,6 +632,70 @@ class TestAsyncExecutionPath: assert result.status == SubagentStatus.COMPLETED assert "Task" in result.result + @pytest.mark.anyio + async def test_aexecute_passes_at_most_one_system_message_to_agent( + self, + classes, + base_config, + monkeypatch: pytest.MonkeyPatch, + tmp_path, + ): + """Regression: messages sent to agent.astream must contain at most one + SystemMessage and it must be the first message. + + This catches any regression where system_prompt would be re-injected + via create_agent() (e.g. system_prompt not passed as None) and appear + as a second SystemMessage, which providers like vLLM and Xinference + reject with "System message must be at the beginning." + """ + from langchain_core.messages import AIMessage, SystemMessage + + SubagentExecutor = classes["SubagentExecutor"] + SubagentStatus = classes["SubagentStatus"] + + # Set up a skill so both system_prompt AND skill content are present, + # maximising the chance of catching a double-SystemMessage regression. + skill_dir = tmp_path / "regression-skill" + skill_dir.mkdir() + (skill_dir / "SKILL.md").write_text("Skill instruction text", encoding="utf-8") + + monkeypatch.setattr( + sys.modules["deerflow.skills.storage"], + "get_or_new_skill_storage", + lambda *, app_config=None: SimpleNamespace(load_skills=lambda *, enabled_only: [SimpleNamespace(name="regression-skill", skill_file=skill_dir / "SKILL.md", allowed_tools=None)]), + ) + + captured_states: list[dict] = [] + + async def capturing_astream(state, **kwargs): + captured_states.append(state) + yield {"messages": [AIMessage(content="Done", id="msg-1")]} + + mock_agent = MagicMock() + mock_agent.astream = capturing_astream + + executor = SubagentExecutor( + config=base_config, + tools=[], + thread_id="test-thread", + ) + + with patch.object(executor, "_create_agent", return_value=mock_agent): + result = await executor._aexecute("Do something") + + assert result.status == SubagentStatus.COMPLETED + assert len(captured_states) == 1, "astream should be called exactly once" + initial_messages = captured_states[0]["messages"] + + system_messages = [m for m in initial_messages if isinstance(m, SystemMessage)] + assert len(system_messages) <= 1, f"Expected at most 1 SystemMessage but got {len(system_messages)}: {system_messages}" + if system_messages: + assert initial_messages[0] is system_messages[0], "SystemMessage must be the first message in the conversation" + # The consolidated SystemMessage must carry both the system_prompt + # and all skill content — nothing should be split across two messages. + assert base_config.system_prompt in system_messages[0].content + assert "Skill instruction text" in system_messages[0].content + class TestSkillAllowedTools: @pytest.mark.anyio From c3bc6c7cd5f0208464301f2f4c772d956920b2dd Mon Sep 17 00:00:00 2001 From: AochenShen99 <142667174+ShenAC-SAC@users.noreply.github.com> Date: Mon, 11 May 2026 17:38:37 +0800 Subject: [PATCH 06/86] fix(nginx): defer CORS to gateway allowlist (#2861) * fix(nginx): defer cors to gateway allowlist Remove proxy-level wildcard CORS handling so browser origins are controlled by the Gateway allowlist and stay aligned with CSRF origin checks. * docs: document gateway cors allowlist Clarify that same-origin nginx access needs no CORS headers while split-origin or port-forwarded browser clients must opt in with GATEWAY_CORS_ORIGINS. * docs(gateway): record cors source of truth Document that Gateway CORSMiddleware and CSRFMiddleware share GATEWAY_CORS_ORIGINS as the split-origin source of truth. * fix(gateway): align cors origin normalization * docs: clarify gateway langgraph routing * docs(gateway): update runtime routing note --- .env.example | 5 +- CONTRIBUTING.md | 32 +++++-------- README.md | 2 + backend/CLAUDE.md | 4 +- backend/README.md | 41 ++++++++-------- backend/app/gateway/app.py | 48 +++++++++---------- backend/app/gateway/config.py | 3 -- backend/app/gateway/csrf_middleware.py | 9 +++- backend/docs/API.md | 30 +++++------- backend/docs/ARCHITECTURE.md | 20 ++++---- backend/tests/test_gateway_docs_toggle.py | 42 ++++++++++++++++ backend/tests/test_gateway_runtime_cleanup.py | 23 +++++++++ docker/nginx/nginx.conf | 20 ++------ docker/nginx/nginx.local.conf | 20 ++------ 14 files changed, 169 insertions(+), 130 deletions(-) diff --git a/.env.example b/.env.example index a859ec2a5..43290954b 100644 --- a/.env.example +++ b/.env.example @@ -9,8 +9,9 @@ JINA_API_KEY=your-jina-api-key # InfoQuest API Key INFOQUEST_API_KEY=your-infoquest-api-key -# CORS Origins (comma-separated) - e.g., http://localhost:3000,http://localhost:3001 -# CORS_ORIGINS=http://localhost:3000 +# Browser CORS allowlist for split-origin or port-forwarded deployments (comma-separated exact origins). +# Leave unset when using the unified nginx endpoint, e.g. http://localhost:2026. +# GATEWAY_CORS_ORIGINS=http://localhost:3000,http://127.0.0.1:3000 # Optional: # FIRECRAWL_API_KEY=your-firecrawl-api-key diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b7cb2840b..51b834b4f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -46,12 +46,12 @@ Docker provides a consistent, isolated environment with all dependencies pre-con All services will start with hot-reload enabled: - Frontend changes are automatically reloaded - Backend changes trigger automatic restart - - LangGraph server supports hot-reload + - Gateway-hosted LangGraph-compatible runtime supports hot-reload 4. **Access the application**: - Web Interface: http://localhost:2026 - API Gateway: http://localhost:2026/api/* - - LangGraph: http://localhost:2026/api/langgraph/* + - LangGraph-compatible API: http://localhost:2026/api/langgraph/* #### Docker Commands @@ -94,7 +94,7 @@ Use these as practical starting points for development and review environments: If `make docker-init`, `make docker-start`, or `make docker-stop` fails on Linux with an error like below, your current user likely does not have permission to access the Docker daemon socket: ```text -unable to get image 'deer-flow-dev-langgraph': permission denied while trying to connect to the Docker daemon socket at unix:///var/run/docker.sock +unable to get image 'deer-flow-gateway': permission denied while trying to connect to the Docker daemon socket at unix:///var/run/docker.sock ``` Recommended fix: add your current user to the `docker` group so Docker commands work without `sudo`. @@ -131,9 +131,8 @@ Host Machine Docker Compose (deer-flow-dev) ├→ nginx (port 2026) ← Reverse proxy ├→ web (port 3000) ← Frontend with hot-reload - ├→ api (port 8001) ← Gateway API with hot-reload - ├→ langgraph (port 2024) ← LangGraph server with hot-reload - └→ provisioner (optional, port 8002) ← Started only in provisioner/K8s sandbox mode + ├→ gateway (port 8001) ← Gateway API + LangGraph-compatible runtime with hot-reload + └→ provisioner (optional, port 8002) ← Started only in provisioner/K8s sandbox mode ``` **Benefits of Docker Development**: @@ -184,17 +183,13 @@ Required tools: If you need to start services individually: -1. **Start backend services**: +1. **Start backend service**: ```bash - # Terminal 1: Start LangGraph Server (port 2024) - cd backend - make dev - - # Terminal 2: Start Gateway API (port 8001) + # Terminal 1: Start Gateway API and embedded LangGraph-compatible runtime (port 8001) cd backend make gateway - # Terminal 3: Start Frontend (port 3000) + # Terminal 2: Start Frontend (port 3000) cd frontend pnpm dev ``` @@ -212,10 +207,10 @@ If you need to start services individually: The nginx configuration provides: - Unified entry point on port 2026 -- Routes `/api/langgraph/*` to LangGraph Server (2024) +- Gateway owns `/api/langgraph/*` and translates those public LangGraph-compatible paths to its native `/api/*` routers behind nginx - Routes other `/api/*` endpoints to Gateway API (8001) - Routes non-API requests to Frontend (3000) -- Centralized CORS handling +- Same-origin API routing; split-origin or port-forwarded browser clients should use the Gateway `GATEWAY_CORS_ORIGINS` allowlist - SSE/streaming support for real-time agent responses - Optimized timeouts for long-running operations @@ -235,8 +230,8 @@ deer-flow/ │ └── nginx.local.conf # Nginx config for local dev ├── backend/ # Backend application │ ├── src/ -│ │ ├── gateway/ # Gateway API (port 8001) -│ │ ├── agents/ # LangGraph agents (port 2024) +│ │ ├── gateway/ # Gateway API and LangGraph-compatible runtime (port 8001) +│ │ ├── agents/ # LangGraph agent definitions │ │ ├── mcp/ # Model Context Protocol integration │ │ ├── skills/ # Skills system │ │ └── sandbox/ # Sandbox execution @@ -256,8 +251,7 @@ Browser ↓ Nginx (port 2026) ← Unified entry point ├→ Frontend (port 3000) ← / (non-API requests) - ├→ Gateway API (port 8001) ← /api/models, /api/mcp, /api/skills, /api/threads/*/artifacts - └→ LangGraph Server (port 2024) ← /api/langgraph/* (agent interactions) + └→ Gateway API (port 8001) ← /api/* and /api/langgraph/* (LangGraph-compatible agent interactions) ``` ## Development Workflow diff --git a/README.md b/README.md index 0fc8f173e..9ff1d501b 100644 --- a/README.md +++ b/README.md @@ -245,6 +245,8 @@ make down # Stop and remove containers Access: http://localhost:2026 +The unified nginx endpoint is same-origin by default and does not emit browser CORS headers. If you run a split-origin or port-forwarded browser client, set `GATEWAY_CORS_ORIGINS` to comma-separated exact origins such as `http://localhost:3000`; the Gateway then applies the CORS allowlist and matching CSRF origin checks. + See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide. #### Option 2: Local Development diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 99922a61e..67ee9cc7e 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -207,6 +207,8 @@ Configuration priority: FastAPI application on port 8001 with health check at `GET /health`. Set `GATEWAY_ENABLE_DOCS=false` to disable `/docs`, `/redoc`, and `/openapi.json` in production (default: enabled). +CORS is same-origin by default when requests enter through nginx on port 2026. Split-origin or port-forwarded browser clients must opt in with `GATEWAY_CORS_ORIGINS` (comma-separated exact origins); Gateway `CORSMiddleware` and `CSRFMiddleware` both read that variable so browser CORS and auth-origin checks stay aligned. + **Routers**: | Router | Endpoints | @@ -223,7 +225,7 @@ FastAPI application on port 8001 with health check at `GET /health`. Set `GATEWA | **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific | | **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id | -Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → Gateway. +Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runtime, all other `/api/*` → Gateway REST APIs. ### Sandbox System (`packages/harness/deerflow/sandbox/`) diff --git a/backend/README.md b/backend/README.md index 6295eba22..9b4d26fb1 100644 --- a/backend/README.md +++ b/backend/README.md @@ -14,28 +14,31 @@ DeerFlow is a LangGraph-based AI super agent with sandbox execution, persistent │ │ /api/langgraph/* │ │ /api/* (other) ▼ ▼ - ┌────────────────────┐ ┌────────────────────────┐ - │ LangGraph Server │ │ Gateway API (8001) │ - │ (Port 2024) │ │ FastAPI REST │ - │ │ │ │ - │ ┌────────────────┐ │ │ Models, MCP, Skills, │ - │ │ Lead Agent │ │ │ Memory, Uploads, │ - │ │ ┌──────────┐ │ │ │ Artifacts │ - │ │ │Middleware│ │ │ └────────────────────────┘ - │ │ │ Chain │ │ │ - │ │ └──────────┘ │ │ - │ │ ┌──────────┐ │ │ - │ │ │ Tools │ │ │ - │ │ └──────────┘ │ │ - │ │ ┌──────────┐ │ │ - │ │ │Subagents │ │ │ - │ │ └──────────┘ │ │ - │ └────────────────┘ │ - └────────────────────┘ + ┌──────────────────────────────────────────────┐ + │ Gateway API (8001) │ + │ FastAPI REST + LangGraph-compatible runtime │ + │ │ + │ Models, MCP, Skills, Memory, Uploads, │ + │ Artifacts, Threads, Runs, Streaming │ + │ │ + │ ┌────────────────┐ │ + │ │ Lead Agent │ │ + │ │ ┌──────────┐ │ │ + │ │ │Middleware│ │ │ + │ │ │ Chain │ │ │ + │ │ └──────────┘ │ │ + │ │ ┌──────────┐ │ │ + │ │ │ Tools │ │ │ + │ │ └──────────┘ │ │ + │ │ ┌──────────┐ │ │ + │ │ │Subagents │ │ │ + │ │ └──────────┘ │ │ + │ └────────────────┘ │ + └──────────────────────────────────────────────┘ ``` **Request Routing** (via Nginx): -- `/api/langgraph/*` → LangGraph Server - agent interactions, threads, streaming +- `/api/langgraph/*` → Gateway API - LangGraph-compatible agent interactions, threads, runs, and streaming translated to native `/api/*` routers - `/api/*` (other) → Gateway API - models, MCP, skills, memory, artifacts, uploads, thread-local cleanup - `/` (non-API) → Frontend - Next.js web interface diff --git a/backend/app/gateway/app.py b/backend/app/gateway/app.py index 2a506df2b..8848f473e 100644 --- a/backend/app/gateway/app.py +++ b/backend/app/gateway/app.py @@ -1,6 +1,5 @@ import asyncio import logging -import os from collections.abc import AsyncGenerator from contextlib import asynccontextmanager @@ -9,7 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware from app.gateway.auth_middleware import AuthMiddleware from app.gateway.config import get_gateway_config -from app.gateway.csrf_middleware import CSRFMiddleware +from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins from app.gateway.deps import langgraph_runtime from app.gateway.routers import ( agents, @@ -219,7 +218,9 @@ def create_app() -> FastAPI: Configured FastAPI application instance. """ config = get_gateway_config() - docs_kwargs = {"docs_url": "/docs", "redoc_url": "/redoc", "openapi_url": "/openapi.json"} if config.enable_docs else {"docs_url": None, "redoc_url": None, "openapi_url": None} + docs_url = "/docs" if config.enable_docs else None + redoc_url = "/redoc" if config.enable_docs else None + openapi_url = "/openapi.json" if config.enable_docs else None app = FastAPI( title="DeerFlow API Gateway", @@ -239,12 +240,14 @@ API Gateway for DeerFlow - A LangGraph-based AI agent backend with sandbox execu ### Architecture -LangGraph requests are handled by nginx reverse proxy. -This gateway provides custom endpoints for models, MCP configuration, skills, and artifacts. +LangGraph-compatible requests are routed through nginx to this gateway. +This gateway provides runtime endpoints for agent runs plus custom endpoints for models, MCP configuration, skills, and artifacts. """, version="0.1.0", lifespan=lifespan, - **docs_kwargs, + docs_url=docs_url, + redoc_url=redoc_url, + openapi_url=openapi_url, openapi_tags=[ { "name": "models", @@ -307,25 +310,18 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an # CSRF: Double Submit Cookie pattern for state-changing requests app.add_middleware(CSRFMiddleware) - # CORS: when GATEWAY_CORS_ORIGINS is set (dev without nginx), add CORS middleware. - # In production, nginx handles CORS and no middleware is needed. - cors_origins_env = os.environ.get("GATEWAY_CORS_ORIGINS", "") - if cors_origins_env: - cors_origins = [o.strip() for o in cors_origins_env.split(",") if o.strip()] - # Validate: wildcard origin with credentials is a security misconfiguration - for origin in cors_origins: - if origin == "*": - logger.error("GATEWAY_CORS_ORIGINS contains wildcard '*' with allow_credentials=True. This is a security misconfiguration — browsers will reject the response. Use explicit scheme://host:port origins instead.") - cors_origins = [o for o in cors_origins if o != "*"] - break - if cors_origins: - app.add_middleware( - CORSMiddleware, - allow_origins=cors_origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) + # CORS: the unified nginx endpoint is same-origin by default. Split-origin + # browser clients must opt in with this explicit Gateway allowlist so CORS + # and CSRF origin checks share the same source of truth. + cors_origins = sorted(get_configured_cors_origins()) + if cors_origins: + app.add_middleware( + CORSMiddleware, + allow_origins=cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) # Include routers # Models API is mounted at /api/models @@ -374,7 +370,7 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an app.include_router(runs.router) @app.get("/health", tags=["health"]) - async def health_check() -> dict: + async def health_check() -> dict[str, str]: """Health check endpoint. Returns: diff --git a/backend/app/gateway/config.py b/backend/app/gateway/config.py index 95221dad2..06a7d5b1a 100644 --- a/backend/app/gateway/config.py +++ b/backend/app/gateway/config.py @@ -8,7 +8,6 @@ class GatewayConfig(BaseModel): host: str = Field(default="0.0.0.0", description="Host to bind the gateway server") port: int = Field(default=8001, description="Port to bind the gateway server") - cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:3000"], description="Allowed CORS origins") enable_docs: bool = Field(default=True, description="Enable Swagger/ReDoc/OpenAPI endpoints") @@ -19,11 +18,9 @@ def get_gateway_config() -> GatewayConfig: """Get gateway config, loading from environment if available.""" global _gateway_config if _gateway_config is None: - cors_origins_str = os.getenv("CORS_ORIGINS", "http://localhost:3000") _gateway_config = GatewayConfig( host=os.getenv("GATEWAY_HOST", "0.0.0.0"), port=int(os.getenv("GATEWAY_PORT", "8001")), - cors_origins=cors_origins_str.split(","), enable_docs=os.getenv("GATEWAY_ENABLE_DOCS", "true").lower() == "true", ) return _gateway_config diff --git a/backend/app/gateway/csrf_middleware.py b/backend/app/gateway/csrf_middleware.py index 08e95be4b..f34882032 100644 --- a/backend/app/gateway/csrf_middleware.py +++ b/backend/app/gateway/csrf_middleware.py @@ -6,7 +6,7 @@ State-changing operations require CSRF protection. import os import secrets -from collections.abc import Callable +from collections.abc import Awaitable, Callable from urllib.parse import urlsplit from fastapi import Request, Response @@ -106,6 +106,11 @@ def _configured_cors_origins() -> set[str]: return origins +def get_configured_cors_origins() -> set[str]: + """Return normalized explicit browser origins from GATEWAY_CORS_ORIGINS.""" + return _configured_cors_origins() + + def _first_header_value(value: str | None) -> str | None: """Return the first value from a comma-separated proxy header.""" if not value: @@ -172,7 +177,7 @@ class CSRFMiddleware(BaseHTTPMiddleware): def __init__(self, app: ASGIApp) -> None: super().__init__(app) - async def dispatch(self, request: Request, call_next: Callable) -> Response: + async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: _is_auth = is_auth_endpoint(request) if should_check_csrf(request) and _is_auth and not is_allowed_auth_origin(request): diff --git a/backend/docs/API.md b/backend/docs/API.md index dcefe6779..293c1ebd1 100644 --- a/backend/docs/API.md +++ b/backend/docs/API.md @@ -6,16 +6,16 @@ This document provides a complete reference for the DeerFlow backend APIs. DeerFlow backend exposes two sets of APIs: -1. **LangGraph API** - Agent interactions, threads, and streaming (`/api/langgraph/*`) +1. **LangGraph-compatible API** - Agent interactions, threads, and streaming (`/api/langgraph/*`) 2. **Gateway API** - Models, MCP, skills, uploads, and artifacts (`/api/*`) All APIs are accessed through the Nginx reverse proxy at port 2026. -## LangGraph API +## LangGraph-compatible API Base URL: `/api/langgraph` -The LangGraph API is provided by the LangGraph server and follows the LangGraph SDK conventions. +The public LangGraph-compatible API follows LangGraph SDK conventions. In the unified nginx deployment, Gateway owns `/api/langgraph/*` and translates those paths to its native `/api/*` run, thread, and streaming routers. ### Threads @@ -104,17 +104,11 @@ Content-Type: application/json **Recursion Limit:** `config.recursion_limit` caps the number of graph steps LangGraph will execute -in a single run. The `/api/langgraph/*` endpoints go straight to the LangGraph -server and therefore inherit LangGraph's native default of **25**, which is -too low for plan-mode or subagent-heavy runs — the agent typically errors out -with `GraphRecursionError` after the first round of subagent results comes -back, before the lead agent can synthesize the final answer. - -DeerFlow's own Gateway and IM-channel paths mitigate this by defaulting to -`100` in `build_run_config` (see `backend/app/gateway/services.py`), but -clients calling the LangGraph API directly must set `recursion_limit` -explicitly in the request body. `100` matches the Gateway default and is a -safe starting point; increase it if you run deeply nested subagent graphs. +in a single run. The unified Gateway path defaults to `100` in +`build_run_config` (see `backend/app/gateway/services.py`), which is a safer +starting point for plan-mode or subagent-heavy runs. Clients can still set +`recursion_limit` explicitly in the request body; increase it if you run deeply +nested subagent graphs. **Configurable Options:** - `model_name` (string): Override the default model @@ -649,7 +643,7 @@ curl -X POST http://localhost:2026/api/langgraph/threads/abc123/runs \ }' ``` -> The `/api/langgraph/*` endpoints bypass DeerFlow's Gateway and inherit -> LangGraph's native `recursion_limit` default of 25, which is too low for -> plan-mode or subagent runs. Set `config.recursion_limit` explicitly — see -> the [Create Run](#create-run) section for details. +> The unified Gateway path defaults `config.recursion_limit` to 100 for +> plan-mode and subagent-heavy runs. Clients may still set +> `config.recursion_limit` explicitly — see the [Create Run](#create-run) +> section for details. diff --git a/backend/docs/ARCHITECTURE.md b/backend/docs/ARCHITECTURE.md index cc0993f7f..e6fdbe217 100644 --- a/backend/docs/ARCHITECTURE.md +++ b/backend/docs/ARCHITECTURE.md @@ -14,8 +14,8 @@ This document provides a comprehensive overview of the DeerFlow backend architec │ Nginx (Port 2026) │ │ Unified Reverse Proxy Entry Point │ │ ┌────────────────────────────────────────────────────────────────────┐ │ -│ │ /api/langgraph/* → LangGraph Server (2024) │ │ -│ │ /api/* → Gateway API (8001) │ │ +│ │ /api/langgraph/* → Gateway LangGraph-compatible runtime (8001) │ │ +│ │ /api/* → Gateway REST APIs (8001) │ │ │ │ /* → Frontend (3000) │ │ │ └────────────────────────────────────────────────────────────────────┘ │ └─────────────────────────────────┬────────────────────────────────────────┘ @@ -24,8 +24,8 @@ This document provides a comprehensive overview of the DeerFlow backend architec │ │ │ ▼ ▼ ▼ ┌─────────────────────┐ ┌─────────────────────┐ ┌─────────────────────┐ -│ LangGraph Server │ │ Gateway API │ │ Frontend │ -│ (Port 2024) │ │ (Port 8001) │ │ (Port 3000) │ +│ Embedded Runtime │ │ Gateway API │ │ Frontend │ +│ (inside Gateway) │ │ (Port 8001) │ │ (Port 3000) │ │ │ │ │ │ │ │ - Agent Runtime │ │ - Models API │ │ - Next.js App │ │ - Thread Mgmt │ │ - MCP Config │ │ - React UI │ @@ -52,9 +52,9 @@ This document provides a comprehensive overview of the DeerFlow backend architec ## Component Details -### LangGraph Server +### Embedded LangGraph Runtime -The LangGraph server is the core agent runtime, built on LangGraph for robust multi-agent workflow orchestration. +The LangGraph-compatible runtime runs inside the Gateway process and is built on LangGraph for robust multi-agent workflow orchestration. **Entry Point**: `packages/harness/deerflow/agents/lead_agent/agent.py:make_lead_agent` @@ -78,7 +78,7 @@ The LangGraph server is the core agent runtime, built on LangGraph for robust mu ### Gateway API -FastAPI application providing REST endpoints for non-agent operations. +FastAPI application providing REST endpoints plus the public LangGraph-compatible `/api/langgraph/*` runtime routes. **Entry Point**: `app/gateway/app.py` @@ -353,10 +353,10 @@ SKILL.md Format: POST /api/langgraph/threads/{thread_id}/runs {"input": {"messages": [{"role": "user", "content": "Hello"}]}} -2. Nginx → LangGraph Server (2024) - Proxied to LangGraph server +2. Nginx → Gateway API (8001) + Routes `/api/langgraph/*` to the Gateway's LangGraph-compatible runtime -3. LangGraph Server +3. Embedded LangGraph runtime a. Load/create thread state b. Execute middleware chain: - ThreadDataMiddleware: Set up paths diff --git a/backend/tests/test_gateway_docs_toggle.py b/backend/tests/test_gateway_docs_toggle.py index 54392ee2e..372f93e18 100644 --- a/backend/tests/test_gateway_docs_toggle.py +++ b/backend/tests/test_gateway_docs_toggle.py @@ -122,3 +122,45 @@ def test_health_still_works_when_docs_disabled(): resp = client.get("/health") assert resp.status_code == 200 assert resp.json()["status"] == "healthy" + + +# --------------------------------------------------------------------------- +# Runtime CORS behavior +# --------------------------------------------------------------------------- + + +def _make_gateway_client(cors_origins: str) -> TestClient: + with patch.dict(os.environ, {"GATEWAY_CORS_ORIGINS": cors_origins}): + _reset_gateway_config() + from app.gateway.app import create_app + + return TestClient(create_app()) + + +def test_gateway_cors_allows_configured_origin(): + """GATEWAY_CORS_ORIGINS should control actual browser CORS responses.""" + client = _make_gateway_client("https://app.example") + + response = client.get("/health", headers={"Origin": "https://app.example"}) + + assert response.status_code == 200 + assert response.headers["access-control-allow-origin"] == "https://app.example" + assert response.headers["access-control-allow-credentials"] == "true" + + +def test_gateway_cors_rejects_unconfigured_origin(): + client = _make_gateway_client("https://app.example") + + response = client.get("/health", headers={"Origin": "https://evil.example"}) + + assert response.status_code == 200 + assert "access-control-allow-origin" not in response.headers + + +def test_gateway_cors_normalizes_configured_default_port(): + client = _make_gateway_client("https://app.example:443") + + response = client.get("/health", headers={"Origin": "https://app.example"}) + + assert response.status_code == 200 + assert response.headers["access-control-allow-origin"] == "https://app.example" diff --git a/backend/tests/test_gateway_runtime_cleanup.py b/backend/tests/test_gateway_runtime_cleanup.py index 3bf7c1a5b..895e04885 100644 --- a/backend/tests/test_gateway_runtime_cleanup.py +++ b/backend/tests/test_gateway_runtime_cleanup.py @@ -53,6 +53,29 @@ def test_nginx_routes_official_langgraph_prefix_to_gateway_api(): assert "proxy_pass http://gateway" in content or "proxy_pass http://$gateway_upstream" in content +def test_nginx_defers_cors_to_gateway_allowlist(): + for path in ("docker/nginx/nginx.local.conf", "docker/nginx/nginx.conf"): + content = _read(path) + + assert "Access-Control-Allow-Origin" not in content + assert "Access-Control-Allow-Methods" not in content + assert "Access-Control-Allow-Headers" not in content + assert "Access-Control-Allow-Credentials" not in content + assert "proxy_hide_header 'Access-Control-Allow-" not in content + assert "if ($request_method = 'OPTIONS')" not in content + + +def test_gateway_cors_configuration_uses_gateway_allowlist(): + gateway_config = _read("backend/app/gateway/config.py") + gateway_app = _read("backend/app/gateway/app.py") + csrf_middleware = _read("backend/app/gateway/csrf_middleware.py") + + assert not re.search(r"(? Date: Mon, 11 May 2026 13:54:00 +0200 Subject: [PATCH 07/86] fix(runtime): persist run message summaries (#2850) * fix(runtime): persist run message summaries (#2849) * fix(runtime): dedupe run message summaries --- .../harness/deerflow/runtime/journal.py | 56 ++++++++++- backend/tests/test_run_journal.py | 93 +++++++++++++++++++ 2 files changed, 148 insertions(+), 1 deletion(-) diff --git a/backend/packages/harness/deerflow/runtime/journal.py b/backend/packages/harness/deerflow/runtime/journal.py index 41e48efed..8a9382e23 100644 --- a/backend/packages/harness/deerflow/runtime/journal.py +++ b/backend/packages/harness/deerflow/runtime/journal.py @@ -20,12 +20,13 @@ from __future__ import annotations import asyncio import logging import time +from collections.abc import Mapping from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, cast from uuid import UUID from langchain_core.callbacks import BaseCallbackHandler -from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage +from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, HumanMessage, ToolMessage from langgraph.types import Command if TYPE_CHECKING: @@ -71,6 +72,7 @@ class RunJournal(BaseCallbackHandler): # Dedup: LangChain may fire on_llm_end multiple times for the same run_id self._counted_llm_run_ids: set[str] = set() self._counted_external_source_ids: set[str] = set() + self._counted_message_llm_run_ids: set[str] = set() # Convenience fields self._last_ai_msg: str | None = None @@ -86,6 +88,50 @@ class RunJournal(BaseCallbackHandler): # -- Lifecycle callbacks -- + @staticmethod + def _message_text(message: BaseMessage) -> str: + """Extract displayable text from a message's mixed content shape.""" + content = getattr(message, "content", None) + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for block in content: + if isinstance(block, str): + parts.append(block) + elif isinstance(block, Mapping): + text = block.get("text") + if isinstance(text, str): + parts.append(text) + else: + nested = block.get("content") + if isinstance(nested, str): + parts.append(nested) + return "".join(parts) + if isinstance(content, Mapping): + for key in ("text", "content"): + value = content.get(key) + if isinstance(value, str): + return value + + text = getattr(message, "text", None) + if isinstance(text, str): + return text + return "" + + def _record_message_summary(self, message: BaseMessage, *, caller: str | None = None) -> None: + """Update run-level convenience fields for persisted run rows.""" + self._msg_count += 1 + + # ``last_ai_message`` should represent the lead agent's user-facing + # answer. Middleware/subagent model calls and empty tool-call-only + # AI messages must not overwrite the last useful assistant text. + is_ai_message = isinstance(message, AIMessage) or getattr(message, "type", None) == "ai" + if is_ai_message and (caller is None or caller == "lead_agent"): + text = self._message_text(message).strip() + if text: + self._last_ai_msg = text[:2000] + def on_chain_start( self, serialized: dict[str, Any], @@ -164,6 +210,7 @@ class RunJournal(BaseCallbackHandler): content=m.model_dump(), metadata={"caller": caller}, ) + self._record_message_summary(m, caller=caller) break if self._first_human_msg: break @@ -222,6 +269,8 @@ class RunJournal(BaseCallbackHandler): "llm_call_index": call_index, }, ) + if rid not in self._counted_message_llm_run_ids: + self._record_message_summary(message, caller=caller) # Token accumulation (dedup by langchain run_id to avoid double-counting # when the callback fires more than once for the same response) @@ -245,6 +294,9 @@ class RunJournal(BaseCallbackHandler): else: self._lead_agent_tokens += total_tk + if messages: + self._counted_message_llm_run_ids.add(str(run_id)) + def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: self._llm_start_times.pop(str(run_id), None) self._put(event_type="llm.error", category="trace", content=str(error)) @@ -260,12 +312,14 @@ class RunJournal(BaseCallbackHandler): if isinstance(output, ToolMessage): msg = cast(ToolMessage, output) self._put(event_type="llm.tool.result", category="message", content=msg.model_dump()) + self._record_message_summary(msg) elif isinstance(output, Command): cmd = cast(Command, output) messages = cmd.update.get("messages", []) for message in messages: if isinstance(message, BaseMessage): self._put(event_type="llm.tool.result", category="message", content=message.model_dump()) + self._record_message_summary(message) else: logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}") else: diff --git a/backend/tests/test_run_journal.py b/backend/tests/test_run_journal.py index 27c05619c..8615caa49 100644 --- a/backend/tests/test_run_journal.py +++ b/backend/tests/test_run_journal.py @@ -339,6 +339,99 @@ class TestConvenienceFields: data = j.get_completion_data() assert data["first_human_message"] == "What is AI?" + @pytest.mark.anyio + async def test_completion_data_counts_human_ai_and_tool_messages(self, journal_setup): + from langchain_core.messages import HumanMessage, ToolMessage + + j, _ = journal_setup + j.on_chat_model_start({}, [[HumanMessage(content="Question")]], run_id=uuid4(), tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("Answer"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) + j.on_tool_end(ToolMessage(content="Tool result", tool_call_id="call_1", name="search"), run_id=uuid4()) + + data = j.get_completion_data() + + assert data["message_count"] == 3 + assert data["first_human_message"] == "Question" + assert data["last_ai_message"] == "Answer" + + @pytest.mark.anyio + async def test_tool_call_only_ai_does_not_clear_last_ai_message(self, journal_setup): + j, _ = journal_setup + j.on_llm_end(_make_llm_response("Useful answer"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) + j.on_llm_end( + _make_llm_response("", tool_calls=[{"id": "call_1", "name": "search", "args": {}}]), + run_id=uuid4(), + parent_run_id=None, + tags=["lead_agent"], + ) + + data = j.get_completion_data() + + assert data["message_count"] == 2 + assert data["last_ai_message"] == "Useful answer" + + @pytest.mark.anyio + async def test_last_ai_message_extracts_mixed_content_without_extra_newlines(self, journal_setup): + j, _ = journal_setup + j.on_llm_end( + _make_llm_response( + [ + {"type": "text", "text": "First "}, + {"type": "text", "content": "second"}, + " third", + {"type": "image", "url": "ignored"}, + ] + ), + run_id=uuid4(), + parent_run_id=None, + tags=["lead_agent"], + ) + + data = j.get_completion_data() + + assert data["message_count"] == 1 + assert data["last_ai_message"] == "First second third" + + @pytest.mark.anyio + async def test_last_ai_message_extracts_mapping_content(self, journal_setup): + j, _ = journal_setup + j.on_llm_end(_make_llm_response({"content": "Nested answer"}), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) + + data = j.get_completion_data() + + assert data["message_count"] == 1 + assert data["last_ai_message"] == "Nested answer" + + @pytest.mark.anyio + async def test_duplicate_llm_run_id_does_not_double_count_message_summary(self, journal_setup): + j, _ = journal_setup + run_id = uuid4() + + j.on_llm_end(_make_llm_response("Answer", usage=None), run_id=run_id, parent_run_id=None, tags=["lead_agent"]) + j.on_llm_end( + _make_llm_response("Answer", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), + run_id=run_id, + parent_run_id=None, + tags=["lead_agent"], + ) + + data = j.get_completion_data() + + assert data["message_count"] == 1 + assert data["last_ai_message"] == "Answer" + assert data["total_tokens"] == 15 + + @pytest.mark.anyio + async def test_subagent_ai_does_not_overwrite_lead_last_ai_message(self, journal_setup): + j, _ = journal_setup + j.on_llm_end(_make_llm_response("Lead answer"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("Subagent detail"), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"]) + + data = j.get_completion_data() + + assert data["message_count"] == 2 + assert data["last_ai_message"] == "Lead answer" + @pytest.mark.anyio async def test_get_completion_data(self, journal_setup): j, _ = journal_setup From de253e4a0a9e4bcfa5fb3ce20e280fc8737ec5fc Mon Sep 17 00:00:00 2001 From: Yi Tang <6054101+yitang@users.noreply.github.com> Date: Mon, 11 May 2026 14:45:18 +0100 Subject: [PATCH 08/86] feat(run): Propagates `model_name` from the gateway request through the runtime and persistence stack to the SQLite database. (#2775) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(run): propagate model_name from gateway request context to persistence layer Pass model_name through the full run creation pipeline — from RunCreateRequest.context in the gateway, through RunManager, to the RunStore interface and SQL persistence. This enables client-specified model selection to be recorded per-run in the database. * feat(run): add model allowlist validation and effective model name capture - Validate model_name against allowlist in gateway services.py using get_app_config().get_model_config() - Truncate model_name to 128 chars to match DB column constraint - In worker.py, capture effective model name from agent.metadata after agent creation and persist if resolved differently than requested * feat(run): add defense-in-depth model_name normalization and round-trip persistence tests - Add _normalize_model_name() to RunRepository for whitespace stripping and 128-char truncation before DB writes. - Add round-trip unit tests for model_name creation and default None in test_run_manager.py. * fix(run): coerce non-string model_name values before strip/truncate in _normalize_model_name * fix(gateway): add runtime type guard for model_name coercion in gateway services Add isinstance check and str() coercion before calling .strip() to prevent AttributeError when non-string types (int, None, etc.) flow through the gateway. Paired with SQL integration test for end-to-end model_name persistence across gateway → langgraph → persistence layer. * fix(run): drop Alembic migration for model_name (no-op) and expose public update method on RunManager - Drop a1b2c3d4e5f6 migration: model_name already exists in RunRow schema and is auto-created via Base.metadata.create_all() at startup - Add update_model_name() public method to RunManager to replace the private _persist_to_store call in worker.py, preserving internal locking/persistence --- backend/app/gateway/services.py | 19 +++++++ .../harness/deerflow/persistence/run/sql.py | 14 +++++ .../harness/deerflow/runtime/runs/manager.py | 16 ++++++ .../deerflow/runtime/runs/store/base.py | 1 + .../deerflow/runtime/runs/store/memory.py | 2 + .../harness/deerflow/runtime/runs/worker.py | 11 ++++ backend/tests/test_run_manager.py | 51 +++++++++++++++++++ backend/tests/test_run_repository.py | 29 +++++++++++ 8 files changed, 143 insertions(+) diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index 0cbea4faf..96521b86f 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -19,6 +19,7 @@ from langchain_core.messages import HumanMessage from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge from app.gateway.utils import sanitize_log_param +from deerflow.config.app_config import get_app_config from deerflow.runtime import ( END_SENTINEL, HEARTBEAT_SENTINEL, @@ -267,6 +268,23 @@ async def start_run( disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_ + body_context = getattr(body, "context", None) or {} + model_name = body_context.get("model_name") + + # Coerce non-string model_name values to str before truncation. + if model_name is not None and not isinstance(model_name, str): + model_name = str(model_name) + + # Validate model against the allowlist when a model_name is provided. + if model_name: + app_config = get_app_config() + resolved = app_config.get_model_config(model_name) + if resolved is None: + raise HTTPException( + status_code=400, + detail=f"Model {model_name!r} is not in the configured model allowlist", + ) + try: record = await run_mgr.create_or_reject( thread_id, @@ -275,6 +293,7 @@ async def start_run( metadata=body.metadata or {}, kwargs={"input": body.input, "config": body.config}, multitask_strategy=body.multitask_strategy, + model_name=model_name, ) except ConflictError as exc: raise HTTPException(status_code=409, detail=str(exc)) from exc diff --git a/backend/packages/harness/deerflow/persistence/run/sql.py b/backend/packages/harness/deerflow/persistence/run/sql.py index fcd1a3411..430fbe4f6 100644 --- a/backend/packages/harness/deerflow/persistence/run/sql.py +++ b/backend/packages/harness/deerflow/persistence/run/sql.py @@ -23,6 +23,18 @@ class RunRepository(RunStore): def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: self._sf = session_factory + @staticmethod + def _normalize_model_name(model_name: str | None) -> str | None: + """Normalize model_name for storage: strip whitespace, truncate to 128 chars.""" + if model_name is None: + return None + if not isinstance(model_name, str): + model_name = str(model_name) + normalized = model_name.strip() + if len(normalized) > 128: + normalized = normalized[:128] + return normalized + @staticmethod def _safe_json(obj: Any) -> Any: """Ensure obj is JSON-serializable. Falls back to model_dump() or str().""" @@ -70,6 +82,7 @@ class RunRepository(RunStore): thread_id, assistant_id=None, user_id: str | None | _AutoSentinel = AUTO, + model_name: str | None = None, status="pending", multitask_strategy="reject", metadata=None, @@ -85,6 +98,7 @@ class RunRepository(RunStore): thread_id=thread_id, assistant_id=assistant_id, user_id=resolved_user_id, + model_name=self._normalize_model_name(model_name), status=status, multitask_strategy=multitask_strategy, metadata_json=self._safe_json(metadata) or {}, diff --git a/backend/packages/harness/deerflow/runtime/runs/manager.py b/backend/packages/harness/deerflow/runtime/runs/manager.py index 533342c87..50dc594ab 100644 --- a/backend/packages/harness/deerflow/runtime/runs/manager.py +++ b/backend/packages/harness/deerflow/runtime/runs/manager.py @@ -36,6 +36,7 @@ class RunRecord: abort_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False) abort_action: str = "interrupt" error: str | None = None + model_name: str | None = None class RunManager: @@ -65,6 +66,7 @@ class RunManager: metadata=record.metadata or {}, kwargs=record.kwargs or {}, created_at=record.created_at, + model_name=record.model_name, ) except Exception: logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True) @@ -137,6 +139,18 @@ class RunManager: logger.warning("Failed to persist status update for run %s", run_id, exc_info=True) logger.info("Run %s -> %s", run_id, status.value) + async def update_model_name(self, run_id: str, model_name: str | None) -> None: + """Update the model name for a run.""" + async with self._lock: + record = self._runs.get(run_id) + if record is None: + logger.warning("update_model_name called for unknown run %s", run_id) + return + record.model_name = model_name + record.updated_at = _now_iso() + await self._persist_to_store(record) + logger.info("Run %s model_name=%s", run_id, model_name) + async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool: """Request cancellation of a run. @@ -171,6 +185,7 @@ class RunManager: metadata: dict | None = None, kwargs: dict | None = None, multitask_strategy: str = "reject", + model_name: str | None = None, ) -> RunRecord: """Atomically check for inflight runs and create a new one. @@ -221,6 +236,7 @@ class RunManager: kwargs=kwargs or {}, created_at=now, updated_at=now, + model_name=model_name, ) self._runs[run_id] = record diff --git a/backend/packages/harness/deerflow/runtime/runs/store/base.py b/backend/packages/harness/deerflow/runtime/runs/store/base.py index 518a1903c..d3c10eba6 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/base.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/base.py @@ -23,6 +23,7 @@ class RunStore(abc.ABC): thread_id: str, assistant_id: str | None = None, user_id: str | None = None, + model_name: str | None = None, status: str = "pending", multitask_strategy: str = "reject", metadata: dict[str, Any] | None = None, diff --git a/backend/packages/harness/deerflow/runtime/runs/store/memory.py b/backend/packages/harness/deerflow/runtime/runs/store/memory.py index 5a14af3df..e41147e3e 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/memory.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/memory.py @@ -22,6 +22,7 @@ class MemoryRunStore(RunStore): thread_id, assistant_id=None, user_id=None, + model_name=None, status="pending", multitask_strategy="reject", metadata=None, @@ -35,6 +36,7 @@ class MemoryRunStore(RunStore): "thread_id": thread_id, "assistant_id": assistant_id, "user_id": user_id, + "model_name": model_name, "status": status, "multitask_strategy": multitask_strategy, "metadata": metadata or {}, diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index 2aecb9a1b..f78d425a2 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -230,6 +230,17 @@ async def run_agent( else: agent = agent_factory(config=runnable_config) + # Capture the effective (resolved) model name from the agent's metadata. + # _resolve_model_name in agent.py may return the default model if the + # requested name is not in the allowlist — this update ensures the + # persisted model_name reflects the actual model used. + if record.model_name is not None: + resolved = getattr(agent, "metadata", {}) or {} + if isinstance(resolved, dict): + effective = resolved.get("model_name") + if effective and effective != record.model_name: + await run_manager.update_model_name(record.run_id, effective) + # 4. Attach checkpointer and store if checkpointer is not None: agent.checkpointer = checkpointer diff --git a/backend/tests/test_run_manager.py b/backend/tests/test_run_manager.py index 58ecf1f26..98cd58264 100644 --- a/backend/tests/test_run_manager.py +++ b/backend/tests/test_run_manager.py @@ -5,6 +5,7 @@ import re import pytest from deerflow.runtime import RunManager, RunStatus +from deerflow.runtime.runs.store.memory import MemoryRunStore ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}") @@ -141,3 +142,53 @@ async def test_create_defaults(manager: RunManager): assert record.kwargs == {} assert record.multitask_strategy == "reject" assert record.assistant_id is None + + +@pytest.mark.anyio +async def test_model_name_create_or_reject(): + """create_or_reject should accept and persist model_name.""" + from deerflow.runtime.runs.schemas import DisconnectMode + + store = MemoryRunStore() + mgr = RunManager(store=store) + + record = await mgr.create_or_reject( + "thread-1", + assistant_id="lead_agent", + on_disconnect=DisconnectMode.cancel, + metadata={"key": "val"}, + kwargs={"input": {}}, + multitask_strategy="reject", + model_name="anthropic.claude-sonnet-4-20250514-v1:0", + ) + assert record.model_name == "anthropic.claude-sonnet-4-20250514-v1:0" + assert record.status == RunStatus.pending + + # Verify model_name was persisted to store + stored = await store.get(record.run_id) + assert stored is not None + assert stored["model_name"] == "anthropic.claude-sonnet-4-20250514-v1:0" + + # Verify retrieval returns the model_name via in-memory record + fetched = mgr.get(record.run_id) + assert fetched is not None + assert fetched.model_name == "anthropic.claude-sonnet-4-20250514-v1:0" + + +@pytest.mark.anyio +async def test_model_name_default_is_none(): + """create_or_reject without model_name should default to None.""" + from deerflow.runtime.runs.schemas import DisconnectMode + + store = MemoryRunStore() + mgr = RunManager(store=store) + + record = await mgr.create_or_reject( + "thread-1", + on_disconnect=DisconnectMode.cancel, + model_name=None, + ) + assert record.model_name is None + + stored = await store.get(record.run_id) + assert stored["model_name"] is None diff --git a/backend/tests/test_run_repository.py b/backend/tests/test_run_repository.py index bff49206d..6fd534829 100644 --- a/backend/tests/test_run_repository.py +++ b/backend/tests/test_run_repository.py @@ -249,3 +249,32 @@ class TestRunRepository: rows = await repo.list_by_thread("t1", user_id=None) assert len(rows) == 2 await _cleanup() + + @pytest.mark.anyio + async def test_model_name_persistence(self, tmp_path): + """RunRepository should persist, normalize, and truncate model_name correctly via SQL.""" + from deerflow.persistence.engine import get_session_factory, init_engine + + url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" + await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) + repo = RunRepository(get_session_factory()) + + await repo.put("run-1", thread_id="thread-1", model_name="gpt-4o") + row = await repo.get("run-1") + assert row is not None + assert row["model_name"] == "gpt-4o" + + long_name = "a" * 200 + await repo.put("run-2", thread_id="thread-1", model_name=long_name) + row2 = await repo.get("run-2") + assert row2["model_name"] == "a" * 128 + + await repo.put("run-3", thread_id="thread-1", model_name=123) + row3 = await repo.get("run-3") + assert row3["model_name"] == "123" + + await repo.put("run-4", thread_id="thread-1", model_name=None) + row4 = await repo.get("run-4") + assert row4["model_name"] is None + + await _cleanup() From bedbf2291e182a53c7be6bece9485d44300d1925 Mon Sep 17 00:00:00 2001 From: AochenShen99 <142667174+ShenAC-SAC@users.noreply.github.com> Date: Mon, 11 May 2026 22:14:13 +0800 Subject: [PATCH 09/86] fix(harness): wrap async-only config tools for sync client execution (#2878) * fix(harness): wrap async-only config tools for sync clients * refactor(tools): share async tool sync wrapper --- .../packages/harness/deerflow/mcp/tools.py | 45 +------------------ .../deerflow/tools/skill_manage_tool.py | 4 +- .../packages/harness/deerflow/tools/sync.py | 36 +++++++++++++++ .../packages/harness/deerflow/tools/tools.py | 10 ++++- backend/tests/test_mcp_sync_wrapper.py | 16 +++---- backend/tests/test_tool_deduplication.py | 42 ++++++++++++++++- 6 files changed, 98 insertions(+), 55 deletions(-) create mode 100644 backend/packages/harness/deerflow/tools/sync.py diff --git a/backend/packages/harness/deerflow/mcp/tools.py b/backend/packages/harness/deerflow/mcp/tools.py index bcd50c645..d27641692 100644 --- a/backend/packages/harness/deerflow/mcp/tools.py +++ b/backend/packages/harness/deerflow/mcp/tools.py @@ -1,11 +1,6 @@ """Load MCP tools using langchain-mcp-adapters.""" -import asyncio -import atexit -import concurrent.futures import logging -from collections.abc import Callable -from typing import Any from langchain_core.tools import BaseTool @@ -13,46 +8,10 @@ from deerflow.config.extensions_config import ExtensionsConfig from deerflow.mcp.client import build_servers_config from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers from deerflow.reflection import resolve_variable +from deerflow.tools.sync import make_sync_tool_wrapper logger = logging.getLogger(__name__) -# Global thread pool for sync tool invocation in async environments -_SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thread_name_prefix="mcp-sync-tool") - -# Register shutdown hook for the global executor -atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False)) - - -def _make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]: - """Build a synchronous wrapper for an asynchronous tool coroutine. - - Args: - coro: The tool's asynchronous coroutine. - tool_name: Name of the tool (for logging). - - Returns: - A synchronous function that correctly handles nested event loops. - """ - - def sync_wrapper(*args: Any, **kwargs: Any) -> Any: - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - - try: - if loop is not None and loop.is_running(): - # Use global executor to avoid nested loop issues and improve performance - future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs)) - return future.result() - else: - return asyncio.run(coro(*args, **kwargs)) - except Exception as e: - logger.error(f"Error invoking MCP tool '{tool_name}' via sync wrapper: {e}", exc_info=True) - raise - - return sync_wrapper - async def get_mcp_tools() -> list[BaseTool]: """Get all tools from enabled MCP servers. @@ -126,7 +85,7 @@ async def get_mcp_tools() -> list[BaseTool]: # Patch tools to support sync invocation, as deerflow client streams synchronously for tool in tools: if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None: - tool.func = _make_sync_tool_wrapper(tool.coroutine, tool.name) + tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name) return tools diff --git a/backend/packages/harness/deerflow/tools/skill_manage_tool.py b/backend/packages/harness/deerflow/tools/skill_manage_tool.py index 46865242c..2a39732bc 100644 --- a/backend/packages/harness/deerflow/tools/skill_manage_tool.py +++ b/backend/packages/harness/deerflow/tools/skill_manage_tool.py @@ -10,11 +10,11 @@ from weakref import WeakValueDictionary from langchain.tools import tool from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async -from deerflow.mcp.tools import _make_sync_tool_wrapper from deerflow.skills.security_scanner import scan_skill_content from deerflow.skills.storage import get_or_new_skill_storage from deerflow.skills.storage.skill_storage import SkillStorage from deerflow.skills.types import SKILL_MD_FILE +from deerflow.tools.sync import make_sync_tool_wrapper from deerflow.tools.types import Runtime logger = logging.getLogger(__name__) @@ -235,4 +235,4 @@ async def skill_manage_tool( ) -skill_manage_tool.func = _make_sync_tool_wrapper(_skill_manage_impl, "skill_manage") +skill_manage_tool.func = make_sync_tool_wrapper(_skill_manage_impl, "skill_manage") diff --git a/backend/packages/harness/deerflow/tools/sync.py b/backend/packages/harness/deerflow/tools/sync.py new file mode 100644 index 000000000..c2b80781a --- /dev/null +++ b/backend/packages/harness/deerflow/tools/sync.py @@ -0,0 +1,36 @@ +"""Utilities for invoking async tools from synchronous agent paths.""" + +import asyncio +import atexit +import concurrent.futures +import logging +from collections.abc import Callable +from typing import Any + +logger = logging.getLogger(__name__) + +# Shared thread pool for sync tool invocation in async environments. +_SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thread_name_prefix="tool-sync") + +atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False)) + + +def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]: + """Build a synchronous wrapper for an asynchronous tool coroutine.""" + + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + try: + if loop is not None and loop.is_running(): + future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs)) + return future.result() + return asyncio.run(coro(*args, **kwargs)) + except Exception as e: + logger.error("Error invoking tool %r via sync wrapper: %s", tool_name, e, exc_info=True) + raise + + return sync_wrapper diff --git a/backend/packages/harness/deerflow/tools/tools.py b/backend/packages/harness/deerflow/tools/tools.py index 14d93e65f..01bfce43f 100644 --- a/backend/packages/harness/deerflow/tools/tools.py +++ b/backend/packages/harness/deerflow/tools/tools.py @@ -8,6 +8,7 @@ from deerflow.reflection import resolve_variable from deerflow.sandbox.security import is_host_bash_allowed from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool from deerflow.tools.builtins.tool_search import reset_deferred_registry +from deerflow.tools.sync import make_sync_tool_wrapper logger = logging.getLogger(__name__) @@ -33,6 +34,13 @@ def _is_host_bash_tool(tool: object) -> bool: return False +def _ensure_sync_invocable_tool(tool: BaseTool) -> BaseTool: + """Attach a sync wrapper to async-only tools used by sync agent callers.""" + if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None: + tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name) + return tool + + def get_available_tools( groups: list[str] | None = None, include_mcp: bool = True, @@ -77,7 +85,7 @@ def get_available_tools( cfg.use, ) - loaded_tools = [t for _, t in loaded_tools_raw] + loaded_tools = [_ensure_sync_invocable_tool(t) for _, t in loaded_tools_raw] # Conditionally add tools based on config builtin_tools = BUILTIN_TOOLS.copy() diff --git a/backend/tests/test_mcp_sync_wrapper.py b/backend/tests/test_mcp_sync_wrapper.py index 376d1a790..285200781 100644 --- a/backend/tests/test_mcp_sync_wrapper.py +++ b/backend/tests/test_mcp_sync_wrapper.py @@ -5,7 +5,8 @@ import pytest from langchain_core.tools import StructuredTool from pydantic import BaseModel, Field -from deerflow.mcp.tools import _make_sync_tool_wrapper, get_mcp_tools +from deerflow.mcp.tools import get_mcp_tools +from deerflow.tools.sync import make_sync_tool_wrapper class MockArgs(BaseModel): @@ -51,14 +52,13 @@ def test_mcp_tool_sync_wrapper_generation(): def test_mcp_tool_sync_wrapper_in_running_loop(): - """Test the actual helper function from production code (Fix for Comment 1 & 3).""" + """Test the shared sync wrapper from production code.""" async def mock_coro(x: int): await asyncio.sleep(0.01) return f"async_result: {x}" - # Test the real helper function exported from deerflow.mcp.tools - sync_func = _make_sync_tool_wrapper(mock_coro, "test_tool") + sync_func = make_sync_tool_wrapper(mock_coro, "test_tool") async def run_in_loop(): # This call should succeed due to ThreadPoolExecutor in the real helper @@ -70,16 +70,16 @@ def test_mcp_tool_sync_wrapper_in_running_loop(): def test_mcp_tool_sync_wrapper_exception_logging(): - """Test the actual helper's error logging (Fix for Comment 3).""" + """Test the shared sync wrapper's error logging.""" async def error_coro(): raise ValueError("Tool failure") - sync_func = _make_sync_tool_wrapper(error_coro, "error_tool") + sync_func = make_sync_tool_wrapper(error_coro, "error_tool") - with patch("deerflow.mcp.tools.logger.error") as mock_log_error: + with patch("deerflow.tools.sync.logger.error") as mock_log_error: with pytest.raises(ValueError, match="Tool failure"): sync_func() mock_log_error.assert_called_once() # Verify the tool name is in the log message - assert "error_tool" in mock_log_error.call_args[0][0] + assert mock_log_error.call_args[0][1] == "error_tool" diff --git a/backend/tests/test_tool_deduplication.py b/backend/tests/test_tool_deduplication.py index 35ec0bea6..ed9efffaf 100644 --- a/backend/tests/test_tool_deduplication.py +++ b/backend/tests/test_tool_deduplication.py @@ -10,7 +10,8 @@ from __future__ import annotations from unittest.mock import MagicMock, patch -from langchain_core.tools import BaseTool, tool +from langchain_core.tools import BaseTool, StructuredTool, tool +from pydantic import BaseModel, Field from deerflow.tools.tools import get_available_tools @@ -19,6 +20,10 @@ from deerflow.tools.tools import get_available_tools # --------------------------------------------------------------------------- +class AsyncToolArgs(BaseModel): + x: int = Field(..., description="test input") + + @tool def _tool_alpha(x: str) -> str: """Alpha tool.""" @@ -52,10 +57,45 @@ def _make_minimal_config(tools): config.tools = tools config.models = [] config.tool_search.enabled = False + config.skill_evolution.enabled = False config.sandbox = MagicMock() + config.acp_agents = {} return config +@patch("deerflow.tools.tools.get_app_config") +@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) +@patch("deerflow.tools.tools.reset_deferred_registry") +def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash, mock_cfg): + """Config-loaded async-only tools can still be invoked by sync clients.""" + + async def async_tool_impl(x: int) -> str: + return f"result: {x}" + + async_tool = StructuredTool( + name="async_tool", + description="Async-only test tool.", + args_schema=AsyncToolArgs, + func=None, + coroutine=async_tool_impl, + ) + tool_cfg = MagicMock() + tool_cfg.name = "async_tool" + tool_cfg.group = "test" + tool_cfg.use = "tests.fake:async_tool" + mock_cfg.return_value = _make_minimal_config([tool_cfg]) + + with ( + patch("deerflow.tools.tools.resolve_variable", return_value=async_tool), + patch("deerflow.tools.tools.BUILTIN_TOOLS", []), + ): + result = get_available_tools(include_mcp=False, app_config=mock_cfg.return_value) + + assert async_tool in result + assert async_tool.func is not None + assert async_tool.invoke({"x": 42}) == "result: 42" + + @patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) @patch("deerflow.tools.tools.reset_deferred_registry") From 1f978393ec6558b3c91f30475f28b805ad0bb803 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 May 2026 10:35:34 +0800 Subject: [PATCH 10/86] chore(deps): bump urllib3 from 2.6.3 to 2.7.0 in /backend (#2898) Bumps [urllib3](https://github.com/urllib3/urllib3) from 2.6.3 to 2.7.0. - [Release notes](https://github.com/urllib3/urllib3/releases) - [Changelog](https://github.com/urllib3/urllib3/blob/main/CHANGES.rst) - [Commits](https://github.com/urllib3/urllib3/compare/2.6.3...2.7.0) --- updated-dependencies: - dependency-name: urllib3 dependency-version: 2.7.0 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- backend/uv.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/uv.lock b/backend/uv.lock index 64cab46d9..e144fb07e 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -4224,11 +4224,11 @@ wheels = [ [[package]] name = "urllib3" -version = "2.6.3" +version = "2.7.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c7/24/5f1b3bdffd70275f6661c76461e25f024d5a38a46f04aaca912426a2b1d3/urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed", size = 435556, upload-time = "2026-01-07T16:24:43.925Z" } +sdist = { url = "https://files.pythonhosted.org/packages/53/0c/06f8b233b8fd13b9e5ee11424ef85419ba0d8ba0b3138bf360be2ff56953/urllib3-2.7.0.tar.gz", hash = "sha256:231e0ec3b63ceb14667c67be60f2f2c40a518cb38b03af60abc813da26505f4c", size = 433602, upload-time = "2026-05-07T16:13:18.596Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" }, + { url = "https://files.pythonhosted.org/packages/7f/3e/5db95bcf282c52709639744ca2a8b149baccf648e39c8cc87553df9eae0c/urllib3-2.7.0-py3-none-any.whl", hash = "sha256:9fb4c81ebbb1ce9531cce37674bbc6f1360472bc18ca9a553ede278ef7276897", size = 131087, upload-time = "2026-05-07T16:13:17.151Z" }, ] [[package]] From 0009655454cda708654ebe41c6ede5cb9e3fc760 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 12 May 2026 10:45:40 +0800 Subject: [PATCH 11/86] chore(deps): bump next from 16.1.7 to 16.2.6 in /frontend (#2899) Bumps [next](https://github.com/vercel/next.js) from 16.1.7 to 16.2.6. - [Release notes](https://github.com/vercel/next.js/releases) - [Changelog](https://github.com/vercel/next.js/blob/canary/release.js) - [Commits](https://github.com/vercel/next.js/compare/v16.1.7...v16.2.6) --- updated-dependencies: - dependency-name: next dependency-version: 16.2.6 dependency-type: direct:production ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- frontend/package.json | 2 +- frontend/pnpm-lock.yaml | 154 ++++++++++++++++++++++------------------ 2 files changed, 85 insertions(+), 71 deletions(-) diff --git a/frontend/package.json b/frontend/package.json index 2ce4e2f6d..0a46ee452 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -68,7 +68,7 @@ "lucide-react": "^0.562.0", "motion": "^12.26.2", "nanoid": "^5.1.6", - "next": "^16.1.7", + "next": "^16.2.6", "next-themes": "^0.4.6", "nextra": "^4.6.1", "nextra-theme-docs": "^4.6.1", diff --git a/frontend/pnpm-lock.yaml b/frontend/pnpm-lock.yaml index d27c6687c..8c80061c9 100644 --- a/frontend/pnpm-lock.yaml +++ b/frontend/pnpm-lock.yaml @@ -156,17 +156,17 @@ importers: specifier: ^5.1.6 version: 5.1.6 next: - specifier: ^16.1.7 - version: 16.1.7(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + specifier: ^16.2.6 + version: 16.2.6(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) next-themes: specifier: ^0.4.6 version: 0.4.6(react-dom@19.2.4(react@19.2.4))(react@19.2.4) nextra: specifier: ^4.6.1 - version: 4.6.1(next@16.1.7(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3) + version: 4.6.1(next@16.2.6(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3) nextra-theme-docs: specifier: ^4.6.1 - version: 4.6.1(@types/react@19.2.13)(next@16.1.7(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(nextra@4.6.1(next@16.1.7(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) + version: 4.6.1(@types/react@19.2.13)(next@16.2.6(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(nextra@4.6.1(next@16.2.6(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)) nuxt-og-image: specifier: ^5.1.13 version: 5.1.13(@unhead/vue@2.1.4(vue@3.5.28(typescript@5.9.3)))(unstorage@1.17.4)(vite@7.3.1(@types/node@20.19.33)(jiti@2.6.1)(lightningcss@1.30.2)(yaml@2.8.3))(vue@3.5.28(typescript@5.9.3)) @@ -437,8 +437,8 @@ packages: '@emnapi/core@1.8.1': resolution: {integrity: sha512-AvT9QFpxK0Zd8J0jopedNm+w/2fIzvtPKPjqyw9jwvBaReTTqPBk9Hixaz7KbjimP+QNz605/XnjFcDAL2pqBg==} - '@emnapi/runtime@1.9.0': - resolution: {integrity: sha512-QN75eB0IH2ywSpRpNddCRfQIhmJYBCJ1x5Lb3IscKAL8bMnVAKnRg8dCoXbHzVLLH7P38N2Z3mtulB7W0J0FKw==} + '@emnapi/runtime@1.10.0': + resolution: {integrity: sha512-ewvYlk86xUoGI0zQRNq/mC+16R1QeDlKQy21Ki3oSYXNgLb45GV1P6A0M+/s6nyCuNDqe5VpaY84BzXGwVbwFA==} '@emnapi/wasi-threads@1.1.0': resolution: {integrity: sha512-WI0DdZ8xFSbgMjR1sFsKABJ/C5OnRrjT06JXbZKexJGrDuPTzZdDYfFlsgcCXCyf+suG5QU2e/y1Wo2V/OapLQ==} @@ -1018,56 +1018,56 @@ packages: '@napi-rs/wasm-runtime@0.2.12': resolution: {integrity: sha512-ZVWUcfwY4E/yPitQJl481FjFo3K22D6qF0DuFH6Y/nbnE11GY5uguDxZMGXPQ8WQ0128MXQD7TnfHyK4oWoIJQ==} - '@next/env@16.1.7': - resolution: {integrity: sha512-rJJbIdJB/RQr2F1nylZr/PJzamvNNhfr3brdKP6s/GW850jbtR70QlSfFselvIBbcPUOlQwBakexjFzqLzF6pg==} + '@next/env@16.2.6': + resolution: {integrity: sha512-gd8HoHN4ufj73WmR3JmVolrpJR47ILK6LouP5xElPglaVxir6e1a7VzvTvDWkOoPXT9rkkTzyCxBu4yeZfZwcw==} '@next/eslint-plugin-next@15.5.12': resolution: {integrity: sha512-+ZRSDFTv4aC96aMb5E41rMjysx8ApkryevnvEYZvPZO52KvkqP5rNExLUXJFr9P4s0f3oqNQR6vopCZsPWKDcQ==} - '@next/swc-darwin-arm64@16.1.7': - resolution: {integrity: sha512-b2wWIE8sABdyafc4IM8r5Y/dS6kD80JRtOGrUiKTsACFQfWWgUQ2NwoUX1yjFMXVsAwcQeNpnucF2ZrujsBBPg==} + '@next/swc-darwin-arm64@16.2.6': + resolution: {integrity: sha512-ZJGkkcNfYgrrMkqOdZ7zoLa1TOy0qpcMfk/z4Mh/FKUz40gVO+HNQWqmLxf67Z5WB64DRp0dhEbyHfel+6sJUg==} engines: {node: '>= 10'} cpu: [arm64] os: [darwin] - '@next/swc-darwin-x64@16.1.7': - resolution: {integrity: sha512-zcnVaaZulS1WL0Ss38R5Q6D2gz7MtBu8GZLPfK+73D/hp4GFMrC2sudLky1QibfV7h6RJBJs/gOFvYP0X7UVlQ==} + '@next/swc-darwin-x64@16.2.6': + resolution: {integrity: sha512-v/YLBHIY132Ced3puBJ7YJKw1lqsCrgcNo2aRJlCEyQrrCeRJlvGlnmxhPxNQI3KE3N1DN5r9TPNPvka3nq5RQ==} engines: {node: '>= 10'} cpu: [x64] os: [darwin] - '@next/swc-linux-arm64-gnu@16.1.7': - resolution: {integrity: sha512-2ant89Lux/Q3VyC8vNVg7uBaFVP9SwoK2jJOOR0L8TQnX8CAYnh4uctAScy2Hwj2dgjVHqHLORQZJ2wH6VxhSQ==} + '@next/swc-linux-arm64-gnu@16.2.6': + resolution: {integrity: sha512-RPOvqlYBbcQjkz9VQQDZ2T2bARIjXZV1KFlt+V2Mr6SW/e4I9fcKsaA0hdyf2FHoTlsV2xnBd5Y912rP/1Ce6w==} engines: {node: '>= 10'} cpu: [arm64] os: [linux] - '@next/swc-linux-arm64-musl@16.1.7': - resolution: {integrity: sha512-uufcze7LYv0FQg9GnNeZ3/whYfo+1Q3HnQpm16o6Uyi0OVzLlk2ZWoY7j07KADZFY8qwDbsmFnMQP3p3+Ftprw==} + '@next/swc-linux-arm64-musl@16.2.6': + resolution: {integrity: sha512-URUTu1+dMkxJsPFgm+OeEvq9wf5sujw0EvgYy80TDGHTSLTnIHeqb0Eu8A3sC95IRgjejQL+kC4mw+4yPxiAXA==} engines: {node: '>= 10'} cpu: [arm64] os: [linux] - '@next/swc-linux-x64-gnu@16.1.7': - resolution: {integrity: sha512-KWVf2gxYvHtvuT+c4MBOGxuse5TD7DsMFYSxVxRBnOzok/xryNeQSjXgxSv9QpIVlaGzEn/pIuI6Koosx8CGWA==} + '@next/swc-linux-x64-gnu@16.2.6': + resolution: {integrity: sha512-DOj182mPV8G3UkrayLoREM5YEYI+Dk5wv7Ox9xl1fFibAELEsFD0lDPfHIeILlutMMfdyhlzYPELG3peuKaurw==} engines: {node: '>= 10'} cpu: [x64] os: [linux] - '@next/swc-linux-x64-musl@16.1.7': - resolution: {integrity: sha512-HguhaGwsGr1YAGs68uRKc4aGWxLET+NevJskOcCAwXbwj0fYX0RgZW2gsOCzr9S11CSQPIkxmoSbuVaBp4Z3dA==} + '@next/swc-linux-x64-musl@16.2.6': + resolution: {integrity: sha512-HKQ5SP/V/ub73UvF7n/zeJlxk2kLmtL7Wzrg4WfmkjmNos5onJ2tKu7yZOPdL18A6Svfn3max29ym+ry7NkK4g==} engines: {node: '>= 10'} cpu: [x64] os: [linux] - '@next/swc-win32-arm64-msvc@16.1.7': - resolution: {integrity: sha512-S0n3KrDJokKTeFyM/vGGGR8+pCmXYrjNTk2ZozOL1C/JFdfUIL9O1ATaJOl5r2POe56iRChbsszrjMAdWSv7kQ==} + '@next/swc-win32-arm64-msvc@16.2.6': + resolution: {integrity: sha512-LZXpTlPyS5v7HhSmnvsLGP3iIYgYOBnc8r8ArlT55sGHV89bR2HlDdBjWQ+PY6SJMmk8TuVGFuxalnP3k/0Dwg==} engines: {node: '>= 10'} cpu: [arm64] os: [win32] - '@next/swc-win32-x64-msvc@16.1.7': - resolution: {integrity: sha512-mwgtg8CNZGYm06LeEd+bNnOUfwOyNem/rOiP14Lsz+AnUY92Zq/LXwtebtUiaeVkhbroRCQ0c8GlR4UT1U+0yg==} + '@next/swc-win32-x64-msvc@16.2.6': + resolution: {integrity: sha512-F0+4i0h9J6C4eE3EAPWsoCk7UW/dbzOjyzxY0qnDUOYFu6FFmdZ6l97/XdV3/Nz3VYyO7UWjyEJUXkGqcoXfMA==} engines: {node: '>= 10'} cpu: [x64] os: [win32] @@ -1912,6 +1912,9 @@ packages: '@swc/helpers@0.5.15': resolution: {integrity: sha512-JQ5TuMi45Owi4/BIMAJBoSQoOJu12oOk/gADqlcUL9JEdHB8vyjUSsxqeNXnmXHjYKMi2WcYtezGEEhqUI/E2g==} + '@swc/helpers@0.5.21': + resolution: {integrity: sha512-jI/VAmtdjB/RnI8GTnokyX7Ug8c+g+ffD6QRLa6XQewtnGyukKkKSk3wLTM3b5cjt1jNh9x0jfVlagdN2gDKQg==} + '@t3-oss/env-core@0.12.0': resolution: {integrity: sha512-lOPj8d9nJJTt81mMuN9GMk8x5veOt7q9m11OSnCBJhwp1QrL/qR+M8Y467ULBSm9SunosryWNbmQQbgoiMgcdw==} peerDependencies: @@ -2652,8 +2655,8 @@ packages: base64-js@1.5.1: resolution: {integrity: sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==} - baseline-browser-mapping@2.10.8: - resolution: {integrity: sha512-PCLz/LXGBsNTErbtB6i5u4eLpHeMfi93aUv5duMmj6caNu6IphS4q6UevDnL36sZQv9lrP11dbPKGMaXPwMKfQ==} + baseline-browser-mapping@2.10.29: + resolution: {integrity: sha512-Asa2krT+XTPZINCS+2QcyS8WTkObE77RwkydwF7h6DmnKqbvlalz93m/dnphUyCa6SWSP51VgtEUf2FN+gelFQ==} engines: {node: '>=6.0.0'} hasBin: true @@ -2710,8 +2713,8 @@ packages: camelize@1.0.1: resolution: {integrity: sha512-dU+Tx2fsypxTgtLoE36npi3UqcjSSMNYfkqgmoEhtZrraP5VWq0K7FkWVTYa8eMPtnU/G2txVsfdCJTn9uzpuQ==} - caniuse-lite@1.0.30001780: - resolution: {integrity: sha512-llngX0E7nQci5BPJDqoZSbuZ5Bcs9F5db7EtgfwBerX9XGtkkiO4NwfDDIRzHTTwcYC8vC7bmeUEPGrKlR/TkQ==} + caniuse-lite@1.0.30001792: + resolution: {integrity: sha512-hVLMUZFgR4JJ6ACt1uEESvQN1/dBVqPAKY0hgrV70eN3391K6juAfTjKZLKvOMsx8PxA7gsY1/tLMMTcfFLLpw==} canvas-confetti@1.9.4: resolution: {integrity: sha512-yxQbJkAVrFXWNbTUjPqjF7G+g6pDotOUHGbkZq2NELZUMDpiJ85rIEazVb8GTaAptNW2miJAXbs1BtioA251Pw==} @@ -4389,8 +4392,8 @@ packages: react: ^16.8 || ^17 || ^18 || ^19 || ^19.0.0-rc react-dom: ^16.8 || ^17 || ^18 || ^19 || ^19.0.0-rc - next@16.1.7: - resolution: {integrity: sha512-WM0L7WrSvKwoLegLYr6V+mz+RIofqQgVAfHhMp9a88ms0cFX8iX9ew+snpWlSBwpkURJOUdvCEt3uLl3NNzvWg==} + next@16.2.6: + resolution: {integrity: sha512-qOVgKJg1+At15NpeUP+eJgCHvTCgXsogweq87Ri/Ix7PkqQHg4sdaXmSFqKlgaIXE4kW0g25LE68W87UANlHtw==} engines: {node: '>=20.9.0'} hasBin: true peerDependencies: @@ -5013,6 +5016,11 @@ packages: engines: {node: '>=10'} hasBin: true + semver@7.8.0: + resolution: {integrity: sha512-AcM7dV/5ul4EekoQ29Agm5vri8JNqRyj39o0qpX6vDF2GZrtutZl5RwgD1XnZjiTAfncsJhMI48QQH3sN87YNA==} + engines: {node: '>=10'} + hasBin: true + server-only@0.0.1: resolution: {integrity: sha512-qepMx2JxAa5jjfzxG79yPPq+8BuFToHd1hm7kI+Z4zAq1ftQiP7HcxMhDDItrbtwVeLg/cY2JnKnrcFkmiswNA==} @@ -6066,7 +6074,7 @@ snapshots: tslib: 2.8.1 optional: true - '@emnapi/runtime@1.9.0': + '@emnapi/runtime@1.10.0': dependencies: tslib: 2.8.1 optional: true @@ -6343,7 +6351,7 @@ snapshots: '@img/sharp-wasm32@0.34.5': dependencies: - '@emnapi/runtime': 1.9.0 + '@emnapi/runtime': 1.10.0 optional: true '@img/sharp-win32-arm64@0.34.5': @@ -6598,38 +6606,38 @@ snapshots: '@napi-rs/wasm-runtime@0.2.12': dependencies: '@emnapi/core': 1.8.1 - '@emnapi/runtime': 1.9.0 + '@emnapi/runtime': 1.10.0 '@tybys/wasm-util': 0.10.1 optional: true - '@next/env@16.1.7': {} + '@next/env@16.2.6': {} '@next/eslint-plugin-next@15.5.12': dependencies: fast-glob: 3.3.1 - '@next/swc-darwin-arm64@16.1.7': + '@next/swc-darwin-arm64@16.2.6': optional: true - '@next/swc-darwin-x64@16.1.7': + '@next/swc-darwin-x64@16.2.6': optional: true - '@next/swc-linux-arm64-gnu@16.1.7': + '@next/swc-linux-arm64-gnu@16.2.6': optional: true - '@next/swc-linux-arm64-musl@16.1.7': + '@next/swc-linux-arm64-musl@16.2.6': optional: true - '@next/swc-linux-x64-gnu@16.1.7': + '@next/swc-linux-x64-gnu@16.2.6': optional: true - '@next/swc-linux-x64-musl@16.1.7': + '@next/swc-linux-x64-musl@16.2.6': optional: true - '@next/swc-win32-arm64-msvc@16.1.7': + '@next/swc-win32-arm64-msvc@16.2.6': optional: true - '@next/swc-win32-x64-msvc@16.1.7': + '@next/swc-win32-x64-msvc@16.2.6': optional: true '@nodelib/fs.scandir@2.1.5': @@ -7192,7 +7200,7 @@ snapshots: '@react-aria/interactions': 3.27.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@react-aria/utils': 3.33.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@react-types/shared': 3.33.1(react@19.2.4) - '@swc/helpers': 0.5.15 + '@swc/helpers': 0.5.21 clsx: 2.1.1 react: 19.2.4 react-dom: 19.2.4(react@19.2.4) @@ -7203,13 +7211,13 @@ snapshots: '@react-aria/utils': 3.33.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@react-stately/flags': 3.1.2 '@react-types/shared': 3.33.1(react@19.2.4) - '@swc/helpers': 0.5.15 + '@swc/helpers': 0.5.21 react: 19.2.4 react-dom: 19.2.4(react@19.2.4) '@react-aria/ssr@3.9.10(react@19.2.4)': dependencies: - '@swc/helpers': 0.5.15 + '@swc/helpers': 0.5.21 react: 19.2.4 '@react-aria/utils@3.33.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4)': @@ -7218,18 +7226,18 @@ snapshots: '@react-stately/flags': 3.1.2 '@react-stately/utils': 3.11.0(react@19.2.4) '@react-types/shared': 3.33.1(react@19.2.4) - '@swc/helpers': 0.5.15 + '@swc/helpers': 0.5.21 clsx: 2.1.1 react: 19.2.4 react-dom: 19.2.4(react@19.2.4) '@react-stately/flags@3.1.2': dependencies: - '@swc/helpers': 0.5.15 + '@swc/helpers': 0.5.21 '@react-stately/utils@3.11.0(react@19.2.4)': dependencies: - '@swc/helpers': 0.5.15 + '@swc/helpers': 0.5.21 react: 19.2.4 '@react-types/shared@3.33.1(react@19.2.4)': @@ -7437,6 +7445,10 @@ snapshots: dependencies: tslib: 2.8.1 + '@swc/helpers@0.5.21': + dependencies: + tslib: 2.8.1 + '@t3-oss/env-core@0.12.0(typescript@5.9.3)(zod@3.25.76)': optionalDependencies: typescript: 5.9.3 @@ -8249,7 +8261,7 @@ snapshots: base64-js@1.5.1: {} - baseline-browser-mapping@2.10.8: {} + baseline-browser-mapping@2.10.29: {} best-effort-json-parser@1.2.1: {} @@ -8313,7 +8325,7 @@ snapshots: camelize@1.0.1: {} - caniuse-lite@1.0.30001780: {} + caniuse-lite@1.0.30001792: {} canvas-confetti@1.9.4: {} @@ -9643,7 +9655,7 @@ snapshots: is-bun-module@2.0.0: dependencies: - semver: 7.7.4 + semver: 7.8.0 is-callable@1.2.7: {} @@ -10531,25 +10543,25 @@ snapshots: react: 19.2.4 react-dom: 19.2.4(react@19.2.4) - next@16.1.7(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4): + next@16.2.6(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4): dependencies: - '@next/env': 16.1.7 + '@next/env': 16.2.6 '@swc/helpers': 0.5.15 - baseline-browser-mapping: 2.10.8 - caniuse-lite: 1.0.30001780 + baseline-browser-mapping: 2.10.29 + caniuse-lite: 1.0.30001792 postcss: 8.4.31 react: 19.2.4 react-dom: 19.2.4(react@19.2.4) styled-jsx: 5.1.6(react@19.2.4) optionalDependencies: - '@next/swc-darwin-arm64': 16.1.7 - '@next/swc-darwin-x64': 16.1.7 - '@next/swc-linux-arm64-gnu': 16.1.7 - '@next/swc-linux-arm64-musl': 16.1.7 - '@next/swc-linux-x64-gnu': 16.1.7 - '@next/swc-linux-x64-musl': 16.1.7 - '@next/swc-win32-arm64-msvc': 16.1.7 - '@next/swc-win32-x64-msvc': 16.1.7 + '@next/swc-darwin-arm64': 16.2.6 + '@next/swc-darwin-x64': 16.2.6 + '@next/swc-linux-arm64-gnu': 16.2.6 + '@next/swc-linux-arm64-musl': 16.2.6 + '@next/swc-linux-x64-gnu': 16.2.6 + '@next/swc-linux-x64-musl': 16.2.6 + '@next/swc-win32-arm64-msvc': 16.2.6 + '@next/swc-win32-x64-msvc': 16.2.6 '@opentelemetry/api': 1.9.0 '@playwright/test': 1.59.1 sharp: 0.34.5 @@ -10557,13 +10569,13 @@ snapshots: - '@babel/core' - babel-plugin-macros - nextra-theme-docs@4.6.1(@types/react@19.2.13)(next@16.1.7(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(nextra@4.6.1(next@16.1.7(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)): + nextra-theme-docs@4.6.1(@types/react@19.2.13)(next@16.2.6(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(nextra@4.6.1(next@16.2.6(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(use-sync-external-store@1.6.0(react@19.2.4)): dependencies: '@headlessui/react': 2.2.9(react-dom@19.2.4(react@19.2.4))(react@19.2.4) clsx: 2.1.1 - next: 16.1.7(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + next: 16.2.6(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) next-themes: 0.4.6(react-dom@19.2.4(react@19.2.4))(react@19.2.4) - nextra: 4.6.1(next@16.1.7(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3) + nextra: 4.6.1(next@16.2.6(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3) react: 19.2.4 react-compiler-runtime: 19.1.0-rc.3(react@19.2.4) react-dom: 19.2.4(react@19.2.4) @@ -10575,7 +10587,7 @@ snapshots: - immer - use-sync-external-store - nextra@4.6.1(next@16.1.7(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3): + nextra@4.6.1(next@16.2.6(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(typescript@5.9.3): dependencies: '@formatjs/intl-localematcher': 0.6.2 '@headlessui/react': 2.2.9(react-dom@19.2.4(react@19.2.4))(react@19.2.4) @@ -10596,7 +10608,7 @@ snapshots: mdast-util-gfm: 3.1.0 mdast-util-to-hast: 13.2.1 negotiator: 1.0.0 - next: 16.1.7(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) + next: 16.2.6(@opentelemetry/api@1.9.0)(@playwright/test@1.59.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) react: 19.2.4 react-compiler-runtime: 19.1.0-rc.3(react@19.2.4) react-dom: 19.2.4(react@19.2.4) @@ -10925,7 +10937,7 @@ snapshots: postcss@8.4.31: dependencies: - nanoid: 3.3.11 + nanoid: 3.3.12 picocolors: 1.1.1 source-map-js: 1.2.1 @@ -11365,6 +11377,8 @@ snapshots: semver@7.7.4: {} + semver@7.8.0: {} + server-only@0.0.1: {} set-function-length@1.2.2: @@ -11393,7 +11407,7 @@ snapshots: dependencies: '@img/colour': 1.1.0 detect-libc: 2.1.2 - semver: 7.7.4 + semver: 7.8.0 optionalDependencies: '@img/sharp-darwin-arm64': 0.34.5 '@img/sharp-darwin-x64': 0.34.5 From 20d2d2b3731edf9d5d72a191471c1fd856453350 Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Tue, 12 May 2026 04:55:13 +0200 Subject: [PATCH 12/86] fix(middleware): Handle invalid tool calls in dangling pairing middleware (#2890) (#2891) --- .../dangling_tool_call_middleware.py | 83 +++++++++++++------ .../test_dangling_tool_call_middleware.py | 50 +++++++++++ 2 files changed, 107 insertions(+), 26 deletions(-) diff --git a/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py index 7bf600b9f..5bb54f3e5 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py @@ -36,42 +36,73 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): @staticmethod def _message_tool_calls(msg) -> list[dict]: - """Return normalized tool calls from structured fields or raw provider payloads.""" + """Return normalized tool calls from structured fields or raw provider payloads. + + LangChain stores malformed provider function calls in ``invalid_tool_calls``. + They do not execute, but provider adapters may still serialize enough of + the call id/name back into the next request that strict OpenAI-compatible + validators expect a matching ToolMessage. Treat them as dangling calls so + the next model request stays well-formed and the model sees a recoverable + tool error instead of another provider 400. + """ + normalized: list[dict] = [] + tool_calls = getattr(msg, "tool_calls", None) or [] - if tool_calls: - return list(tool_calls) + normalized.extend(list(tool_calls)) raw_tool_calls = (getattr(msg, "additional_kwargs", None) or {}).get("tool_calls") or [] - normalized: list[dict] = [] - for raw_tc in raw_tool_calls: - if not isinstance(raw_tc, dict): + if not tool_calls: + for raw_tc in raw_tool_calls: + if not isinstance(raw_tc, dict): + continue + + function = raw_tc.get("function") + name = raw_tc.get("name") + if not name and isinstance(function, dict): + name = function.get("name") + + args = raw_tc.get("args", {}) + if not args and isinstance(function, dict): + raw_args = function.get("arguments") + if isinstance(raw_args, str): + try: + parsed_args = json.loads(raw_args) + except (TypeError, ValueError, json.JSONDecodeError): + parsed_args = {} + args = parsed_args if isinstance(parsed_args, dict) else {} + + normalized.append( + { + "id": raw_tc.get("id"), + "name": name or "unknown", + "args": args if isinstance(args, dict) else {}, + } + ) + + for invalid_tc in getattr(msg, "invalid_tool_calls", None) or []: + if not isinstance(invalid_tc, dict): continue - - function = raw_tc.get("function") - name = raw_tc.get("name") - if not name and isinstance(function, dict): - name = function.get("name") - - args = raw_tc.get("args", {}) - if not args and isinstance(function, dict): - raw_args = function.get("arguments") - if isinstance(raw_args, str): - try: - parsed_args = json.loads(raw_args) - except (TypeError, ValueError, json.JSONDecodeError): - parsed_args = {} - args = parsed_args if isinstance(parsed_args, dict) else {} - normalized.append( { - "id": raw_tc.get("id"), - "name": name or "unknown", - "args": args if isinstance(args, dict) else {}, + "id": invalid_tc.get("id"), + "name": invalid_tc.get("name") or "unknown", + "args": {}, + "invalid": True, + "error": invalid_tc.get("error"), } ) return normalized + @staticmethod + def _synthetic_tool_message_content(tool_call: dict) -> str: + if tool_call.get("invalid"): + error = tool_call.get("error") + if isinstance(error, str) and error: + return f"[Tool call could not be executed because its arguments were invalid: {error}]" + return "[Tool call could not be executed because its arguments were invalid.]" + return "[Tool call was interrupted and did not return a result.]" + def _build_patched_messages(self, messages: list) -> list | None: """Return a new message list with patches inserted at the correct positions. @@ -114,7 +145,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids: patched.append( ToolMessage( - content="[Tool call was interrupted and did not return a result.]", + content=self._synthetic_tool_message_content(tc), tool_call_id=tc_id, name=tc.get("name", "unknown"), status="error", diff --git a/backend/tests/test_dangling_tool_call_middleware.py b/backend/tests/test_dangling_tool_call_middleware.py index 90c162eac..b1d5c476a 100644 --- a/backend/tests/test_dangling_tool_call_middleware.py +++ b/backend/tests/test_dangling_tool_call_middleware.py @@ -14,6 +14,10 @@ def _ai_with_tool_calls(tool_calls): return AIMessage(content="", tool_calls=tool_calls) +def _ai_with_invalid_tool_calls(invalid_tool_calls): + return AIMessage(content="", tool_calls=[], invalid_tool_calls=invalid_tool_calls) + + def _tool_msg(tool_call_id, name="test_tool"): return ToolMessage(content="result", tool_call_id=tool_call_id, name=name) @@ -22,6 +26,16 @@ def _tc(name="bash", tc_id="call_1"): return {"name": name, "id": tc_id, "args": {}} +def _invalid_tc(name="write_file", tc_id="write_file:36", error="Failed to parse tool arguments: malformed JSON"): + return { + "type": "invalid_tool_call", + "name": name, + "id": tc_id, + "args": '{"description":"write report","path":"/mnt/user-data/outputs/report.md","content":"bad {"json"}"}', + "error": error, + } + + class TestBuildPatchedMessagesNoPatch: def test_empty_messages(self): mw = DanglingToolCallMiddleware() @@ -144,6 +158,42 @@ class TestBuildPatchedMessagesPatching: assert patched[1].name == "bash" assert patched[1].status == "error" + def test_invalid_tool_call_is_patched(self): + mw = DanglingToolCallMiddleware() + msgs = [_ai_with_invalid_tool_calls([_invalid_tc()])] + patched = mw._build_patched_messages(msgs) + assert patched is not None + assert len(patched) == 2 + assert isinstance(patched[1], ToolMessage) + assert patched[1].tool_call_id == "write_file:36" + assert patched[1].name == "write_file" + assert patched[1].status == "error" + assert "arguments were invalid" in patched[1].content + assert "Failed to parse tool arguments" in patched[1].content + + def test_valid_and_invalid_tool_calls_are_both_patched(self): + mw = DanglingToolCallMiddleware() + msgs = [ + AIMessage( + content="", + tool_calls=[_tc("bash", "call_1")], + invalid_tool_calls=[_invalid_tc()], + ) + ] + patched = mw._build_patched_messages(msgs) + assert patched is not None + tool_msgs = [m for m in patched if isinstance(m, ToolMessage)] + assert len(tool_msgs) == 2 + assert {tm.tool_call_id for tm in tool_msgs} == {"call_1", "write_file:36"} + + def test_invalid_tool_call_already_responded_is_not_patched(self): + mw = DanglingToolCallMiddleware() + msgs = [ + _ai_with_invalid_tool_calls([_invalid_tc()]), + _tool_msg("write_file:36", "write_file"), + ] + assert mw._build_patched_messages(msgs) is None + class TestWrapModelCall: def test_no_patch_passthrough(self): From 84f88b6610e5c6384735e703809bc8b35e33dacb Mon Sep 17 00:00:00 2001 From: Eilen Shin <136898293+Eilen6316@users.noreply.github.com> Date: Tue, 12 May 2026 16:19:21 +0800 Subject: [PATCH 13/86] docs: align runtime docs with gateway mode (#2868) Co-authored-by: Willem Jiang --- CONTRIBUTING.md | 8 +-- README_fr.md | 6 +- README_ja.md | 6 +- README_zh.md | 6 +- backend/CONTRIBUTING.md | 5 +- backend/README.md | 57 ++++++++----------- backend/docs/API.md | 31 ++++++---- backend/docs/ARCHITECTURE.md | 49 ++++++++-------- frontend/README.md | 6 +- .../en/application/agents-and-threads.mdx | 7 +-- .../en/application/deployment-guide.mdx | 29 +++------- frontend/src/content/en/application/index.mdx | 20 ++----- .../operations-and-troubleshooting.mdx | 13 ++--- .../content/en/application/quick-start.mdx | 11 ++-- frontend/src/content/en/harness/skills.mdx | 2 +- .../content/zh/application/configuration.mdx | 1 - .../zh/application/deployment-guide.mdx | 27 +++------ frontend/src/content/zh/application/index.mdx | 20 ++----- .../operations-and-troubleshooting.mdx | 17 ++---- .../content/zh/application/quick-start.mdx | 11 ++-- skills/public/claude-to-deerflow/SKILL.md | 4 +- 21 files changed, 135 insertions(+), 201 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 51b834b4f..ceebba99c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -185,9 +185,9 @@ If you need to start services individually: 1. **Start backend service**: ```bash - # Terminal 1: Start Gateway API and embedded LangGraph-compatible runtime (port 8001) + # Terminal 1: Start Gateway API + embedded agent runtime (port 8001) cd backend - make gateway + make dev # Terminal 2: Start Frontend (port 3000) cd frontend @@ -207,7 +207,7 @@ If you need to start services individually: The nginx configuration provides: - Unified entry point on port 2026 -- Gateway owns `/api/langgraph/*` and translates those public LangGraph-compatible paths to its native `/api/*` routers behind nginx +- Rewrites `/api/langgraph/*` to Gateway's LangGraph-compatible API (8001) - Routes other `/api/*` endpoints to Gateway API (8001) - Routes non-API requests to Frontend (3000) - Same-origin API routing; split-origin or port-forwarded browser clients should use the Gateway `GATEWAY_CORS_ORIGINS` allowlist @@ -231,7 +231,7 @@ deer-flow/ ├── backend/ # Backend application │ ├── src/ │ │ ├── gateway/ # Gateway API and LangGraph-compatible runtime (port 8001) -│ │ ├── agents/ # LangGraph agent definitions +│ │ ├── agents/ # LangGraph agent runtime used by Gateway │ │ ├── mcp/ # Model Context Protocol integration │ │ ├── skills/ # Skills system │ │ └── sandbox/ # Sandbox execution diff --git a/README_fr.md b/README_fr.md index 3b8dc3d41..f144d8bc5 100644 --- a/README_fr.md +++ b/README_fr.md @@ -228,7 +228,7 @@ make down # Stop and remove containers ``` > [!NOTE] -> Le serveur d'agents LangGraph fonctionne actuellement via `langgraph dev` (le serveur CLI open source). +> Le runtime d'agent s'exécute actuellement dans la Gateway. nginx réécrit `/api/langgraph/*` vers l'API compatible LangGraph servie par la Gateway. Accès : http://localhost:2026 @@ -296,8 +296,8 @@ DeerFlow peut recevoir des tâches depuis des applications de messagerie. Les ca ```yaml channels: - # LangGraph Server URL (default: http://localhost:2024) - langgraph_url: http://localhost:2024 + # LangGraph-compatible Gateway API base URL (default: http://localhost:8001/api) + langgraph_url: http://localhost:8001/api # Gateway API URL (default: http://localhost:8001) gateway_url: http://localhost:8001 diff --git a/README_ja.md b/README_ja.md index d2ba81750..2bf060799 100644 --- a/README_ja.md +++ b/README_ja.md @@ -181,7 +181,7 @@ make down # コンテナを停止して削除 ``` > [!NOTE] -> LangGraphエージェントサーバーは現在`langgraph dev`(オープンソースCLIサーバー)経由で実行されます。 +> Agentランタイムは現在Gateway内で実行されます。`/api/langgraph/*`はnginxによってGatewayのLangGraph-compatible APIへ書き換えられます。 アクセス: http://localhost:2026 @@ -249,8 +249,8 @@ DeerFlowはメッセージングアプリからのタスク受信をサポート ```yaml channels: - # LangGraphサーバーURL(デフォルト: http://localhost:2024) - langgraph_url: http://localhost:2024 + # LangGraph-compatible Gateway API base URL(デフォルト: http://localhost:8001/api) + langgraph_url: http://localhost:8001/api # Gateway API URL(デフォルト: http://localhost:8001) gateway_url: http://localhost:8001 diff --git a/README_zh.md b/README_zh.md index d5317082e..ec67b95d6 100644 --- a/README_zh.md +++ b/README_zh.md @@ -184,7 +184,7 @@ make down # 停止并移除容器 ``` > [!NOTE] -> 当前 LangGraph agent server 通过开源 CLI 服务 `langgraph dev` 运行。 +> 当前 Agent 运行时嵌入在 Gateway 中运行,`/api/langgraph/*` 会由 nginx 重写到 Gateway 的 LangGraph-compatible API。 访问地址:http://localhost:2026 @@ -254,8 +254,8 @@ DeerFlow 支持从即时通讯应用接收任务。只要配置完成,对应 ```yaml channels: - # LangGraph Server URL(默认:http://localhost:2024) - langgraph_url: http://localhost:2024 + # LangGraph-compatible Gateway API base URL(默认:http://localhost:8001/api) + langgraph_url: http://localhost:8001/api # Gateway API URL(默认:http://localhost:8001) gateway_url: http://localhost:8001 diff --git a/backend/CONTRIBUTING.md b/backend/CONTRIBUTING.md index 322710e74..f7ef58447 100644 --- a/backend/CONTRIBUTING.md +++ b/backend/CONTRIBUTING.md @@ -56,11 +56,8 @@ export OPENAI_API_KEY="your-api-key" ### Run the Development Server ```bash -# Terminal 1: LangGraph server +# Gateway API + embedded agent runtime make dev - -# Terminal 2: Gateway API -make gateway ``` ## Project Structure diff --git a/backend/README.md b/backend/README.md index 9b4d26fb1..18d89c2be 100644 --- a/backend/README.md +++ b/backend/README.md @@ -11,34 +11,26 @@ DeerFlow is a LangGraph-based AI super agent with sandbox execution, persistent │ Nginx (Port 2026) │ │ Unified reverse proxy │ └───────┬──────────────────┬───────────┘ - │ │ - /api/langgraph/* │ │ /api/* (other) - ▼ ▼ - ┌──────────────────────────────────────────────┐ - │ Gateway API (8001) │ - │ FastAPI REST + LangGraph-compatible runtime │ - │ │ - │ Models, MCP, Skills, Memory, Uploads, │ - │ Artifacts, Threads, Runs, Streaming │ - │ │ - │ ┌────────────────┐ │ - │ │ Lead Agent │ │ - │ │ ┌──────────┐ │ │ - │ │ │Middleware│ │ │ - │ │ │ Chain │ │ │ - │ │ └──────────┘ │ │ - │ │ ┌──────────┐ │ │ - │ │ │ Tools │ │ │ - │ │ └──────────┘ │ │ - │ │ ┌──────────┐ │ │ - │ │ │Subagents │ │ │ - │ │ └──────────┘ │ │ - │ └────────────────┘ │ - └──────────────────────────────────────────────┘ + │ + /api/langgraph/* │ /api/* (other) + rewritten to /api/* │ + ▼ + ┌────────────────────────────────────────┐ + │ Gateway API (8001) │ + │ FastAPI REST + agent runtime │ + │ │ + │ Models, MCP, Skills, Memory, Uploads, │ + │ Artifacts, Threads, Runs, Streaming │ + │ │ + │ ┌────────────────────────────────────┐ │ + │ │ Lead Agent │ │ + │ │ Middleware Chain, Tools, Subagents │ │ + │ └────────────────────────────────────┘ │ + └────────────────────────────────────────┘ ``` **Request Routing** (via Nginx): -- `/api/langgraph/*` → Gateway API - LangGraph-compatible agent interactions, threads, runs, and streaming translated to native `/api/*` routers +- `/api/langgraph/*` → Gateway LangGraph-compatible API - agent interactions, threads, streaming - `/api/*` (other) → Gateway API - models, MCP, skills, memory, artifacts, uploads, thread-local cleanup - `/` (non-API) → Frontend - Next.js web interface @@ -196,7 +188,7 @@ export OPENAI_API_KEY="your-api-key-here" **Full Application** (from project root): ```bash -make dev # Starts LangGraph + Gateway + Frontend + Nginx +make dev # Starts Gateway + Frontend + Nginx ``` Access at: http://localhost:2026 @@ -204,14 +196,11 @@ Access at: http://localhost:2026 **Backend Only** (from backend directory): ```bash -# Terminal 1: LangGraph server +# Gateway API + embedded agent runtime make dev - -# Terminal 2: Gateway API -make gateway ``` -Direct access: LangGraph at http://localhost:2024, Gateway at http://localhost:8001 +Direct access: Gateway at http://localhost:8001 --- @@ -247,7 +236,7 @@ backend/ │ └── utils/ # Utilities ├── docs/ # Documentation ├── tests/ # Test suite -├── langgraph.json # LangGraph server configuration +├── langgraph.json # LangGraph graph registry for tooling/Studio compatibility ├── pyproject.toml # Python dependencies ├── Makefile # Development commands └── Dockerfile # Container build @@ -365,8 +354,8 @@ If a provider is explicitly enabled but required credentials are missing, or the ```bash make install # Install dependencies -make dev # Run LangGraph server (port 2024) -make gateway # Run Gateway API (port 8001) +make dev # Run Gateway API + embedded agent runtime (port 8001) +make gateway # Run Gateway API without reload (port 8001) make lint # Run linter (ruff) make format # Format code (ruff) ``` diff --git a/backend/docs/API.md b/backend/docs/API.md index 293c1ebd1..d0b06ef0b 100644 --- a/backend/docs/API.md +++ b/backend/docs/API.md @@ -561,12 +561,13 @@ location /api/ { --- -## WebSocket Support +## Streaming Support -The LangGraph server supports WebSocket connections for real-time streaming. Connect to: +Gateway's LangGraph-compatible API streams run events with Server-Sent Events (SSE): -``` -ws://localhost:2026/api/langgraph/threads/{thread_id}/runs/stream +```http +POST /api/langgraph/threads/{thread_id}/runs/stream +Accept: text/event-stream ``` --- @@ -602,13 +603,21 @@ const response = await fetch('/api/models'); const data = await response.json(); console.log(data.models); -// Using EventSource for streaming -const eventSource = new EventSource( - `/api/langgraph/threads/${threadId}/runs/stream` -); -eventSource.onmessage = (event) => { - console.log(JSON.parse(event.data)); -}; +// Create a run and stream SSE events +const streamResponse = await fetch(`/api/langgraph/threads/${threadId}/runs/stream`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "text/event-stream", + }, + body: JSON.stringify({ + input: { messages: [{ role: "user", content: "Hello" }] }, + stream_mode: ["values", "messages-tuple", "custom"], + }), +}); + +const reader = streamResponse.body?.getReader(); +// Decode and parse SSE frames from reader in your client code. ``` ### cURL Examples diff --git a/backend/docs/ARCHITECTURE.md b/backend/docs/ARCHITECTURE.md index e6fdbe217..f1557a6fb 100644 --- a/backend/docs/ARCHITECTURE.md +++ b/backend/docs/ARCHITECTURE.md @@ -20,24 +20,22 @@ This document provides a comprehensive overview of the DeerFlow backend architec │ └────────────────────────────────────────────────────────────────────┘ │ └─────────────────────────────────┬────────────────────────────────────────┘ │ - ┌───────────────────────┼───────────────────────┐ - │ │ │ - ▼ ▼ ▼ -┌─────────────────────┐ ┌─────────────────────┐ ┌─────────────────────┐ -│ Embedded Runtime │ │ Gateway API │ │ Frontend │ -│ (inside Gateway) │ │ (Port 8001) │ │ (Port 3000) │ -│ │ │ │ │ │ -│ - Agent Runtime │ │ - Models API │ │ - Next.js App │ -│ - Thread Mgmt │ │ - MCP Config │ │ - React UI │ -│ - SSE Streaming │ │ - Skills Mgmt │ │ - Chat Interface │ -│ - Checkpointing │ │ - File Uploads │ │ │ -│ │ │ - Thread Cleanup │ │ │ -│ │ │ - Artifacts │ │ │ -└─────────────────────┘ └─────────────────────┘ └─────────────────────┘ - │ │ - │ ┌─────────────────┘ - │ │ - ▼ ▼ + ┌───────────────────────┴───────────────────────┐ + │ │ + ▼ ▼ +┌─────────────────────────────────────────────┐ ┌─────────────────────┐ +│ Gateway API │ │ Frontend │ +│ (Port 8001) │ │ (Port 3000) │ +│ │ │ │ +│ - LangGraph-compatible runs/threads API │ │ - Next.js App │ +│ - Embedded Agent Runtime │ │ - React UI │ +│ - SSE Streaming │ │ - Chat Interface │ +│ - Checkpointing │ │ │ +│ - Models, MCP, Skills, Uploads, Artifacts │ │ │ +│ - Thread Cleanup │ │ │ +└─────────────────────────────────────────────┘ └─────────────────────┘ + │ + ▼ ┌──────────────────────────────────────────────────────────────────────────┐ │ Shared Configuration │ │ ┌─────────────────────────┐ ┌────────────────────────────────────────┐ │ @@ -52,9 +50,9 @@ This document provides a comprehensive overview of the DeerFlow backend architec ## Component Details -### Embedded LangGraph Runtime +### Gateway Embedded Agent Runtime -The LangGraph-compatible runtime runs inside the Gateway process and is built on LangGraph for robust multi-agent workflow orchestration. +The agent runtime is embedded in the FastAPI Gateway and built on LangGraph for robust multi-agent workflow orchestration. Nginx rewrites `/api/langgraph/*` to Gateway's native `/api/*` routes, so the public API remains compatible with LangGraph SDK clients without running a separate LangGraph server. **Entry Point**: `packages/harness/deerflow/agents/lead_agent/agent.py:make_lead_agent` @@ -65,7 +63,7 @@ The LangGraph-compatible runtime runs inside the Gateway process and is built on - Tool execution orchestration - SSE streaming for real-time responses -**Configuration**: `langgraph.json` +**Graph registry**: `langgraph.json` remains available for tooling and Studio compatibility. ```json { @@ -84,6 +82,7 @@ FastAPI application providing REST endpoints plus the public LangGraph-compatibl **Routers**: - `models.py` - `/api/models` - Model listing and details +- `thread_runs.py` / `runs.py` - `/api/threads/{id}/runs`, `/api/runs/*` - LangGraph-compatible runs and streaming - `mcp.py` - `/api/mcp` - MCP server configuration - `skills.py` - `/api/skills` - Skills management - `uploads.py` - `/api/threads/{id}/uploads` - File upload @@ -91,7 +90,7 @@ FastAPI application providing REST endpoints plus the public LangGraph-compatibl - `artifacts.py` - `/api/threads/{id}/artifacts` - Artifact serving - `suggestions.py` - `/api/threads/{id}/suggestions` - Follow-up suggestion generation -The web conversation delete flow is now split across both backend surfaces: LangGraph handles `DELETE /api/langgraph/threads/{thread_id}` for thread state, then the Gateway `threads.py` router removes DeerFlow-managed filesystem data via `Paths.delete_thread_dir()`. +The web conversation delete flow first deletes Gateway-managed thread state through the LangGraph-compatible route, then the Gateway `threads.py` router removes DeerFlow-managed filesystem data via `Paths.delete_thread_dir()`. ### Agent Architecture @@ -354,9 +353,9 @@ SKILL.md Format: {"input": {"messages": [{"role": "user", "content": "Hello"}]}} 2. Nginx → Gateway API (8001) - Routes `/api/langgraph/*` to the Gateway's LangGraph-compatible runtime + `/api/langgraph/*` is rewritten to Gateway's LangGraph-compatible `/api/*` routes -3. Embedded LangGraph runtime +3. Gateway embedded runtime a. Load/create thread state b. Execute middleware chain: - ThreadDataMiddleware: Set up paths @@ -412,7 +411,7 @@ SKILL.md Format: ### Thread Cleanup Flow ``` -1. Client deletes conversation via LangGraph +1. Client deletes conversation via the LangGraph-compatible Gateway route DELETE /api/langgraph/threads/{thread_id} 2. Web UI follows up with Gateway cleanup diff --git a/frontend/README.md b/frontend/README.md index 6db881301..4ad70fb1f 100644 --- a/frontend/README.md +++ b/frontend/README.md @@ -82,10 +82,10 @@ pnpm start Key environment variables (see `.env.example` for full list): ```bash -# Backend API URLs (optional, uses nginx proxy by default) +# Backend API URL (optional, uses local Next.js/nginx proxy by default) NEXT_PUBLIC_BACKEND_BASE_URL="http://localhost:8001" -# LangGraph API URLs (optional, uses nginx proxy by default) -NEXT_PUBLIC_LANGGRAPH_BASE_URL="http://localhost:2024" +# LangGraph-compatible API URL (optional, uses local Next.js/nginx proxy by default) +NEXT_PUBLIC_LANGGRAPH_BASE_URL="http://localhost:8001/api" ``` ## Project Structure diff --git a/frontend/src/content/en/application/agents-and-threads.mdx b/frontend/src/content/en/application/agents-and-threads.mdx index bbf3cfc7e..0a281a33e 100644 --- a/frontend/src/content/en/application/agents-and-threads.mdx +++ b/frontend/src/content/en/application/agents-and-threads.mdx @@ -111,10 +111,9 @@ checkpointer: ``` - The LangGraph Server manages its own state separately. The - checkpointer setting in config.yaml applies to the - embedded DeerFlowClient (used in direct Python integrations), not - to the LangGraph Server deployment used by DeerFlow App. + The Gateway embedded runtime uses the checkpointer setting in + config.yaml. The same setting is also used by + DeerFlowClient in direct Python integrations. ### Thread data storage diff --git a/frontend/src/content/en/application/deployment-guide.mdx b/frontend/src/content/en/application/deployment-guide.mdx index 04b3599c0..52b59cf01 100644 --- a/frontend/src/content/en/application/deployment-guide.mdx +++ b/frontend/src/content/en/application/deployment-guide.mdx @@ -23,8 +23,7 @@ Services started: | Service | Port | Description | | ----------- | ---- | ------------------------ | -| LangGraph | 2024 | DeerFlow Harness runtime | -| Gateway API | 8001 | FastAPI backend | +| Gateway API | 8001 | FastAPI backend + embedded agent runtime | | Frontend | 3000 | Next.js UI | | nginx | 2026 | Unified reverse proxy | @@ -36,13 +35,12 @@ Access the app at **http://localhost:2026**. make stop ``` -Stops all four services. Safe to run even if a service is not running. +Stops all services. Safe to run even if a service is not running. ``` -logs/langgraph.log # Agent runtime logs -logs/gateway.log # API gateway logs +logs/gateway.log # API gateway and agent runtime logs logs/frontend.log # Next.js dev server logs logs/nginx.log # nginx access/error logs ``` @@ -50,7 +48,7 @@ logs/nginx.log # nginx access/error logs Tail a log in real time: ```bash -tail -f logs/langgraph.log +tail -f logs/gateway.log ``` @@ -74,7 +72,7 @@ export DEER_FLOW_ROOT=/path/to/deer-flow docker compose -f docker/docker-compose-dev.yaml up --build ``` -Services: nginx, frontend, gateway, langgraph, and optionally provisioner (for K8s-managed sandboxes). +Services: nginx, frontend, gateway, and optionally provisioner (for K8s-managed sandboxes). Access the app at **http://localhost:2026**. @@ -99,7 +97,7 @@ The `docker-compose*.yaml` files include an `env_file: ../.env` directive that l ### Data persistence -Thread data is stored in `backend/.deer-flow/threads/`. In Docker deployments, this directory is bind-mounted into the langgraph container. +Thread data is stored in `backend/.deer-flow/threads/`. In Docker deployments, this directory is bind-mounted into the gateway container. To avoid data loss when containers are recreated: @@ -161,14 +159,7 @@ When `USERDATA_PVC_NAME` is set, the provisioner automatically uses subPath (`th ### nginx configuration -nginx routes all traffic. Key environment variables that control routing: - -| Variable | Default | Description | -| -------------------- | ---------------- | --------------------------------------- | -| `LANGGRAPH_UPSTREAM` | `langgraph:2024` | LangGraph service address | -| `LANGGRAPH_REWRITE` | `/` | URL rewrite prefix for LangGraph routes | - -These are set in the Docker Compose environment and processed by `envsubst` at container startup. +nginx routes all traffic to the frontend or Gateway. `/api/langgraph/*` is rewritten to Gateway's LangGraph-compatible `/api/*` routes, so no separate LangGraph upstream is required. ### Authentication @@ -186,8 +177,7 @@ openssl rand -base64 32 | Service | Minimum | Recommended | | ------------------------------- | ---------------- | ---------------- | -| LangGraph (agent runtime) | 2 vCPU, 4 GB RAM | 4 vCPU, 8 GB RAM | -| Gateway | 0.5 vCPU, 512 MB | 1 vCPU, 1 GB | +| Gateway + agent runtime | 2 vCPU, 4 GB RAM | 4 vCPU, 8 GB RAM | | Frontend | 0.5 vCPU, 512 MB | 1 vCPU, 1 GB | | Sandbox container (per session) | 1 vCPU, 1 GB | 2 vCPU, 2 GB | @@ -199,9 +189,6 @@ After starting, verify the deployment: # Check Gateway health curl http://localhost:8001/health -# Check LangGraph health -curl http://localhost:2024/ok - # List configured models (through nginx) curl http://localhost:2026/api/models ``` diff --git a/frontend/src/content/en/application/index.mdx b/frontend/src/content/en/application/index.mdx index 2cb15a911..b45a6cbf0 100644 --- a/frontend/src/content/en/application/index.mdx +++ b/frontend/src/content/en/application/index.mdx @@ -25,11 +25,11 @@ DeerFlow App is the reference implementation of what a production DeerFlow exper | **Streaming responses** | Real-time token streaming with thinking steps and tool call visibility | | **Artifact viewer** | In-browser preview and download of files and outputs produced by the agent | | **Extensions UI** | Enable/disable MCP servers and skills without editing config files | -| **Gateway API** | FastAPI-based REST API that bridges the frontend and the LangGraph runtime | +| **Gateway API** | FastAPI-based REST API with the embedded LangGraph-compatible agent runtime | ## Architecture -The DeerFlow App runs as four services behind a single nginx reverse proxy: +The DeerFlow App runs behind a single nginx reverse proxy: ``` ┌──────────────────┐ @@ -42,19 +42,11 @@ The DeerFlow App runs as four services behind a single nginx reverse proxy: │ Frontend :3000 │ │ Gateway API :8001 │ │ (Next.js) │ │ (FastAPI) │ └──────────────────┘ └──────────────────────┘ - │ - ┌─────────┘ - ▼ - ┌──────────────────────┐ - │ LangGraph :2024 │ - │ (DeerFlow Harness) │ - └──────────────────────┘ ``` -- **nginx**: routes requests — `/api/*` to the Gateway, LangGraph streaming endpoints to LangGraph directly, and everything else to the frontend. -- **Frontend** (Next.js + React): the browser UI. Communicates with both the Gateway and LangGraph. -- **Gateway** (FastAPI): handles API operations — model listing, agent CRUD, memory, extensions management, file uploads. -- **LangGraph**: the DeerFlow Harness runtime. Manages thread state, agent execution, and streaming. +- **nginx**: routes requests — `/api/*` and `/api/langgraph/*` to Gateway, and everything else to the frontend. +- **Frontend** (Next.js + React): the browser UI. Communicates with Gateway. +- **Gateway** (FastAPI): handles API operations and the embedded LangGraph-compatible runtime for thread state, agent execution, and streaming. ## Technology stack @@ -64,7 +56,7 @@ The DeerFlow App runs as four services behind a single nginx reverse proxy: | Gateway | FastAPI, Python 3.12, uvicorn | | Agent runtime | LangGraph, LangChain, DeerFlow Harness | | Reverse proxy | nginx | -| State persistence | LangGraph Server (default) + optional SQLite/PostgreSQL checkpointer | +| State persistence | Gateway runtime + optional SQLite/PostgreSQL checkpointer | diff --git a/frontend/src/content/en/application/operations-and-troubleshooting.mdx b/frontend/src/content/en/application/operations-and-troubleshooting.mdx index 8b21cf4b4..0f8d7e44c 100644 --- a/frontend/src/content/en/application/operations-and-troubleshooting.mdx +++ b/frontend/src/content/en/application/operations-and-troubleshooting.mdx @@ -15,15 +15,13 @@ All services write logs to the `logs/` directory when started with `make dev`: | File | Service | | -------------------- | ------------------------------------ | -| `logs/langgraph.log` | LangGraph / DeerFlow Harness runtime | -| `logs/gateway.log` | FastAPI Gateway API | +| `logs/gateway.log` | FastAPI Gateway API and agent runtime | | `logs/frontend.log` | Next.js frontend dev server | | `logs/nginx.log` | nginx reverse proxy | Tail logs in real time: ```bash -tail -f logs/langgraph.log tail -f logs/gateway.log ``` @@ -41,9 +39,6 @@ Verify each service is responding: # Gateway health curl http://localhost:8001/health -# LangGraph health -curl http://localhost:2024/ok - # Through nginx (verifies full proxy chain) curl http://localhost:2026/api/models ``` @@ -66,7 +61,7 @@ grep config_version config.yaml ### The app loads but the agent doesn't respond -1. Check `logs/langgraph.log` for startup errors. +1. Check `logs/gateway.log` for startup errors. 2. Verify your model is correctly configured in `config.yaml` with a valid API key. 3. Confirm the API key environment variable is set in the shell that ran `make dev`. 4. Test the model endpoint directly with `curl` to rule out network issues. @@ -126,7 +121,7 @@ Connection refused: http://provisioner:8002 If MCP tools appear in `extensions_config.json` but are not available in the agent: -1. Check `logs/langgraph.log` for MCP initialization errors. +1. Check `logs/gateway.log` for MCP initialization errors. 2. Verify the MCP server command is installed (`npx`, `uvx`, or the relevant binary). 3. Test the server command manually to confirm it starts without errors. 4. Set `log_level: debug` to see detailed MCP loading output. @@ -137,7 +132,7 @@ If MCP tools appear in `extensions_config.json` but are not available in the age - Verify `memory.enabled: true` in `config.yaml`. - Check that the storage path is writable: `ls -la backend/.deer-flow/`. -- Look for memory update errors in `logs/langgraph.log` (search for "memory"). +- Look for memory update errors in `logs/gateway.log` (search for "memory"). ## Data backup diff --git a/frontend/src/content/en/application/quick-start.mdx b/frontend/src/content/en/application/quick-start.mdx index 5ecfb3a26..c3baa0764 100644 --- a/frontend/src/content/en/application/quick-start.mdx +++ b/frontend/src/content/en/application/quick-start.mdx @@ -1,6 +1,6 @@ --- title: Quick Start -description: This guide walks you through starting DeerFlow App on your local machine using the `make dev` workflow. All four services (LangGraph, Gateway, Frontend, nginx) start together and are accessible through a single URL. +description: This guide walks you through starting DeerFlow App on your local machine using the `make dev` workflow. Gateway, Frontend, and nginx start together and are accessible through a single URL. --- import { Callout, Cards, Steps } from "nextra/components"; @@ -12,7 +12,7 @@ import { Callout, Cards, Steps } from "nextra/components"; Python 3.12+, Node.js 22+, and at least one LLM API key. -This guide walks you through starting DeerFlow App on your local machine using the `make dev` workflow. All four services (LangGraph, Gateway, Frontend, nginx) start together and are accessible through a single URL. +This guide walks you through starting DeerFlow App on your local machine using the `make dev` workflow. Gateway, Frontend, and nginx start together and are accessible through a single URL. ## Prerequisites @@ -88,8 +88,7 @@ make dev This starts: -- LangGraph server on port `2024` -- Gateway API on port `8001` +- Gateway API and embedded agent runtime on port `8001` - Frontend on port `3000` - nginx reverse proxy on port `2026` @@ -113,15 +112,13 @@ Log files: | Service | Log file | | --------- | -------------------- | -| LangGraph | `logs/langgraph.log` | | Gateway | `logs/gateway.log` | | Frontend | `logs/frontend.log` | | nginx | `logs/nginx.log` | If something is not working, check the log files first. Most startup errors - (missing API keys, config parsing failures) appear in `logs/langgraph.log` or - `logs/gateway.log`. + (missing API keys, config parsing failures) appear in `logs/gateway.log`. diff --git a/frontend/src/content/en/harness/skills.mdx b/frontend/src/content/en/harness/skills.mdx index 09f8b0d43..78247c40b 100644 --- a/frontend/src/content/en/harness/skills.mdx +++ b/frontend/src/content/en/harness/skills.mdx @@ -68,7 +68,7 @@ DeerFlow ships with the following public skills: ### Discovery and loading -`load_skills()` in `skills/loader.py` scans both `public/` and `custom/` directories under the configured skills path. It re-reads `ExtensionsConfig.from_file()` on every call, which means enabling or disabling a skill through the Gateway API takes effect immediately in the running LangGraph server without a restart. +`load_skills()` in `skills/loader.py` scans both `public/` and `custom/` directories under the configured skills path. It re-reads `ExtensionsConfig.from_file()` on every call, which means enabling or disabling a skill through the Gateway API takes effect immediately in the running agent runtime without a restart. ### Parsing diff --git a/frontend/src/content/zh/application/configuration.mdx b/frontend/src/content/zh/application/configuration.mdx index 639eeaec5..0094323e7 100644 --- a/frontend/src/content/zh/application/configuration.mdx +++ b/frontend/src/content/zh/application/configuration.mdx @@ -215,7 +215,6 @@ BETTER_AUTH_SECRET=local-dev-secret-at-least-32-chars | `DEER_FLOW_CONFIG_PATH` | 自动发现 | `config.yaml` 的绝对路径 | | `LOG_LEVEL` | `info` | 日志详细程度(`debug`/`info`/`warning`/`error`) | | `DEER_FLOW_ROOT` | 仓库根目录 | 用于 Docker 中的技能和线程挂载 | -| `LANGGRAPH_UPSTREAM` | `langgraph:2024` | nginx 代理的 LangGraph 地址 | diff --git a/frontend/src/content/zh/application/deployment-guide.mdx b/frontend/src/content/zh/application/deployment-guide.mdx index 59eceece2..635120337 100644 --- a/frontend/src/content/zh/application/deployment-guide.mdx +++ b/frontend/src/content/zh/application/deployment-guide.mdx @@ -23,8 +23,7 @@ make dev | 服务 | 端口 | 描述 | | ----------- | ---- | ----------------------- | -| LangGraph | 2024 | DeerFlow Harness 运行时 | -| Gateway API | 8001 | FastAPI 后端 | +| Gateway API | 8001 | FastAPI 后端 + 嵌入式 Agent 运行时 | | 前端 | 3000 | Next.js 界面 | | nginx | 2026 | 统一反向代理 | @@ -36,13 +35,12 @@ make dev make stop ``` -停止所有四个服务。即使某个服务没有运行也可以安全执行。 +停止所有服务。即使某个服务没有运行也可以安全执行。 ``` -logs/langgraph.log # Agent 运行时日志 -logs/gateway.log # API Gateway 日志 +logs/gateway.log # API Gateway 和 Agent 运行时日志 logs/frontend.log # Next.js 开发服务器日志 logs/nginx.log # nginx 访问/错误日志 ``` @@ -50,7 +48,7 @@ logs/nginx.log # nginx 访问/错误日志 实时追踪日志: ```bash -tail -f logs/langgraph.log +tail -f logs/gateway.log ``` @@ -96,7 +94,7 @@ BETTER_AUTH_SECRET=your-secret-here-min-32-chars ### 数据持久化 -线程数据存储在 `backend/.deer-flow/threads/`。在 Docker 部署中,此目录被绑定挂载到 langgraph 容器中。 +线程数据存储在 `backend/.deer-flow/threads/`。在 Docker 部署中,此目录会绑定挂载到 gateway 容器中。 为避免容器重建时数据丢失: @@ -156,14 +154,7 @@ SKILLS_PVC_NAME=deer-flow-skills-pvc ### nginx 配置 -nginx 路由所有流量,控制路由的关键环境变量: - -| 变量 | 默认值 | 描述 | -| -------------------- | ---------------- | ----------------------------- | -| `LANGGRAPH_UPSTREAM` | `langgraph:2024` | LangGraph 服务地址 | -| `LANGGRAPH_REWRITE` | `/` | LangGraph 路由的 URL 重写前缀 | - -这些在 Docker Compose 环境中设置,并在容器启动时由 `envsubst` 处理。 +nginx 将流量路由到前端或 Gateway。`/api/langgraph/*` 会被重写到 Gateway 的 LangGraph-compatible `/api/*` 路由,因此不需要单独的 LangGraph upstream。 ### 认证配置 @@ -181,8 +172,7 @@ openssl rand -base64 32 | 服务 | 最低配置 | 推荐配置 | | ------------------------- | ---------------- | ---------------- | -| LangGraph(Agent 运行时) | 2 vCPU、4 GB RAM | 4 vCPU、8 GB RAM | -| Gateway | 0.5 vCPU、512 MB | 1 vCPU、1 GB | +| Gateway + Agent 运行时 | 2 vCPU、4 GB RAM | 4 vCPU、8 GB RAM | | 前端 | 0.5 vCPU、512 MB | 1 vCPU、1 GB | | 沙箱容器(每会话) | 1 vCPU、1 GB | 2 vCPU、2 GB | @@ -194,9 +184,6 @@ openssl rand -base64 32 # 检查 Gateway 健康状态 curl http://localhost:8001/health -# 检查 LangGraph 健康状态 -curl http://localhost:2024/ok - # 通过 nginx 列出配置的模型(验证完整代理链) curl http://localhost:2026/api/models ``` diff --git a/frontend/src/content/zh/application/index.mdx b/frontend/src/content/zh/application/index.mdx index 81e7113e2..c12959b42 100644 --- a/frontend/src/content/zh/application/index.mdx +++ b/frontend/src/content/zh/application/index.mdx @@ -25,11 +25,11 @@ DeerFlow 应用是 DeerFlow 生产体验的参考实现。它将 Harness 运行 | **流式响应** | 实时 token 流式传输,带思考步骤和工具调用可见性 | | **产出物查看器** | Agent 生成文件和输出的浏览器内预览和下载 | | **扩展界面** | 无需编辑配置文件即可启用/禁用 MCP 服务器和技能 | -| **Gateway API** | 桥接前端和 LangGraph 运行时的基于 FastAPI 的 REST API | +| **Gateway API** | 基于 FastAPI 的 REST API,并内置 LangGraph-compatible Agent 运行时 | ## 架构 -DeerFlow 应用以四个服务的形式运行,通过单个 nginx 反向代理提供: +DeerFlow 应用通过单个 nginx 反向代理提供: ``` ┌──────────────────┐ @@ -42,19 +42,11 @@ DeerFlow 应用以四个服务的形式运行,通过单个 nginx 反向代理 │ 前端 :3000 │ │ Gateway API :8001 │ │ (Next.js) │ │ (FastAPI) │ └──────────────────┘ └──────────────────────┘ - │ - ┌─────────┘ - ▼ - ┌──────────────────────┐ - │ LangGraph :2024 │ - │ (DeerFlow Harness) │ - └──────────────────────┘ ``` -- **nginx**:路由请求——`/api/*` 到 Gateway,LangGraph 流式端点到 LangGraph,其余到前端。 -- **前端**(Next.js + React):浏览器界面,与 Gateway 和 LangGraph 通信。 -- **Gateway**(FastAPI):处理 API 操作——模型列表、Agent CRUD、记忆、扩展管理、文件上传。 -- **LangGraph**:DeerFlow Harness 运行时,管理线程状态、Agent 执行和流式传输。 +- **nginx**:路由请求——`/api/*` 和 `/api/langgraph/*` 到 Gateway,其余到前端。 +- **前端**(Next.js + React):浏览器界面,与 Gateway 通信。 +- **Gateway**(FastAPI):处理 API 操作,并通过内置 LangGraph-compatible 运行时管理线程状态、Agent 执行和流式传输。 ## 技术栈 @@ -64,7 +56,7 @@ DeerFlow 应用以四个服务的形式运行,通过单个 nginx 反向代理 | Gateway | FastAPI、Python 3.12、uvicorn | | Agent 运行时 | LangGraph、LangChain、DeerFlow Harness | | 反向代理 | nginx | -| 状态持久化 | LangGraph Server(默认)+ 可选 SQLite/PostgreSQL 检查点 | +| 状态持久化 | Gateway 运行时 + 可选 SQLite/PostgreSQL 检查点 | diff --git a/frontend/src/content/zh/application/operations-and-troubleshooting.mdx b/frontend/src/content/zh/application/operations-and-troubleshooting.mdx index c047bbd5c..8dc4c6551 100644 --- a/frontend/src/content/zh/application/operations-and-troubleshooting.mdx +++ b/frontend/src/content/zh/application/operations-and-troubleshooting.mdx @@ -15,16 +15,14 @@ DeerFlow 应用在 `logs/` 目录中写入每个服务的日志: | 文件 | 内容 | | -------------------- | -------------------------------------- | -| `logs/langgraph.log` | Agent 运行时、工具调用、LangGraph 错误 | -| `logs/gateway.log` | API 请求/响应、Gateway 错误 | +| `logs/gateway.log` | API 请求/响应、Agent 运行时和 Gateway 错误 | | `logs/frontend.log` | Next.js 服务器日志 | | `logs/nginx.log` | 代理访问和错误日志 | **实时追踪日志**: ```bash -tail -f logs/langgraph.log # 查看 Agent 活动 -tail -f logs/gateway.log # 查看 API 请求 +tail -f logs/gateway.log # 查看 API 请求和 Agent 活动 ``` **调整日志级别**: @@ -42,9 +40,6 @@ DeerFlow 暴露健康检查端点: # Gateway 健康状态 curl http://localhost:8001/health -# LangGraph 健康状态 -curl http://localhost:2024/ok - # 通过 nginx 完整代理链验证 curl http://localhost:2026/api/models ``` @@ -68,8 +63,8 @@ make config-upgrade **诊断**: ```bash -# 检查 LangGraph 日志中的模型错误 -grep -i "error\|apikey\|unauthorized" logs/langgraph.log | tail -20 +# 检查 Gateway 日志中的模型错误 +grep -i "error\|apikey\|unauthorized" logs/gateway.log | tail -20 ``` **解决**: @@ -118,13 +113,13 @@ SKIP_ENV_VALIDATION=1 pnpm build ### MCP 服务器连接失败 -**症状**:MCP 工具未出现,`logs/langgraph.log` 中有超时错误。 +**症状**:MCP 工具未出现,`logs/gateway.log` 中有超时错误。 **诊断**: ```bash # 检查 MCP 相关错误 -grep -i "mcp\|timeout" logs/langgraph.log | tail -20 +grep -i "mcp\|timeout" logs/gateway.log | tail -20 ``` **解决**: diff --git a/frontend/src/content/zh/application/quick-start.mdx b/frontend/src/content/zh/application/quick-start.mdx index 5ccf117ad..b5ab052fc 100644 --- a/frontend/src/content/zh/application/quick-start.mdx +++ b/frontend/src/content/zh/application/quick-start.mdx @@ -1,6 +1,6 @@ --- title: 快速上手 -description: 本指南引导你使用 `make dev` 工作流在本地机器上启动 DeerFlow 应用。所有四个服务(LangGraph、Gateway、前端、nginx)一起启动,通过单个 URL 访问。 +description: 本指南引导你使用 `make dev` 工作流在本地机器上启动 DeerFlow 应用。Gateway、前端和 nginx 会一起启动,通过单个 URL 访问。 --- import { Callout, Cards, Steps } from "nextra/components"; @@ -12,7 +12,7 @@ import { Callout, Cards, Steps } from "nextra/components"; 3.12+、Node.js 22+ 的机器,以及至少一个 LLM API Key。 -本指南引导你使用 `make dev` 工作流在本地机器上启动 DeerFlow 应用。所有四个服务(LangGraph、Gateway、前端、nginx)一起启动,通过单个 URL 访问。 +本指南引导你使用 `make dev` 工作流在本地机器上启动 DeerFlow 应用。Gateway、前端和 nginx 会一起启动,通过单个 URL 访问。 ## 前置条件 @@ -88,8 +88,7 @@ make dev 这会启动: -- LangGraph 服务,端口 `2024` -- Gateway API,端口 `8001` +- Gateway API 和嵌入式 Agent 运行时,端口 `8001` - 前端,端口 `3000` - nginx 反向代理,端口 `2026` @@ -113,15 +112,13 @@ make stop | 服务 | 日志文件 | | --------- | -------------------- | -| LangGraph | `logs/langgraph.log` | | Gateway | `logs/gateway.log` | | 前端 | `logs/frontend.log` | | nginx | `logs/nginx.log` | 如果有问题,先检查日志文件。大多数启动错误(缺失 API - Key、配置解析失败)会出现在 logs/langgraph.log 或{" "} - logs/gateway.log 中。 + Key、配置解析失败)会出现在 logs/gateway.log 中。 diff --git a/skills/public/claude-to-deerflow/SKILL.md b/skills/public/claude-to-deerflow/SKILL.md index d191f5c75..969a292c1 100644 --- a/skills/public/claude-to-deerflow/SKILL.md +++ b/skills/public/claude-to-deerflow/SKILL.md @@ -14,8 +14,8 @@ DeerFlow exposes two API surfaces behind an Nginx reverse proxy: | Service | Direct Port | Via Proxy | Purpose | |----------------|-------------|----------------------------------|----------------------------------| -| Gateway API | 8001 | `$DEERFLOW_GATEWAY_URL` | REST endpoints (models, skills, memory, uploads) | -| LangGraph API | 2024 | `$DEERFLOW_LANGGRAPH_URL` | Agent threads, runs, streaming | +| Gateway API | 8001 | `$DEERFLOW_GATEWAY_URL` | REST endpoints and embedded agent runtime | +| LangGraph-compatible API | 8001 | `$DEERFLOW_LANGGRAPH_URL` | Agent threads, runs, streaming | ## Environment Variables From f734e14d8b5004a6e5088499f27138f25e917653 Mon Sep 17 00:00:00 2001 From: greatmengqi Date: Tue, 12 May 2026 23:07:11 +0800 Subject: [PATCH 14/86] docs: document auth design and user isolation (#2913) * docs: document auth design and user isolation * docs: align auth docs with current storage and reset behavior --------- Co-authored-by: greatmengqi --- backend/app/gateway/app.py | 4 +- backend/app/gateway/auth/models.py | 2 +- backend/app/gateway/routers/auth.py | 2 +- backend/docs/API.md | 26 ++- backend/docs/AUTH_DESIGN.md | 331 +++++++++++++++++++++++++++ backend/docs/AUTH_TEST_DOCKER_GAP.md | 12 +- backend/docs/AUTH_TEST_PLAN.md | 254 +++++++++++--------- backend/docs/AUTH_UPGRADE.md | 59 +++-- backend/docs/README.md | 2 + 9 files changed, 547 insertions(+), 145 deletions(-) create mode 100644 backend/docs/AUTH_DESIGN.md diff --git a/backend/app/gateway/app.py b/backend/app/gateway/app.py index 8848f473e..2c13f571c 100644 --- a/backend/app/gateway/app.py +++ b/backend/app/gateway/app.py @@ -62,7 +62,7 @@ async def _ensure_admin_user(app: FastAPI) -> None: Subsequent boots (admin already exists): - Runs the one-time "no-auth → with-auth" orphan thread migration for - existing LangGraph thread metadata that has no owner_id. + existing LangGraph thread metadata that has no user_id. No SQL persistence migration is needed: the four user_id columns (threads_meta, runs, run_events, feedback) only come into existence @@ -177,7 +177,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: async with langgraph_runtime(app): logger.info("LangGraph runtime initialised") - # Ensure admin user exists (auto-create on first boot) + # Check admin bootstrap state and migrate orphan threads after admin exists. # Must run AFTER langgraph_runtime so app.state.store is available for thread migration await _ensure_admin_user(app) diff --git a/backend/app/gateway/auth/models.py b/backend/app/gateway/auth/models.py index d8f9b954a..25c6476fe 100644 --- a/backend/app/gateway/auth/models.py +++ b/backend/app/gateway/auth/models.py @@ -28,7 +28,7 @@ class User(BaseModel): oauth_id: str | None = Field(None, description="User ID from OAuth provider") # Auth lifecycle - needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes") + needs_setup: bool = Field(default=False, description="True when a reset account must complete setup") token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs") diff --git a/backend/app/gateway/routers/auth.py b/backend/app/gateway/routers/auth.py index 3a41e13eb..6192456fb 100644 --- a/backend/app/gateway/routers/auth.py +++ b/backend/app/gateway/routers/auth.py @@ -305,7 +305,7 @@ async def login_local( async def register(request: Request, response: Response, body: RegisterRequest): """Register a new user account (always 'user' role). - Admin is auto-created on first boot. This endpoint creates regular users. + The first admin is created explicitly through /initialize. This endpoint creates regular users. Auto-login by setting the session cookie. """ try: diff --git a/backend/docs/API.md b/backend/docs/API.md index d0b06ef0b..762a135c4 100644 --- a/backend/docs/API.md +++ b/backend/docs/API.md @@ -535,14 +535,28 @@ All APIs return errors in a consistent format: ## Authentication -Currently, DeerFlow does not implement authentication. All APIs are accessible without credentials. +DeerFlow enforces authentication for all non-public HTTP routes. Public routes are limited to health/docs metadata and these public auth endpoints: -Note: This is about DeerFlow API authentication. MCP outbound connections can still use OAuth for configured HTTP/SSE MCP servers. +- `POST /api/v1/auth/initialize` creates the first admin account when no admin exists. +- `POST /api/v1/auth/login/local` logs in with email/password and sets an HttpOnly `access_token` cookie. +- `POST /api/v1/auth/register` creates a regular `user` account and sets the session cookie. +- `POST /api/v1/auth/logout` clears the session cookie. +- `GET /api/v1/auth/setup-status` reports whether the first admin still needs to be created. -For production deployments, it is recommended to: -1. Use Nginx for basic auth or OAuth integration -2. Deploy behind a VPN or private network -3. Implement custom authentication middleware +The authenticated auth endpoints are: + +- `GET /api/v1/auth/me` returns the current user. +- `POST /api/v1/auth/change-password` changes password, optionally changes email during setup, increments `token_version`, and reissues the cookie. + +Protected state-changing requests also require the CSRF double-submit token: send the `csrf_token` cookie value as the `X-CSRF-Token` header. Login/register/initialize/logout are bootstrap auth endpoints: they are exempt from the double-submit token but still reject hostile browser `Origin` headers. + +User isolation is enforced from the authenticated user context: + +- Thread metadata is scoped by `threads_meta.user_id`; search/read/write/delete APIs only expose the current user's threads. +- Thread files live under `{base_dir}/users/{user_id}/threads/{thread_id}/user-data/` and are exposed inside the sandbox as `/mnt/user-data/`. +- Memory and custom agents are stored under `{base_dir}/users/{user_id}/...`. + +Note: MCP outbound connections can still use OAuth for configured HTTP/SSE MCP servers; that is separate from DeerFlow API authentication. --- diff --git a/backend/docs/AUTH_DESIGN.md b/backend/docs/AUTH_DESIGN.md new file mode 100644 index 000000000..9a740871d --- /dev/null +++ b/backend/docs/AUTH_DESIGN.md @@ -0,0 +1,331 @@ +# 用户认证与隔离设计 + +本文档描述 DeerFlow 当前内置认证模块的设计,而不是历史 RFC。它覆盖浏览器登录、API 认证、CSRF、用户隔离、首次初始化、密码重置、内部调用和升级迁移。 + +## 设计目标 + +认证模块的核心目标是把 DeerFlow 从“本地单用户工具”提升为“可多用户部署的 agent runtime”,并让用户身份贯穿 HTTP API、LangGraph-compatible runtime、文件系统、memory、自定义 agent 和反馈数据。 + +设计约束: + +- 默认强制认证:除健康检查、文档和 auth bootstrap 端点外,HTTP 路由都必须有有效 session。 +- 服务端持有所有权:客户端 metadata 不能声明 `user_id` 或 `owner_id`。 +- 隔离默认开启:repository(仓储)、文件路径、memory、agent 配置默认按当前用户解析。 +- 旧数据可升级:无认证版本留下的 thread 可以在 admin 存在后迁移到 admin。 +- 密码不进日志:首次初始化由操作者设置密码;`reset_admin` 只写 0600 凭据文件。 + +非目标: + +- 当前 OAuth 端点只是占位,尚未实现第三方登录。 +- 当前用户角色只有 `admin` 和 `user`,尚未实现细粒度 RBAC。 +- 当前登录限速是进程内字典,多 worker 下不是全局精确限速。 + +## 核心模型 + +```mermaid +graph TB + classDef actor fill:#D8CFC4,stroke:#6E6259,color:#2F2A26; + classDef api fill:#C9D7D2,stroke:#5D706A,color:#21302C; + classDef state fill:#D7D3E8,stroke:#6B6680,color:#29263A; + classDef data fill:#E5D2C4,stroke:#806A5B,color:#30251E; + + Browser["Browser — access_token cookie and csrf_token cookie"]:::actor + AuthMiddleware["AuthMiddleware — strict session gate"]:::api + CSRFMiddleware["CSRFMiddleware — double-submit token and Origin check"]:::api + AuthRoutes["Auth routes — initialize login register logout me change-password"]:::api + UserContext["Current user ContextVar — request-scoped identity"]:::state + Repositories["Repositories — AUTO resolves user_id from context"]:::state + Files["Filesystem — users/{user_id}/threads/{thread_id}/user-data"]:::data + Memory["Memory and agents — users/{user_id}/memory.json and agents"]:::data + + Browser --> AuthMiddleware + Browser --> CSRFMiddleware + AuthMiddleware --> AuthRoutes + AuthMiddleware --> UserContext + UserContext --> Repositories + UserContext --> Files + UserContext --> Memory +``` + +### 用户表 + +用户记录定义在 `app.gateway.auth.models.User`,持久化到 `users` 表。关键字段: + +| 字段 | 语义 | +|---|---| +| `id` | 用户主键,JWT `sub` 使用该值 | +| `email` | 唯一登录名 | +| `password_hash` | bcrypt hash,OAuth 用户可为空 | +| `system_role` | `admin` 或 `user` | +| `needs_setup` | reset 后要求用户完成邮箱 / 密码设置 | +| `token_version` | 改密码或 reset 时递增,用于废弃旧 JWT | + +### 运行时身份 + +认证成功后,`AuthMiddleware` 把用户同时写入: + +- `request.state.user` +- `request.state.auth` +- `deerflow.runtime.user_context` 的 `ContextVar` + +`ContextVar` 是这里的核心边界。上层 Gateway 负责写入身份,下层 persistence / file path 只读取结构化的当前用户,不反向依赖 `app.gateway.auth` 具体类型。 + +可以把 repository 调用的用户参数理解成一个三态 ADT: + +```scala +enum UserScope: + case AutoFromContext + case Explicit(userId: String) + case BypassForMigration +``` + +对应 Python 实现是 `AUTO | str | None`: + +- `AUTO`:从 `ContextVar` 解析当前用户;没有上下文则抛错。 +- `str`:显式指定用户,主要用于测试或管理脚本。 +- `None`:跳过用户过滤,只允许迁移脚本或 admin CLI 使用。 + +## 登录与初始化流程 + +### 首次初始化 + +首次启动时,如果没有 admin,服务不会自动创建账号,只记录日志提示访问 `/setup`。 + +流程: + +1. 用户访问 `/setup`。 +2. 前端调用 `GET /api/v1/auth/setup-status`。 +3. 如果返回 `{"needs_setup": true}`,前端展示创建 admin 表单。 +4. 表单提交 `POST /api/v1/auth/initialize`。 +5. 服务端确认当前没有 admin,创建 `system_role="admin"`、`needs_setup=false` 的用户。 +6. 服务端设置 `access_token` HttpOnly cookie,用户进入 workspace。 + +`/api/v1/auth/initialize` 只在没有 admin 时可用。并发初始化由数据库唯一约束兜底,失败方返回 409。 + +### 普通登录 + +`POST /api/v1/auth/login/local` 使用 `OAuth2PasswordRequestForm`: + +- `username` 是邮箱。 +- `password` 是密码。 +- 成功后签发 JWT,放入 `access_token` HttpOnly cookie。 +- 响应体只返回 `expires_in` 和 `needs_setup`,不返回 token。 + +登录失败会按客户端 IP 计数。IP 解析只在 TCP peer 属于 `AUTH_TRUSTED_PROXIES` 时信任 `X-Real-IP`,不使用 `X-Forwarded-For`。 + +### 注册 + +`POST /api/v1/auth/register` 创建普通 `user`,并自动登录。 + +当前实现允许在没有 admin 时注册普通用户,但 `setup-status` 仍会返回 `needs_setup=true`,因为 admin 仍不存在。这是当前产品策略边界:如果后续要求“必须先初始化 admin 才能注册普通用户”,需要在 `/register` 增加 admin-exists gate。 + +### 改密码与 reset setup + +`POST /api/v1/auth/change-password` 需要当前密码和新密码: + +- 校验当前密码。 +- 更新 bcrypt hash。 +- `token_version += 1`,使旧 JWT 立即失效。 +- 重新签发 cookie。 +- 如果 `needs_setup=true` 且传了 `new_email`,则更新邮箱并清除 `needs_setup`。 + +`python -m app.gateway.auth.reset_admin` 会: + +- 找到 admin 或指定邮箱用户。 +- 生成随机密码。 +- 更新密码 hash。 +- `token_version += 1`。 +- 设置 `needs_setup=true`。 +- 写入 `.deer-flow/admin_initial_credentials.txt`,权限 `0600`。 + +命令行只输出凭据文件路径,不输出明文密码。 + +## HTTP 认证边界 + +`AuthMiddleware` 是 fail-closed(默认拒绝)的全局认证门。 + +公开路径: + +- `/health` +- `/docs` +- `/redoc` +- `/openapi.json` +- `/api/v1/auth/login/local` +- `/api/v1/auth/register` +- `/api/v1/auth/logout` +- `/api/v1/auth/setup-status` +- `/api/v1/auth/initialize` + +其余路径都要求有效 `access_token` cookie。存在 cookie 但 JWT 无效、过期、用户不存在或 `token_version` 不匹配时,直接返回 401,而不是让请求穿透到业务路由。 + +路由级别的 owner check 由 `require_permission(..., owner_check=True)` 完成: + +- 读类请求允许旧的未追踪 legacy thread 兼容读取。 +- 写 / 删除类请求使用 `require_existing=True`,要求 thread row 存在且属于当前用户,避免删除后缺 row 导致其他用户误通过。 + +## CSRF 设计 + +DeerFlow 使用 Double Submit Cookie: + +- 服务端设置 `csrf_token` cookie。 +- 前端 state-changing 请求发送同值 `X-CSRF-Token` header。 +- 服务端用 `secrets.compare_digest` 比较 cookie/header。 + +需要 CSRF 的方法: + +- `POST` +- `PUT` +- `DELETE` +- `PATCH` + +auth bootstrap 端点(login/register/initialize/logout)不要求 double-submit token,因为首次调用时浏览器还没有 token;但这些端点会校验 browser `Origin`,拒绝 hostile Origin,避免 login CSRF / session fixation。 + +## 用户隔离 + +### Thread metadata + +Thread metadata 存在 `threads_meta`,关键隔离字段是 `user_id`。 + +创建 thread 时: + +- 客户端传入的 `metadata.user_id` 和 `metadata.owner_id` 会被剥离。 +- `ThreadMetaRepository.create(..., user_id=AUTO)` 从 `ContextVar` 解析真实用户。 +- `/api/threads/search` 默认只返回当前用户的 thread。 + +读取 / 修改 / 删除时: + +- `get()` 默认按当前用户过滤。 +- `check_access()` 用于路由 owner check。 +- 对其他用户的 thread 返回 404,避免泄露资源存在性。 + +### 文件系统 + +当前线程文件布局: + +```text +{base_dir}/users/{user_id}/threads/{thread_id}/user-data/ +├── workspace/ +├── uploads/ +└── outputs/ +``` + +agent 在 sandbox 内看到统一虚拟路径: + +```text +/mnt/user-data/workspace +/mnt/user-data/uploads +/mnt/user-data/outputs +``` + +`ThreadDataMiddleware` 使用 `get_effective_user_id()` 解析当前用户并生成线程路径。没有认证上下文时会落到 `default` 用户桶,主要用于内部调用、嵌入式 client 或无 HTTP 的本地执行路径。 + +### Memory + +默认 memory 存储: + +```text +{base_dir}/users/{user_id}/memory.json +{base_dir}/users/{user_id}/agents/{agent_name}/memory.json +``` + +有用户上下文时,空或相对 `memory.storage_path` 都使用上述 per-user 默认路径;只有绝对 `memory.storage_path` 会视为显式 opt-out(退出) per-user isolation,所有用户共享该路径。无用户上下文的 legacy 路径仍会把相对 `storage_path` 解析到 `Paths.base_dir` 下。 + +### 自定义 agent + +用户自定义 agent 写入: + +```text +{base_dir}/users/{user_id}/agents/{agent_name}/ +├── config.yaml +├── SOUL.md +└── memory.json +``` + +旧布局 `{base_dir}/agents/{agent_name}/` 只作为只读兼容回退。更新或删除旧共享 agent 会要求先运行迁移脚本。 + +## 内部调用与 IM 渠道 + +IM channel worker 不是浏览器用户,不持有浏览器 cookie。它们通过 Gateway 内部认证: + +- 请求带 `X-DeerFlow-Internal-Token`。 +- 同时带匹配的 CSRF cookie/header。 +- 服务端识别为内部用户,`id="default"`、`system_role="internal"`。 + +这意味着 channel 产生的数据默认进入 `default` 用户桶。这个选择适合“平台级 bot 身份”,但不是“每个 IM 用户单独隔离”。如果后续要做到外部 IM 用户隔离,需要把外部 platform user 映射到 DeerFlow user,并让 channel manager 设置对应的 scoped identity。 + +## LangGraph-compatible 认证 + +Gateway 内嵌 runtime 路径由 `AuthMiddleware` 和 `CSRFMiddleware` 保护。 + +仓库仍保留 `app.gateway.langgraph_auth`,用于 LangGraph Server 直连模式: + +- `@auth.authenticate` 校验 JWT cookie、CSRF、用户存在性和 `token_version`。 +- `@auth.on` 在写入 metadata 时注入 `user_id`,并在读路径返回 `{"user_id": current_user}` 过滤条件。 + +这保证 Gateway 路由和 LangGraph-compatible 直连模式使用同一 JWT 语义。 + +## 升级与迁移 + +从无认证版本升级时,可能存在没有 `user_id` 的历史 thread。 + +当前策略: + +1. 首次启动如果没有 admin,只提示访问 `/setup`,不迁移。 +2. 操作者创建 admin。 +3. 后续启动时,`_ensure_admin_user()` 找到 admin,并把 LangGraph store 中缺少 `metadata.user_id` 的 thread 迁移到 admin。 + +文件系统旧布局迁移由脚本处理: + +```bash +cd backend +PYTHONPATH=. python scripts/migrate_user_isolation.py --dry-run +PYTHONPATH=. python scripts/migrate_user_isolation.py --user-id +``` + +迁移脚本覆盖 legacy `memory.json`、`threads/` 和 `agents/` 到 per-user layout。 + +## 安全不变量 + +必须长期保持的不变量: + +- JWT 只在 HttpOnly cookie 中传输,不出现在响应 JSON。 +- 任何非 public HTTP 路由都不能只靠“cookie 存在”放行,必须严格验证 JWT。 +- `token_version` 不匹配必须拒绝,保证改密码 / reset 后旧 session 失效。 +- 客户端 metadata 中的 `user_id` / `owner_id` 必须剥离。 +- repository 默认 `AUTO` 必须从当前用户上下文解析,不能静默退化成全局查询。 +- 只有迁移脚本和 admin CLI 可以显式传 `user_id=None` 绕过隔离。 +- 本地文件路径必须通过 `Paths` 和 sandbox path validation 解析,不能拼接未校验的用户输入。 +- 捕获认证、迁移、后台任务异常必须记录日志;不能空 catch。 + +## 已知边界 + +| 边界 | 当前行为 | 后续方向 | +|---|---|---| +| 无 admin 时注册普通用户 | 允许注册普通 `user` | 如产品要求先初始化 admin,给 `/register` 加 gate | +| 登录限速 | 进程内 dict,单 worker 精确,多 worker 近似 | Redis / DB-backed rate limiter | +| OAuth | 端点占位,未实现 | 接入 provider 并统一 `token_version` / role 语义 | +| IM 用户隔离 | channel 使用 `default` 内部用户 | 建立外部用户到 DeerFlow user 的映射 | +| 绝对 memory path | 显式共享 memory | UI / docs 明确提示 opt-out 风险 | + +## 相关文件 + +| 文件 | 职责 | +|---|---| +| `app/gateway/auth_middleware.py` | 全局认证门、JWT 严格验证、写入 user context | +| `app/gateway/csrf_middleware.py` | CSRF double-submit 和 auth Origin 校验 | +| `app/gateway/routers/auth.py` | initialize/login/register/logout/me/change-password | +| `app/gateway/auth/jwt.py` | JWT 创建与解析 | +| `app/gateway/auth/reset_admin.py` | 密码 reset CLI | +| `app/gateway/auth/credential_file.py` | 0600 凭据文件写入 | +| `app/gateway/authz.py` | 路由权限与 owner check | +| `deerflow/runtime/user_context.py` | 当前用户 ContextVar 与 `AUTO` sentinel | +| `deerflow/persistence/thread_meta/` | thread metadata owner filter | +| `deerflow/config/paths.py` | per-user filesystem layout | +| `deerflow/agents/middlewares/thread_data_middleware.py` | run 时解析用户线程目录 | +| `deerflow/agents/memory/storage.py` | per-user memory storage | +| `deerflow/config/agents_config.py` | per-user custom agents | +| `app/channels/manager.py` | IM channel 内部认证调用 | +| `scripts/migrate_user_isolation.py` | legacy 数据迁移到 per-user layout | +| `.deer-flow/data/deerflow.db` | 统一 SQLite 数据库,包含 users / threads_meta / runs / feedback 等表 | +| `.deer-flow/users/{user_id}/agents/{agent_name}/` | 用户自定义 agent 配置、SOUL 和 agent memory | +| `.deer-flow/admin_initial_credentials.txt` | `reset_admin` 生成的新凭据文件(0600,读完应删除) | diff --git a/backend/docs/AUTH_TEST_DOCKER_GAP.md b/backend/docs/AUTH_TEST_DOCKER_GAP.md index adf4916a3..969aad92c 100644 --- a/backend/docs/AUTH_TEST_DOCKER_GAP.md +++ b/backend/docs/AUTH_TEST_DOCKER_GAP.md @@ -24,11 +24,11 @@ All other test plan sections were executed against either: | Case | Title | What it covers | Why not run | |---|---|---|---| -| TC-DOCKER-01 | `users.db` volume persistence | Verify the `DEER_FLOW_HOME` bind mount survives container restart | needs `docker compose up` | +| TC-DOCKER-01 | `deerflow.db` volume persistence | Verify the `DEER_FLOW_HOME` bind mount survives container restart | needs `docker compose up` | | TC-DOCKER-02 | Session persistence across container restart | `AUTH_JWT_SECRET` env var keeps cookies valid after `docker compose down && up` | needs `docker compose down/up` | | TC-DOCKER-03 | Per-worker rate limiter divergence | Confirms in-process `_login_attempts` dict doesn't share state across `gunicorn` workers (4 by default in the compose file); known limitation, documented | needs multi-worker container | -| TC-DOCKER-04 | IM channels skip AuthMiddleware | Verify Feishu/Slack/Telegram dispatchers run in-container against `http://langgraph:2024` without going through nginx | needs `docker logs` | -| TC-DOCKER-05 | Admin credentials surfacing | **Updated post-simplify** — was "log scrape", now "0600 credential file in `DEER_FLOW_HOME`". The file-based behavior is already validated by TC-1.1 + TC-UPG-13 on sg_dev (non-Docker), so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume | +| TC-DOCKER-04 | IM channels use internal Gateway auth | Verify Feishu/Slack/Telegram dispatchers attach the process-local internal auth header plus CSRF cookie/header when calling Gateway-compatible LangGraph APIs | needs `docker logs` | +| TC-DOCKER-05 | Reset credentials surfacing | `reset_admin` writes a 0600 credential file in `DEER_FLOW_HOME` instead of logging plaintext. The file-based behavior is validated by non-Docker reset tests, so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume | | TC-DOCKER-06 | Gateway-mode Docker deploy | `./scripts/deploy.sh --gateway` produces a 3-container topology (no `langgraph` container); same auth flow as standard mode | needs `docker compose --profile gateway` | ## Coverage already provided by non-Docker tests @@ -41,8 +41,8 @@ the test cases that ran on sg_dev or local: | TC-DOCKER-01 (volume persistence) | TC-REENT-01 on sg_dev (admin row survives gateway restart) — same SQLite file, just no container layer between | | TC-DOCKER-02 (session persistence) | TC-API-02/03/06 (cookie roundtrip), plus TC-REENT-04 (multi-cookie) — JWT verification is process-state-free, container restart is equivalent to `pkill uvicorn && uv run uvicorn` | | TC-DOCKER-03 (per-worker rate limit) | TC-GW-04 + TC-REENT-09 (single-worker rate limit + 5min expiry). The cross-worker divergence is an architectural property of the in-memory dict; no auth code path differs | -| TC-DOCKER-04 (IM channels skip auth) | Code-level only: `app/channels/manager.py` uses `langgraph_sdk` directly with no cookie handling. The langgraph_auth handler is bypassed by going through SDK, not HTTP | -| TC-DOCKER-05 (credential surfacing) | TC-1.1 on sg_dev (file at `~/deer-flow/backend/.deer-flow/admin_initial_credentials.txt`, mode 0600, password 22 chars) — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change | +| TC-DOCKER-04 (IM channels use internal auth) | Code-level: `app/channels/manager.py` creates the `langgraph_sdk` client with `create_internal_auth_headers()` plus CSRF cookie/header, so channel workers do not rely on browser cookies | +| TC-DOCKER-05 (credential surfacing) | `reset_admin` writes `.deer-flow/admin_initial_credentials.txt` with mode 0600 and logs only the path — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change | | TC-DOCKER-06 (gateway-mode container) | Section 七 7.2 covered by TC-GW-01..05 + Section 二 (gateway-mode auth flow on sg_dev) — same Gateway code, container is just a packaging change | ## Reproduction steps when Docker becomes available @@ -72,6 +72,6 @@ Then run TC-DOCKER-01..06 from the test plan as written. about *container packaging* details (bind mounts, multi-worker, log collection), not about whether the auth code paths work. - **TC-DOCKER-05 was updated in place** in `AUTH_TEST_PLAN.md` to reflect - the post-simplify reality (credentials file → 0600 file, no log leak). + the current reset flow (`reset_admin` → 0600 credentials file, no log leak). The old "grep 'Password:' in docker logs" expectation would have failed silently and given a false sense of coverage. diff --git a/backend/docs/AUTH_TEST_PLAN.md b/backend/docs/AUTH_TEST_PLAN.md index 15b20494a..e5245d60b 100644 --- a/backend/docs/AUTH_TEST_PLAN.md +++ b/backend/docs/AUTH_TEST_PLAN.md @@ -19,7 +19,7 @@ ```bash # 清除已有数据 -rm -f backend/.deer-flow/users.db +rm -f backend/.deer-flow/data/deerflow.db # 选择模式启动 make dev # 标准模式 @@ -28,10 +28,11 @@ make dev-pro # Gateway 模式 ``` **验证点:** -- [ ] 控制台输出 admin 邮箱和随机密码 -- [ ] 密码格式为 `secrets.token_urlsafe(16)` 的 22 字符字符串 -- [ ] 邮箱为 `admin@deerflow.dev` -- [ ] 提示 `Change it after login: Settings -> Account` +- [ ] 控制台不输出 admin 邮箱或明文密码 +- [ ] 控制台提示 `First boot detected — no admin account exists.` +- [ ] 控制台提示访问 `/setup` 完成 admin 创建 +- [ ] `GET /api/v1/auth/setup-status` 返回 `{"needs_setup": true}` +- [ ] 前端访问 `/login` 会跳转 `/setup` ### 1.2 非首次启动 @@ -42,7 +43,8 @@ make dev **验证点:** - [ ] 控制台不输出密码 -- [ ] 如果 admin 仍 `needs_setup=True`,控制台有 warning 提示 +- [ ] `GET /api/v1/auth/setup-status` 返回 `{"needs_setup": false}` +- [ ] 已登录用户如果 `needs_setup=True`,访问 workspace 会被引导到 `/setup` 完成改邮箱 / 改密码流程 ### 1.3 环境变量配置 @@ -76,19 +78,22 @@ make dev curl -s $BASE/api/v1/auth/setup-status | jq . ``` -**预期:** 返回 `{"needs_setup": false}`(admin 在启动时已自动创建,`count_users() > 0`)。仅在启动完成前的极短窗口内可能返回 `true`。 +**预期:** +- 干净数据库且尚未初始化 admin:返回 `{"needs_setup": true}` +- 已存在 admin:返回 `{"needs_setup": false}` -#### TC-API-02: Admin 首次登录 +#### TC-API-02: 首次初始化 Admin ```bash -curl -s -X POST $BASE/api/v1/auth/login/local \ - -d "username=admin@deerflow.dev&password=<控制台密码>" \ +curl -s -X POST $BASE/api/v1/auth/initialize \ + -H "Content-Type: application/json" \ + -d '{"email":"admin@example.com","password":"AdminPass1!"}' \ -c cookies.txt | jq . ``` **预期:** -- 状态码 200 -- Body: `{"expires_in": 604800, "needs_setup": true}` +- 状态码 201 +- Body: `{"id": "...", "email": "admin@example.com", "system_role": "admin", "needs_setup": false}` - `cookies.txt` 包含 `access_token`(HttpOnly)和 `csrf_token`(非 HttpOnly) #### TC-API-03: 获取当前用户 @@ -97,9 +102,9 @@ curl -s -X POST $BASE/api/v1/auth/login/local \ curl -s $BASE/api/v1/auth/me -b cookies.txt | jq . ``` -**预期:** `{"id": "...", "email": "admin@deerflow.dev", "system_role": "admin", "needs_setup": true}` +**预期:** `{"id": "...", "email": "admin@example.com", "system_role": "admin", "needs_setup": false}` -#### TC-API-04: Setup 流程(改邮箱 + 改密码) +#### TC-API-04: 改密码流程 ```bash CSRF=$(grep csrf_token cookies.txt | awk '{print $NF}') @@ -107,13 +112,36 @@ curl -s -X POST $BASE/api/v1/auth/change-password \ -b cookies.txt \ -H "Content-Type: application/json" \ -H "X-CSRF-Token: $CSRF" \ - -d '{"current_password":"<控制台密码>","new_password":"NewPass123!","new_email":"admin@example.com"}' | jq . + -d '{"current_password":"AdminPass1!","new_password":"NewPass123!"}' | jq . ``` **预期:** - 状态码 200 - `{"message": "Password changed successfully"}` -- 再调 `/auth/me` 邮箱变为 `admin@example.com`,`needs_setup` 变为 `false` +- 再调 `/auth/me` 仍为 `admin@example.com`,`needs_setup` 仍为 `false` + +#### TC-API-04a: reset_admin 后的 Setup 流程(改邮箱 + 改密码) + +```bash +cd backend +python -m app.gateway.auth.reset_admin --email admin@example.com +# 从 .deer-flow/admin_initial_credentials.txt 读取 reset 后密码 + +curl -s -X POST $BASE/api/v1/auth/login/local \ + -d "username=admin@example.com&password=<凭据文件密码>" \ + -c cookies.txt | jq . + +CSRF=$(grep csrf_token cookies.txt | awk '{print $NF}') +curl -s -X POST $BASE/api/v1/auth/change-password \ + -b cookies.txt \ + -H "Content-Type: application/json" \ + -H "X-CSRF-Token: $CSRF" \ + -d '{"current_password":"<凭据文件密码>","new_password":"AdminPass2!","new_email":"admin2@example.com"}' | jq . +``` + +**预期:** +- 登录返回 `{"expires_in": 604800, "needs_setup": true}` +- `change-password` 后 `/auth/me` 邮箱变为 `admin2@example.com`,`needs_setup` 变为 `false` #### TC-API-05: 普通用户注册 @@ -493,7 +521,7 @@ curl -s -X POST $BASE/api/v1/auth/register \ ```bash # 检查数据库 -sqlite3 backend/.deer-flow/users.db "SELECT email, password_hash FROM users LIMIT 3;" +sqlite3 backend/.deer-flow/data/deerflow.db "SELECT email, password_hash FROM users LIMIT 3;" ``` **预期:** `password_hash` 以 `$2b$` 开头(bcrypt 格式) @@ -506,24 +534,25 @@ sqlite3 backend/.deer-flow/users.db "SELECT email, password_hash FROM users LIMI ### 4.1 首次登录流程 -#### TC-UI-01: 访问首页跳转登录 +#### TC-UI-01: 无 admin 时访问 workspace 跳转 setup 1. 打开 `http://localhost:2026/workspace` -2. **预期:** 自动跳转到 `/login` +2. **预期:** 自动跳转到 `/setup` -#### TC-UI-02: Login 页面 +#### TC-UI-02: Setup 页面创建 admin -1. 输入 admin 邮箱和控制台密码 -2. 点击 Login -3. **预期:** 跳转到 `/setup`(因为 `needs_setup=true`) - -#### TC-UI-03: Setup 页面 - -1. 输入新邮箱、控制台密码(current)、新密码、确认密码 -2. 点击 Complete Setup +1. 输入 admin 邮箱、密码、确认密码 +2. 点击 Create Admin Account 3. **预期:** 跳转到 `/workspace` 4. 刷新页面不跳回 `/setup` +#### TC-UI-03: 已初始化后 Login 页面 + +1. 退出登录后访问 `/login` +2. 输入 admin 邮箱和密码 +3. 点击 Login +4. **预期:** 跳转到 `/workspace` + #### TC-UI-04: Setup 密码不匹配 1. 新密码和确认密码不一致 @@ -602,7 +631,7 @@ sqlite3 backend/.deer-flow/users.db "SELECT email, password_hash FROM users LIMI #### TC-UI-15: reset_admin 后重新登录 1. 执行 `cd backend && python -m app.gateway.auth.reset_admin` -2. 使用新密码登录 +2. 从 `.deer-flow/admin_initial_credentials.txt` 读取新密码并登录 3. **预期:** 跳转到 `/setup` 页面(`needs_setup` 被重置为 true) 4. 旧 session 已失效 @@ -645,18 +674,28 @@ make install make dev ``` -#### TC-UPG-01: 首次启动创建 admin +#### TC-UPG-01: 首次启动等待 admin 初始化 **预期:** -- [ ] 控制台输出 admin 邮箱(`admin@deerflow.dev`)和随机密码 +- [ ] 控制台不输出 admin 邮箱或随机密码 +- [ ] 访问 `/setup` 可创建第一个 admin - [ ] 无报错,正常启动 #### TC-UPG-02: 旧 Thread 迁移到 admin ```bash +# 创建第一个 admin +curl -s -X POST http://localhost:2026/api/v1/auth/initialize \ + -H "Content-Type: application/json" \ + -d '{"email":"admin@example.com","password":"AdminPass1!"}' \ + -c cookies.txt + +# 重启一次:启动迁移只在已有 admin 的启动路径执行 +make stop && make dev + # 登录 admin curl -s -X POST http://localhost:2026/api/v1/auth/login/local \ - -d "username=admin@deerflow.dev&password=<控制台密码>" \ + -d "username=admin@example.com&password=AdminPass1!" \ -c cookies.txt # 查看 thread 列表 @@ -670,8 +709,8 @@ curl -s -X POST http://localhost:2026/api/threads/search \ **预期:** - [ ] 返回的 thread 数量 ≥ 旧版创建的数量 -- [ ] 控制台日志有 `Migrated N orphaned thread(s) to admin` -- [ ] 每个 thread 的 `metadata.owner_id` 都已被设为 admin 的 ID +- [ ] 控制台日志有 `Migrated N orphan LangGraph thread(s) to admin` +- [ ] 旧 thread 只对 admin 可见 #### TC-UPG-03: 旧 Thread 内容完整 @@ -683,7 +722,7 @@ curl -s http://localhost:2026/api/threads/ \ **预期:** - [ ] `metadata.title` 保留原值(如 `old-thread-1`) -- [ ] `metadata.owner_id` 已填充 +- [ ] 响应不回显服务端保留的 `user_id` / `owner_id` #### TC-UPG-04: 新用户看不到旧 Thread @@ -706,18 +745,19 @@ curl -s -X POST http://localhost:2026/api/threads/search \ ### 5.3 数据库 Schema 兼容 -#### TC-UPG-05: 无 users.db 时自动创建 +#### TC-UPG-05: 无 deerflow.db 时创建 schema 但不创建默认用户 ```bash -ls -la backend/.deer-flow/users.db +ls -la backend/.deer-flow/data/deerflow.db +sqlite3 backend/.deer-flow/data/deerflow.db "SELECT COUNT(*) FROM users;" ``` -**预期:** 文件存在,`sqlite3` 可查到 `users` 表含 `needs_setup`、`token_version` 列 +**预期:** 文件存在,`sqlite3` 可查到 `users` 表含 `needs_setup`、`token_version` 列;未调用 `/initialize` 前用户数为 0 -#### TC-UPG-06: users.db WAL 模式 +#### TC-UPG-06: deerflow.db WAL 模式 ```bash -sqlite3 backend/.deer-flow/users.db "PRAGMA journal_mode;" +sqlite3 backend/.deer-flow/data/deerflow.db "PRAGMA journal_mode;" ``` **预期:** 返回 `wal` @@ -768,9 +808,9 @@ make dev ``` **预期:** -- [ ] 服务正常启动(忽略 `users.db`,无 auth 相关代码不报错) +- [ ] 服务正常启动(忽略 `deerflow.db`,无 auth 相关代码不报错) - [ ] 旧对话数据仍然可访问 -- [ ] `users.db` 文件残留但不影响运行 +- [ ] `deerflow.db` 文件残留但不影响运行 #### TC-UPG-12: 再次升级到 auth 分支 @@ -781,51 +821,47 @@ make dev ``` **预期:** -- [ ] 识别已有 `users.db`,不重新创建 admin -- [ ] 旧的 admin 账号仍可登录(如果回退期间未删 `users.db`) +- [ ] 识别已有 `deerflow.db`,不重新创建 admin +- [ ] 旧的 admin 账号仍可登录(如果回退期间未删 `deerflow.db`) -### 5.7 休眠 Admin(初始密码未使用/未更改) +### 5.7 Admin 初始化与 reset_admin -> 首次启动生成 admin + 随机密码,但运维未登录、未改密码。 -> 密码只在首次启动的控制台闪过一次,后续启动不再显示。 +> 首次启动不生成默认 admin,也不在日志输出密码。忘记密码时走 `reset_admin`,新密码写入 0600 凭据文件。 -#### TC-UPG-13: 重启后自动重置密码并打印 +#### TC-UPG-13: 未初始化 admin 时重启不创建默认账号 ```bash -# 首次启动,记录密码 -rm -f backend/.deer-flow/users.db +rm -f backend/.deer-flow/data/deerflow.db make dev -# 控制台输出密码 P0,不登录 make stop -# 隔了几天,再次启动 make dev -# 控制台输出新密码 P1 +curl -s $BASE/api/v1/auth/setup-status | jq . ``` **预期:** -- [ ] 控制台输出 `Admin account setup incomplete — password reset` -- [ ] 输出新密码 P1(P0 已失效) -- [ ] 用 P1 可以登录,P0 不可以 -- [ ] 登录后 `needs_setup=true`,跳转 `/setup` -- [ ] `token_version` 递增(旧 session 如有也失效) +- [ ] 控制台不输出密码 +- [ ] `setup-status` 仍为 `{"needs_setup": true}` +- [ ] 访问 `/setup` 仍可创建第一个 admin -#### TC-UPG-14: 密码丢失 — 无需 CLI,重启即可 +#### TC-UPG-14: 密码丢失 — reset_admin 写入凭据文件 ```bash -# 忘记了控制台密码 → 直接重启服务 -make stop && make dev -# 控制台自动输出新密码 +python -m app.gateway.auth.reset_admin --email admin@example.com +ls -la backend/.deer-flow/admin_initial_credentials.txt +cat backend/.deer-flow/admin_initial_credentials.txt ``` **预期:** -- [ ] 无需 `reset_admin`,重启服务即可拿到新密码 -- [ ] `reset_admin` CLI 仍然可用作手动备选方案 +- [ ] 命令行只输出凭据文件路径,不输出明文密码 +- [ ] 凭据文件权限为 `0600` +- [ ] 凭据文件包含 email + password 行 +- [ ] 该用户下次登录返回 `needs_setup=true` -#### TC-UPG-15: 休眠 admin 期间普通用户注册 +#### TC-UPG-15: 未初始化 admin 期间普通用户注册策略边界 ```bash -# admin 存在但从未登录,普通用户先注册 +# admin 尚不存在,普通用户尝试注册 curl -s -X POST $BASE/api/v1/auth/register \ -H "Content-Type: application/json" \ -d '{"email":"earlybird@example.com","password":"EarlyPass1!"}' \ @@ -833,11 +869,11 @@ curl -s -X POST $BASE/api/v1/auth/register \ ``` **预期:** -- [ ] 注册成功(201),角色为 `user` -- [ ] 无法提权为 admin -- [ ] 普通用户的数据与 admin 隔离 +- [ ] 当前代码允许注册普通用户并自动登录(201,角色为 `user`) +- [ ] 但 `setup-status` 仍为 `{"needs_setup": true}`,因为 admin 仍不存在 +- [ ] 这是一个产品策略边界:若要求“必须先有 admin”,需要在 `/register` 增加 admin-exists gate -#### TC-UPG-16: 休眠 admin 不影响后续操作 +#### TC-UPG-16: 普通用户数据与后续 admin 隔离 ```bash # 普通用户正常创建 thread、发消息 @@ -849,14 +885,13 @@ curl -s -X POST $BASE/api/threads \ -d '{"metadata":{}}' | jq .thread_id ``` -**预期:** 正常创建,不受休眠 admin 影响 +**预期:** 普通用户正常创建 thread;后续 admin 创建后,搜索不到该普通用户 thread -#### TC-UPG-17: 休眠 admin 最终完成 Setup +#### TC-UPG-17: reset_admin 后完成 Setup ```bash -# 运维终于登录 curl -s -X POST $BASE/api/v1/auth/login/local \ - -d "username=admin@deerflow.dev&password=" \ + -d "username=admin@example.com&password=<凭据文件密码>" \ -c admin.txt | jq .needs_setup # 预期: true @@ -866,7 +901,7 @@ curl -s -X POST $BASE/api/v1/auth/change-password \ -b admin.txt \ -H "Content-Type: application/json" \ -H "X-CSRF-Token: $CSRF" \ - -d '{"current_password":"<密码>","new_password":"AdminFinal1!","new_email":"admin@real.com"}' \ + -d '{"current_password":"<凭据文件密码>","new_password":"AdminFinal1!","new_email":"admin@real.com"}' \ -c admin.txt # 验证 @@ -876,7 +911,7 @@ curl -s $BASE/api/v1/auth/me -b admin.txt | jq '{email, needs_setup}' **预期:** - [ ] `email` 变为 `admin@real.com` - [ ] `needs_setup` 变为 `false` -- [ ] 后续重启控制台不再有 warning +- [ ] 后续登录使用新密码 #### TC-UPG-18: 长期未用后 JWT 密钥轮换 @@ -890,8 +925,8 @@ make stop && make dev **预期:** - [ ] 服务正常启动 -- [ ] 旧密码仍可登录(密码存在 DB,与 JWT 密钥无关) -- [ ] 旧的 JWT token 失效(密钥变了签名不匹配)— 但因为从未登录过也没有旧 token +- [ ] 账号密码仍可登录(密码存在 DB,与 JWT 密钥无关) +- [ ] 旧的 JWT token 失效(密钥变了签名不匹配) --- @@ -910,7 +945,7 @@ for i in 1 2 3; do done # 检查 admin 数量 -sqlite3 backend/.deer-flow/users.db \ +sqlite3 backend/.deer-flow/data/deerflow.db \ "SELECT COUNT(*) FROM users WHERE system_role='admin';" ``` @@ -1055,7 +1090,7 @@ curl -s -X POST $BASE/api/v1/auth/register \ wait # 检查用户数 -sqlite3 backend/.deer-flow/users.db \ +sqlite3 backend/.deer-flow/data/deerflow.db \ "SELECT COUNT(*) FROM users WHERE email='race@example.com';" ``` @@ -1165,13 +1200,16 @@ curl -s -w "%{http_code}" -X DELETE "$BASE/api/threads/$TID" \ ```bash cd backend python -m app.gateway.auth.reset_admin -# 记录密码 P1 +cp .deer-flow/admin_initial_credentials.txt /tmp/deerflow-reset-p1.txt +P1=$(awk -F': ' '/^password:/ {print $2}' /tmp/deerflow-reset-p1.txt) python -m app.gateway.auth.reset_admin -# 记录密码 P2 +cp .deer-flow/admin_initial_credentials.txt /tmp/deerflow-reset-p2.txt +P2=$(awk -F': ' '/^password:/ {print $2}' /tmp/deerflow-reset-p2.txt) ``` **预期:** +- [ ] `.deer-flow/admin_initial_credentials.txt` 每次都会被重写,文件权限为 `0600` - [ ] P1 ≠ P2(每次生成新随机密码) - [ ] P1 不可用,只有 P2 有效 - [ ] `token_version` 递增了 2 @@ -1324,7 +1362,8 @@ done ```bash GW=http://localhost:8001 -for path in /health /api/v1/auth/setup-status /api/v1/auth/login/local /api/v1/auth/register; do +for path in /health /api/v1/auth/setup-status /api/v1/auth/login/local \ + /api/v1/auth/register /api/v1/auth/initialize /api/v1/auth/logout; do echo "$path: $(curl -s -w '%{http_code}' -o /dev/null $GW$path)" done # 预期: 200 或 405/422(方法不对但不是 401) @@ -1399,9 +1438,9 @@ done > > 前置条件: > - `.env` 中设置 `AUTH_JWT_SECRET`(否则每次容器重启 session 全部失效) -> - `DEER_FLOW_HOME` 挂载到宿主机目录(持久化 `users.db`) +> - `DEER_FLOW_HOME` 挂载到宿主机目录(持久化 `deerflow.db`) -#### TC-DOCKER-01: users.db 通过 volume 持久化 +#### TC-DOCKER-01: deerflow.db 通过 volume 持久化 ```bash # 启动容器 @@ -1416,13 +1455,13 @@ curl -s -X POST $BASE/api/v1/auth/register \ -H "Content-Type: application/json" \ -d '{"email":"docker-test@example.com","password":"DockerTest1!"}' -w "\nHTTP %{http_code}" -# 检查宿主机上的 users.db -ls -la ${DEER_FLOW_HOME:-backend/.deer-flow}/users.db -sqlite3 ${DEER_FLOW_HOME:-backend/.deer-flow}/users.db \ +# 检查宿主机上的 deerflow.db +ls -la ${DEER_FLOW_HOME:-backend/.deer-flow}/data/deerflow.db +sqlite3 ${DEER_FLOW_HOME:-backend/.deer-flow}/data/deerflow.db \ "SELECT email FROM users WHERE email='docker-test@example.com';" ``` -**预期:** users.db 在宿主机 `DEER_FLOW_HOME` 目录中,查询可见刚注册的用户。 +**预期:** deerflow.db 在宿主机 `DEER_FLOW_HOME` 目录中,查询可见刚注册的用户。 #### TC-DOCKER-02: 重启容器后 session 保持 @@ -1466,22 +1505,24 @@ done **已知限制:** In-process rate limiter 不跨 worker 共享。生产环境如需精确限速,需要 Redis 等外部存储。 -#### TC-DOCKER-04: IM 渠道不经过 auth +#### TC-DOCKER-04: IM 渠道使用内部认证 ```bash -# IM 渠道(Feishu/Slack/Telegram)在 gateway 容器内部通过 LangGraph SDK 通信 -# 不走 nginx,不经过 AuthMiddleware +# IM 渠道(Feishu/Slack/Telegram)在 gateway 容器内部通过 LangGraph SDK 调 Gateway +# 请求携带 process-local internal auth header,并带匹配的 CSRF cookie/header # 验证方式:检查 gateway 日志中 channel manager 的请求不包含 auth 错误 docker logs deer-flow-gateway 2>&1 | grep -E "ChannelManager|channel" | head -10 ``` -**预期:** 无 auth 相关错误。渠道通过 `langgraph-sdk` 直连 LangGraph Server(`http://langgraph:2024`),不走 auth 层。 +**预期:** 无 auth 相关错误。渠道不依赖浏览器 cookie;服务端通过内部认证头把请求归入 `default` 用户桶。 -#### TC-DOCKER-05: admin 密码写入 0600 凭证文件(不再走日志) +#### TC-DOCKER-05: reset_admin 密码写入 0600 凭证文件(不再走日志) ```bash -# 凭证文件写在挂载到宿主机的 DEER_FLOW_HOME 下 +# 首次启动不会自动生成 admin 密码。先重置已有 admin,凭据文件写在挂载到宿主机的 DEER_FLOW_HOME 下。 +docker exec deer-flow-gateway python -m app.gateway.auth.reset_admin --email docker-test@example.com + ls -la ${DEER_FLOW_HOME:-backend/.deer-flow}/admin_initial_credentials.txt # 预期文件权限: -rw------- (0600) @@ -1512,14 +1553,15 @@ sleep 15 docker ps --filter name=deer-flow-langgraph --format '{{.Names}}' | wc -l # 预期: 0 -# auth 流程正常 +# auth 流程正常:未登录受保护接口返回 401 curl -s -w "%{http_code}" -o /dev/null $BASE/api/models # 预期: 401 -curl -s -X POST $BASE/api/v1/auth/login/local \ - -d "username=admin@deerflow.dev&password=<日志密码>" \ +curl -s -X POST $BASE/api/v1/auth/initialize \ + -H "Content-Type: application/json" \ + -d '{"email":"admin@example.com","password":"AdminPass1!"}' \ -c cookies.txt -w "\nHTTP %{http_code}" -# 预期: 200 +# 预期: 201 ``` ### 7.4 补充边界用例 @@ -1587,13 +1629,15 @@ curl -s -D - -X POST $BASE/api/v1/auth/login/local \ #### TC-EDGE-05: HTTP 无 max_age / HTTPS 有 max_age ```bash +GW=http://localhost:8001 + # HTTP -curl -s -D - -X POST $BASE/api/v1/auth/login/local \ +curl -s -D - -X POST $GW/api/v1/auth/login/local \ -d "username=admin@example.com&password=正确密码" 2>/dev/null \ | grep "access_token=" | grep -oi "max-age=[0-9]*" || echo "NO max-age (HTTP session cookie)" -# HTTPS -curl -s -D - -X POST $BASE/api/v1/auth/login/local \ +# HTTPS:直连 Gateway 才能用 X-Forwarded-Proto 模拟 HTTPS;nginx 会覆盖该 header +curl -s -D - -X POST $GW/api/v1/auth/login/local \ -H "X-Forwarded-Proto: https" \ -d "username=admin@example.com&password=正确密码" 2>/dev/null \ | grep "access_token=" | grep -oi "max-age=[0-9]*" @@ -1712,10 +1756,10 @@ curl -s -X POST $BASE/api/threads \ -b cookies.txt \ -H "Content-Type: application/json" \ -H "X-CSRF-Token: $CSRF" \ - -d '{"metadata":{"owner_id":"victim-user-id"}}' | jq .metadata.owner_id + -d '{"metadata":{"owner_id":"victim-user-id","user_id":"victim-user-id"}}' | jq .metadata ``` -**预期:** 返回的 `metadata.owner_id` 应为当前登录用户的 ID,不是请求中注入的 `victim-user-id`。服务端应覆盖客户端提供的 `user_id`。 +**预期:** 返回的 `metadata` 不包含 `owner_id` 或 `user_id`。真实所有权写入 `threads_meta.user_id`,不从客户端 metadata 接收,也不通过 metadata 回显。 #### 7.5.6 HTTP Method 探测 @@ -1796,6 +1840,6 @@ cd backend && PYTHONPATH=. uv run pytest \ # 核心接口冒烟 curl -s $BASE/health # 200 curl -s $BASE/api/models # 401 (无 cookie) -curl -s -X POST $BASE/api/v1/auth/setup-status # 200 +curl -s $BASE/api/v1/auth/setup-status # 200 curl -s $BASE/api/v1/auth/me -b cookies.txt # 200 (有 cookie) ``` diff --git a/backend/docs/AUTH_UPGRADE.md b/backend/docs/AUTH_UPGRADE.md index 344c488c4..75fe8b3cb 100644 --- a/backend/docs/AUTH_UPGRADE.md +++ b/backend/docs/AUTH_UPGRADE.md @@ -2,13 +2,16 @@ DeerFlow 内置了认证模块。本文档面向从无认证版本升级的用户。 +完整设计见 [AUTH_DESIGN.md](AUTH_DESIGN.md)。 + ## 核心概念 认证模块采用**始终强制**策略: -- 首次启动时自动创建 admin 账号,随机密码打印到控制台日志 +- 首次启动时不会自动创建账号;首次访问 `/setup` 时由操作者创建第一个 admin 账号 - 认证从一开始就是强制的,无竞争窗口 -- 历史对话(升级前创建的 thread)自动迁移到 admin 名下 +- 已有 admin 后,服务启动时会把历史对话(升级前创建且缺少 `user_id` 的 thread)迁移到 admin 名下 +- 新数据按用户隔离:thread、workspace/uploads/outputs、memory、自定义 agent 都归属当前用户 ## 升级步骤 @@ -25,39 +28,41 @@ cd backend && make install make dev ``` -控制台会输出: +如果没有 admin 账号,控制台只会提示: ``` ============================================================ - Admin account created on first boot - Email: admin@deerflow.dev - Password: aB3xK9mN_pQ7rT2w - Change it after login: Settings → Account + First boot detected — no admin account exists. + Visit /setup to complete admin account creation. ============================================================ ``` -如果未登录就重启了服务,不用担心——只要 setup 未完成,每次启动都会重置密码并重新打印到控制台。 +首次启动不会在日志里打印随机密码,也不会写入默认 admin。这样避免启动日志泄露凭据,也避免在操作者创建账号前出现可被猜测的默认身份。 -### 3. 登录 +### 3. 创建 admin -访问 `http://localhost:2026/login`,使用控制台输出的邮箱和密码登录。 +访问 `http://localhost:2026/setup`,填写邮箱和密码创建第一个 admin 账号。创建成功后会自动登录并进入 workspace。 -### 4. 修改密码 +如果这是从无认证版本升级,创建 admin 后重启一次服务,让启动迁移把缺少 `user_id` 的历史 thread 归属到 admin。 -登录后进入 Settings → Account → Change Password。 +### 4. 登录 + +后续访问 `http://localhost:2026/login`,使用已创建的邮箱和密码登录。 ### 5. 添加用户(可选) -其他用户通过 `/login` 页面注册,自动获得 **user** 角色。每个用户只能看到自己的对话。 +其他用户通过 `/login` 页面注册,自动获得 **user** 角色。每个用户只能看到自己的对话、上传文件、输出文件、memory 和自定义 agent。 ## 安全机制 | 机制 | 说明 | |------|------| | JWT HttpOnly Cookie | Token 不暴露给 JavaScript,防止 XSS 窃取 | -| CSRF Double Submit Cookie | 所有 POST/PUT/DELETE 请求需携带 `X-CSRF-Token` | +| CSRF Double Submit Cookie | 受保护的 POST/PUT/PATCH/DELETE 请求需携带 `X-CSRF-Token`;登录/注册/初始化/登出走 auth 端点 Origin 校验 | | bcrypt 密码哈希 | 密码不以明文存储 | -| 多租户隔离 | 用户只能访问自己的 thread | +| Thread owner filter | `threads_meta.user_id` 由服务端认证上下文写入,搜索、读取、更新、删除默认按当前用户过滤 | +| 文件系统隔离 | 线程数据写入 `{base_dir}/users/{user_id}/threads/{thread_id}/user-data/`,sandbox 内统一映射为 `/mnt/user-data/` | +| Memory / agent 隔离 | 用户 memory 和自定义 agent 写入 `{base_dir}/users/{user_id}/...`;旧共享 agent 只作为只读兼容回退 | | HTTPS 自适应 | 检测 `x-forwarded-proto`,自动设置 `Secure` cookie 标志 | ## 常见操作 @@ -74,22 +79,26 @@ python -m app.gateway.auth.reset_admin python -m app.gateway.auth.reset_admin --email user@example.com ``` -会输出新的随机密码。 +会把新的随机密码写入 `.deer-flow/admin_initial_credentials.txt`,文件权限为 `0600`。命令行只输出文件路径,不输出明文密码。 ### 完全重置 -删除用户数据库,重启后自动创建新 admin: +删除统一 SQLite 数据库,重启后重新访问 `/setup` 创建新 admin: ```bash -rm -f backend/.deer-flow/users.db -# 重启服务,控制台输出新密码 +rm -f backend/.deer-flow/data/deerflow.db +# 重启服务后访问 http://localhost:2026/setup ``` ## 数据存储 | 文件 | 内容 | |------|------| -| `.deer-flow/users.db` | SQLite 用户数据库(密码哈希、角色) | +| `.deer-flow/data/deerflow.db` | 统一 SQLite 数据库(users、threads_meta、runs、feedback 等应用数据) | +| `.deer-flow/users/{user_id}/threads/{thread_id}/user-data/` | 用户线程的 workspace、uploads、outputs | +| `.deer-flow/users/{user_id}/memory.json` | 用户级 memory | +| `.deer-flow/users/{user_id}/agents/{agent_name}/` | 用户自定义 agent 配置、SOUL 和 agent memory | +| `.deer-flow/admin_initial_credentials.txt` | `reset_admin` 生成的新凭据文件(0600,读完应删除) | | `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成临时密钥,重启后 session 失效) | ### 生产环境建议 @@ -111,19 +120,21 @@ python -c "import secrets; print(secrets.token_urlsafe(32))" | `/api/v1/auth/me` | GET | 获取当前用户信息 | | `/api/v1/auth/change-password` | POST | 修改密码 | | `/api/v1/auth/setup-status` | GET | 检查 admin 是否存在 | +| `/api/v1/auth/initialize` | POST | 首次初始化第一个 admin(仅无 admin 时可调用) | ## 兼容性 -- **标准模式**(`make dev`):完全兼容,admin 自动创建 +- **标准模式**(`make dev`):完全兼容;无 admin 时访问 `/setup` 初始化 - **Gateway 模式**(`make dev-pro`):完全兼容 -- **Docker 部署**:完全兼容,`.deer-flow/users.db` 需持久化卷挂载 -- **IM 渠道**(Feishu/Slack/Telegram):通过 LangGraph SDK 通信,不经过认证层 +- **Docker 部署**:完全兼容,`.deer-flow/data/deerflow.db` 需持久化卷挂载 +- **IM 渠道**(Feishu/Slack/Telegram):通过 Gateway 内部认证通信,使用 `default` 用户桶 - **DeerFlowClient**(嵌入式):不经过 HTTP,不受认证影响 ## 故障排查 | 症状 | 原因 | 解决 | |------|------|------| -| 启动后没看到密码 | admin 已存在(非首次启动) | 用 `reset_admin` 重置,或删 `users.db` | +| 启动后没看到密码 | 当前实现不在启动日志输出密码 | 首次安装访问 `/setup`;忘记密码用 `reset_admin` | +| `/login` 自动跳到 `/setup` | 系统还没有 admin | 在 `/setup` 创建第一个 admin | | 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 | | 重启后需要重新登录 | `AUTH_JWT_SECRET` 未持久化 | 在 `.env` 中设置固定密钥 | diff --git a/backend/docs/README.md b/backend/docs/README.md index da566005d..27e33f854 100644 --- a/backend/docs/README.md +++ b/backend/docs/README.md @@ -8,6 +8,7 @@ This directory contains detailed documentation for the DeerFlow backend. |----------|-------------| | [ARCHITECTURE.md](ARCHITECTURE.md) | System architecture overview | | [API.md](API.md) | Complete API reference | +| [AUTH_DESIGN.md](AUTH_DESIGN.md) | User authentication, CSRF, and per-user isolation design | | [CONFIGURATION.md](CONFIGURATION.md) | Configuration options | | [SETUP.md](SETUP.md) | Quick setup guide | @@ -42,6 +43,7 @@ docs/ ├── README.md # This file ├── ARCHITECTURE.md # System architecture ├── API.md # API reference +├── AUTH_DESIGN.md # User authentication and isolation design ├── CONFIGURATION.md # Configuration guide ├── SETUP.md # Setup instructions ├── FILE_UPLOAD.md # File upload feature From 506be8bffda8413ee0506f198ff47def931294db Mon Sep 17 00:00:00 2001 From: AochenShen99 <142667174+ShenAC-SAC@users.noreply.github.com> Date: Tue, 12 May 2026 23:15:11 +0800 Subject: [PATCH 15/86] docs: clarify LangGraph compatibility entrypoints (#2914) --- backend/README.md | 4 ++++ backend/app/gateway/langgraph_auth.py | 12 ++++++++---- backend/docs/ARCHITECTURE.md | 3 ++- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/backend/README.md b/backend/README.md index 18d89c2be..8c61e2db2 100644 --- a/backend/README.md +++ b/backend/README.md @@ -242,6 +242,10 @@ backend/ └── Dockerfile # Container build ``` +`langgraph.json` is not the default service entrypoint. The scripts and Docker +deployments run the Gateway embedded runtime; the file is kept for LangGraph +tooling, Studio, or direct LangGraph Server compatibility. + --- ## Configuration diff --git a/backend/app/gateway/langgraph_auth.py b/backend/app/gateway/langgraph_auth.py index 38e020150..202fab2d5 100644 --- a/backend/app/gateway/langgraph_auth.py +++ b/backend/app/gateway/langgraph_auth.py @@ -1,8 +1,12 @@ -"""LangGraph Server auth handler — shares JWT logic with Gateway. +"""LangGraph compatibility auth handler — shares JWT logic with Gateway. -Loaded by LangGraph Server via langgraph.json ``auth.path``. -Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway, -so both modes validate tokens with the same secret and rules. +The default DeerFlow runtime is embedded in the FastAPI Gateway; scripts and +Docker deployments do not load this module. It is retained for LangGraph +tooling, Studio, or direct LangGraph Server compatibility through +``langgraph.json``'s ``auth.path``. + +When that compatibility path is used, this module reuses the same JWT and CSRF +rules as Gateway so both modes validate sessions consistently. Two layers: 1. @auth.authenticate — validates JWT cookie, extracts user_id, diff --git a/backend/docs/ARCHITECTURE.md b/backend/docs/ARCHITECTURE.md index f1557a6fb..47859cc9c 100644 --- a/backend/docs/ARCHITECTURE.md +++ b/backend/docs/ARCHITECTURE.md @@ -63,7 +63,8 @@ The agent runtime is embedded in the FastAPI Gateway and built on LangGraph for - Tool execution orchestration - SSE streaming for real-time responses -**Graph registry**: `langgraph.json` remains available for tooling and Studio compatibility. +**Graph registry**: `langgraph.json` remains available for tooling, Studio, or direct LangGraph Server compatibility. +It is not the default service entrypoint; scripts and Docker deployments run the Gateway embedded runtime. ```json { From 68d8caec1f6b543fa7936d8a0c382f33726e00b0 Mon Sep 17 00:00:00 2001 From: Xinmin Zeng <135568692+fancyboi999@users.noreply.github.com> Date: Tue, 12 May 2026 23:18:54 +0800 Subject: [PATCH 16/86] fix(agents): make update_agent honor runtime.context user_id like setup_agent (#2867) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(agents): make update_agent honor runtime.context user_id like setup_agent PR #2784 hardened setup_agent to prefer runtime.context["user_id"] (set by inject_authenticated_user_context from the auth-validated request) over the contextvar, so an agent created during the bootstrap flow always lands under users//agents/. update_agent was left calling get_effective_user_id() unconditionally — the same class of bug that produced issues #2782 / #2862 still applies whenever the contextvar is not available on the executing task (background work, future cross-process drivers, checkpoint resume on a different task). In that regime update_agent silently routes writes to users/default/agents/, corrupting the shared default bucket and losing the user's edit. Extract the resolution policy into a shared resolve_runtime_user_id helper on deerflow.runtime.user_context and route both setup_agent and update_agent through it so the two halves of the lifecycle stay in lockstep. Add load-bearing end-to-end tests that drive a real langchain.agents create_agent graph with a fake LLM, exercising the full pipeline: HTTP wire format -> app.gateway.services.start_run config-assembly -> deerflow.runtime.runs.worker._build_runtime_context -> langchain.agents create_agent graph -> ToolNode dispatch (sync + async + sub-graph + ContextThreadPoolExecutor) -> setup_agent / update_agent The negative-control tests intentionally land in users/default/ to prove the positive tests are actually load-bearing rather than vacuously passing. The new test_update_agent_e2e_user_isolation suite included a test that failed against main and now passes after this fix. * style: ruff format on new e2e tests * test(e2e): real-server HTTP test driving setup_agent through the full ASGI stack Adds tests/test_setup_agent_http_e2e_real_server.py — a single load-bearing test that drives the entire FastAPI gateway through starlette.testclient. TestClient with no mocks above the LLM: - lifespan boots (config, sqlite engine, LangGraph runtime, channels) - POST /api/v1/auth/register (real password hash, real sqlite write, issues access_token + csrf_token cookies) - POST /api/threads (real thread_meta + checkpoint creation) - POST /api/threads/{id}/runs/stream with the exact wire shape the React frontend sends (assistant_id + input + config + context with agent_name/is_bootstrap) - AuthMiddleware -> CSRFMiddleware -> require_permission -> start_run -> inject_authenticated_user_context -> asyncio.create_task(run_agent) -> worker._build_runtime_context -> Runtime injection -> ToolNode dispatch -> real setup_agent - Asserts SOUL.md is under users//agents// and NOT under users/default/agents//. DEER_FLOW_HOME and the sqlite path are redirected into tmp_path so the test never touches the real .deer-flow directory or developer database. The only patch above the LLM boundary is replacing create_chat_model with a fake that emits a single setup_agent tool_call. This is the "真实验证" answer: it reproduces what curl-against-uvicorn would do, minus the network socket layer. * test: address Copilot review on user-isolation e2e tests - Drop "currently expected to FAIL" wording from update_agent e2e docstring and header (Copilot review): the fix is in this PR, the test pins the corrected behaviour rather than driving a future change. - Rephrase the assertion failure messages from "BUG:" to "REGRESSION:" to match the test's role on the fixed branch. - Bound _drain_stream with a wall-clock timeout, a max-bytes cap, and an early break on the "event: end" SSE frame (Copilot review). Stops the test from hanging on a stuck run or runaway heartbeat loop. - Replace the misleading "patch both module aliases" comment with an explanation of why patching lead_agent.agent.create_chat_model is the only correct target (Copilot review): lead_agent rebinds the symbol into its own namespace at import time, so patching deerflow.models is too late. * test(refactor): address WillemJiang review on user-isolation e2e tests - Extract the duplicated FakeToolCallingModel (and a build_single_tool_call_model helper) into tests/_agent_e2e_helpers.py. All three e2e files now import from the shared module instead of redefining the shim locally. - Convert the manual p.start() / p.stop() try/finally blocks in test_update_agent_e2e_user_isolation.py to contextlib.ExitStack so patch lifecycle is Pythonic and exception-safe. - Lift the isolated_app fixture's private-attribute resets into a named _reset_process_singletons helper with a comment block explaining why each singleton has to be invalidated for true e2e isolation, and why raising=False is intentional. Makes the fragility visible and the intent self-documenting rather than leaving the resets inline as opaque monkeypatch calls. Net change: -59 lines (143 -> 84) across the three test files, with every assertion intact. Full suite remains 69 passed / lint clean. * test(e2e): make real-server test self-supply its config CI's actions/checkout only ships config.example.yaml (the real config.yaml is gitignored), so the production config-discovery search (./config.yaml -> ../config.yaml -> $DEER_FLOW_CONFIG_PATH) finds nothing and the test fails at lifespan boot with FileNotFoundError. The dev-machine run passed only because a local config.yaml happened to exist. Write a minimal AppConfig-valid yaml into tmp_path and pin DEER_FLOW_CONFIG_PATH to it. The yaml carries just what the schema requires (a single fake-test-model entry, LocalSandboxProvider, sqlite database). The LLM never gets instantiated because the test patches create_chat_model on the lead agent module, so the api_key/base_url stay placeholders. Verified by hiding the local config.yaml to mirror the CI checkout — the test now passes in both environments. --- .../harness/deerflow/runtime/user_context.py | 28 ++ .../tools/builtins/setup_agent_tool.py | 11 +- .../tools/builtins/update_agent_tool.py | 12 +- backend/tests/_agent_e2e_helpers.py | 68 +++ .../test_setup_agent_e2e_user_isolation.py | 429 ++++++++++++++++++ .../test_setup_agent_http_e2e_real_server.py | 326 +++++++++++++ .../test_update_agent_e2e_user_isolation.py | 253 +++++++++++ 7 files changed, 1114 insertions(+), 13 deletions(-) create mode 100644 backend/tests/_agent_e2e_helpers.py create mode 100644 backend/tests/test_setup_agent_e2e_user_isolation.py create mode 100644 backend/tests/test_setup_agent_http_e2e_real_server.py create mode 100644 backend/tests/test_update_agent_e2e_user_isolation.py diff --git a/backend/packages/harness/deerflow/runtime/user_context.py b/backend/packages/harness/deerflow/runtime/user_context.py index ffe4be690..cfbb68c94 100644 --- a/backend/packages/harness/deerflow/runtime/user_context.py +++ b/backend/packages/harness/deerflow/runtime/user_context.py @@ -109,6 +109,34 @@ def get_effective_user_id() -> str: return str(user.id) +def resolve_runtime_user_id(runtime: object | None) -> str: + """Single source of truth for a tool/middleware's effective user_id. + + Resolution order (most authoritative first): + 1. ``runtime.context["user_id"]`` — set by ``inject_authenticated_user_context`` + in the gateway from the auth-validated ``request.state.user``. This is + the only source that survives boundaries where the contextvar may have + been lost (background tasks scheduled outside the request task, + worker pools that don't copy_context, future cross-process drivers). + 2. The ``_current_user`` ContextVar — set by the auth middleware at + request entry. Reliable for in-task work; copied by ``asyncio`` + child tasks and by ``ContextThreadPoolExecutor``. + 3. ``DEFAULT_USER_ID`` — last-resort fallback so unauthenticated + CLI / migration / test paths keep working without raising. + + Tools that persist user-scoped state (custom agents, memory, uploads) + MUST call this instead of ``get_effective_user_id()`` directly so they + benefit from the runtime.context channel that ``setup_agent`` already + relies on. + """ + context = getattr(runtime, "context", None) + if isinstance(context, dict): + ctx_user_id = context.get("user_id") + if ctx_user_id: + return str(ctx_user_id) + return get_effective_user_id() + + # --------------------------------------------------------------------------- # Sentinel-based user_id resolution # --------------------------------------------------------------------------- diff --git a/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py b/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py index 2f796b005..dfbcf8b6e 100644 --- a/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py @@ -7,19 +7,12 @@ from langgraph.types import Command from deerflow.config.agents_config import validate_agent_name from deerflow.config.paths import get_paths -from deerflow.runtime.user_context import get_effective_user_id +from deerflow.runtime.user_context import resolve_runtime_user_id from deerflow.tools.types import Runtime logger = logging.getLogger(__name__) -def _get_runtime_user_id(runtime: Runtime) -> str: - context_user_id = runtime.context.get("user_id") if runtime.context else None - if context_user_id: - return str(context_user_id) - return get_effective_user_id() - - @tool(parse_docstring=True) def setup_agent( soul: str, @@ -45,7 +38,7 @@ def setup_agent( if agent_name: # Custom agents are persisted under the current user's bucket so # different users do not see each other's agents. - user_id = _get_runtime_user_id(runtime) + user_id = resolve_runtime_user_id(runtime) agent_dir = paths.user_agent_dir(user_id, agent_name) else: # Default agent (no agent_name): SOUL.md lives at the global base dir. diff --git a/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py b/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py index b2dc8ca72..18500a248 100644 --- a/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py @@ -27,7 +27,7 @@ from langgraph.types import Command from deerflow.config.agents_config import load_agent_config, validate_agent_name from deerflow.config.app_config import get_app_config from deerflow.config.paths import get_paths -from deerflow.runtime.user_context import get_effective_user_id +from deerflow.runtime.user_context import resolve_runtime_user_id from deerflow.tools.types import Runtime logger = logging.getLogger(__name__) @@ -118,9 +118,13 @@ def update_agent( return _err("update_agent is only available inside a custom agent's chat. There is no agent_name in the current runtime context, so there is nothing to update. If you are inside the bootstrap flow, use setup_agent instead.") # Resolve the active user so that updates only affect this user's agent. - # ``get_effective_user_id`` returns DEFAULT_USER_ID when no auth context - # is set (matching how memory and thread storage behave). - user_id = get_effective_user_id() + # ``resolve_runtime_user_id`` prefers ``runtime.context["user_id"]`` (set by + # the gateway from the auth-validated request) and falls back to the + # contextvar, then DEFAULT_USER_ID. This matches setup_agent so a user + # creating an agent and later refining it always touches the same files, + # even if the contextvar gets lost across an async/thread boundary + # (issue #2782 / #2862 class of bugs). + user_id = resolve_runtime_user_id(runtime) # Reject an unknown ``model`` *before* touching the filesystem. Otherwise # ``_resolve_model_name`` silently falls back to the default at runtime diff --git a/backend/tests/_agent_e2e_helpers.py b/backend/tests/_agent_e2e_helpers.py new file mode 100644 index 000000000..2f28390a9 --- /dev/null +++ b/backend/tests/_agent_e2e_helpers.py @@ -0,0 +1,68 @@ +"""Shared helpers for user-isolation e2e tests on the custom-agent tooling. + +Centralises the small fake-LLM shim and a few test-data builders that the +three e2e files in this PR (``test_setup_agent_e2e_user_isolation``, +``test_update_agent_e2e_user_isolation``, ``test_setup_agent_http_e2e_real_server``) +all need. The shim is what lets a real ``langchain.agents.create_agent`` +graph run without an API key — every other layer in those tests is real +production code, which is the entire point of the test design. +""" + +from __future__ import annotations + +from typing import Any + +from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel +from langchain_core.messages import AIMessage +from langchain_core.runnables import Runnable + + +class FakeToolCallingModel(FakeMessagesListChatModel): + """FakeMessagesListChatModel plus a no-op ``bind_tools`` for create_agent. + + ``langchain.agents.create_agent`` calls ``model.bind_tools(...)`` to + expose the tool schemas to the model; the upstream fake raises + ``NotImplementedError`` there. We just return ``self`` because we + drive deterministic tool_call output via ``responses=...``, no schema + handling needed. + """ + + def bind_tools( # type: ignore[override] + self, + tools: Any, + *, + tool_choice: Any = None, + **kwargs: Any, + ) -> Runnable: + return self + + +def build_single_tool_call_model( + *, + tool_name: str, + tool_args: dict[str, Any], + tool_call_id: str = "call_e2e_1", + final_text: str = "done", +) -> FakeToolCallingModel: + """Build a fake model that emits exactly one tool_call then finishes. + + Two-turn behaviour, identical across our e2e tests: + turn 1 → AIMessage with a single tool_call for *tool_name* + turn 2 → AIMessage with *final_text* (terminates the agent loop) + """ + return FakeToolCallingModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": tool_name, + "args": tool_args, + "id": tool_call_id, + "type": "tool_call", + } + ], + ), + AIMessage(content=final_text), + ] + ) diff --git a/backend/tests/test_setup_agent_e2e_user_isolation.py b/backend/tests/test_setup_agent_e2e_user_isolation.py new file mode 100644 index 000000000..034d4da84 --- /dev/null +++ b/backend/tests/test_setup_agent_e2e_user_isolation.py @@ -0,0 +1,429 @@ +"""End-to-end verification for issue #2862 (and the regression of #2782). + +Goal: prove — without trusting any single layer's claim — that an authenticated +user creating a custom agent through the real ``setup_agent`` tool, driven by a +real LangGraph ``create_agent`` graph, ends up with files under +``users//agents/`` and **not** under ``users/default/agents/...``. + +We intentionally exercise the full pipeline: + + HTTP body shape (mimics LangGraph SDK wire format) + -> app.gateway.services.start_run config-assembly chain + -> deerflow.runtime.runs.worker._build_runtime_context + -> langchain.agents.create_agent graph + -> ToolNode dispatch + -> setup_agent tool + +The only thing we mock is the LLM (FakeMessagesListChatModel) — every layer +that handles ``user_id`` is the real production code path. If the +``user_id`` propagation is broken anywhere in this chain, these tests will +fail. + +These tests intentionally ``no_auto_user`` so that the ``contextvar`` +fallback would put files into ``default/`` if propagation breaks. +""" + +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import patch +from uuid import UUID + +import pytest +from _agent_e2e_helpers import FakeToolCallingModel +from langchain_core.messages import AIMessage, HumanMessage + +from app.gateway.services import ( + build_run_config, + inject_authenticated_user_context, + merge_run_context_overrides, +) +from deerflow.runtime.runs.worker import _build_runtime_context, _install_runtime_context + +# --------------------------------------------------------------------------- +# Helpers — real production code paths +# --------------------------------------------------------------------------- + + +def _make_request(user_id_str: str | None) -> SimpleNamespace: + """Build a fake FastAPI Request that carries an authenticated user.""" + if user_id_str is None: + user = None + else: + # User.id is UUID in production; honour that + user = SimpleNamespace(id=UUID(user_id_str), email="alice@local") + return SimpleNamespace(state=SimpleNamespace(user=user)) + + +def _assemble_config( + *, + body_config: dict | None, + body_context: dict | None, + request_user_id: str | None, + thread_id: str = "thread-e2e", + assistant_id: str = "lead_agent", +) -> dict: + """Replay the **exact** start_run config-assembly sequence.""" + config = build_run_config(thread_id, body_config, None, assistant_id=assistant_id) + merge_run_context_overrides(config, body_context) + inject_authenticated_user_context(config, _make_request(request_user_id)) + return config + + +def _make_paths_mock(tmp_path: Path): + """Mirror the production paths.user_agent_dir signature.""" + from unittest.mock import MagicMock + + paths = MagicMock() + paths.base_dir = tmp_path + paths.agent_dir = lambda name: tmp_path / "agents" / name + paths.user_agent_dir = lambda user_id, name: tmp_path / "users" / user_id / "agents" / name + return paths + + +# --------------------------------------------------------------------------- +# L1-L3: HTTP wire format → start_run → worker._build_runtime_context +# --------------------------------------------------------------------------- + + +class TestConfigAssembly: + """Covers L1-L3: validate that user_id reaches runtime_ctx for every wire shape.""" + + def test_typical_wire_format_user_id_in_runtime_ctx(self): + """Real frontend: body.config={recursion_limit}, body.context={agent_name,...}.""" + config = _assemble_config( + body_config={"recursion_limit": 1000}, + body_context={"agent_name": "myagent", "is_bootstrap": True, "mode": "flash"}, + request_user_id="11111111-2222-3333-4444-555555555555", + ) + runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None) + assert runtime_ctx["user_id"] == "11111111-2222-3333-4444-555555555555" + assert runtime_ctx["agent_name"] == "myagent" + + def test_body_context_none_still_injects_user_id(self): + """If frontend omits body.context entirely, inject must still create it.""" + config = _assemble_config( + body_config={"recursion_limit": 1000}, + body_context=None, + request_user_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + ) + runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None) + assert runtime_ctx["user_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + def test_body_context_empty_dict_still_injects_user_id(self): + """body.context={} (falsy) path: inject must still produce user_id.""" + config = _assemble_config( + body_config={"recursion_limit": 1000}, + body_context={}, + request_user_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + ) + runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None) + assert runtime_ctx["user_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + def test_body_config_already_contains_context_field(self): + """body.config={'context': {...}} (LG 0.6 alt wire): inject still wins.""" + config = _assemble_config( + body_config={"context": {"agent_name": "myagent"}, "recursion_limit": 1000}, + body_context=None, + request_user_id="aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee", + ) + runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None) + assert runtime_ctx["user_id"] == "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + def test_client_supplied_user_id_is_overridden(self): + """Spoofed client user_id must be overwritten by inject (auth-trusted source).""" + config = _assemble_config( + body_config={"recursion_limit": 1000}, + body_context={"agent_name": "myagent", "user_id": "spoofed"}, + request_user_id="11111111-2222-3333-4444-555555555555", + ) + runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None) + assert runtime_ctx["user_id"] == "11111111-2222-3333-4444-555555555555" + + def test_unauthenticated_request_does_not_inject(self): + """If request.state.user is missing (impossible under fail-closed auth, but + verify defensively), inject must not write user_id and runtime_ctx must + therefore lack it — forcing the tool fallback path to reveal itself.""" + config = _assemble_config( + body_config={"recursion_limit": 1000}, + body_context={"agent_name": "myagent"}, + request_user_id=None, + ) + runtime_ctx = _build_runtime_context("thread-e2e", "run-1", config.get("context"), None) + assert "user_id" not in runtime_ctx + + +# --------------------------------------------------------------------------- +# L4-L7: Real LangGraph create_agent driving the real setup_agent tool +# --------------------------------------------------------------------------- + + +def _build_real_bootstrap_graph(authenticated_user_id: str): + """Construct a real LangGraph using create_agent + the real setup_agent tool. + + The LLM is faked (FakeMessagesListChatModel) so we don't need an API key. + Everything else — ToolNode dispatch, runtime injection, middleware — is + the real production code path. + """ + from langchain.agents import create_agent + + from deerflow.tools.builtins.setup_agent_tool import setup_agent + + # First model turn: emit a tool_call for setup_agent + # Second model turn (after tool result): final answer (terminates the loop) + fake_model = FakeToolCallingModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": "setup_agent", + "args": { + "soul": "# My E2E Agent\n\nA SOUL written by the model.", + "description": "End-to-end test agent", + }, + "id": "call_setup_1", + "type": "tool_call", + } + ], + ), + AIMessage(content=f"Done. Agent created for user {authenticated_user_id}."), + ] + ) + + graph = create_agent( + model=fake_model, + tools=[setup_agent], + system_prompt="You are a bootstrap agent. Call setup_agent immediately.", + ) + return graph + + +@pytest.mark.no_auto_user +@pytest.mark.asyncio +async def test_real_graph_real_setup_agent_writes_to_authenticated_user_dir(tmp_path: Path): + """The smoking-gun test for issue #2862. + + Under no_auto_user (contextvar = empty), if user_id propagation through + runtime.context is broken, setup_agent will fall back to DEFAULT_USER_ID + and write to users/default/agents/... The assertion that this directory + DOES NOT exist is what makes this test load-bearing. + """ + from langgraph.runtime import Runtime + + auth_uid = "abcdef01-2345-6789-abcd-ef0123456789" + config = _assemble_config( + body_config={"recursion_limit": 50}, + body_context={"agent_name": "e2e-agent", "is_bootstrap": True}, + request_user_id=auth_uid, + thread_id="thread-e2e-1", + ) + + # Replay worker.run_agent's runtime construction. This is the key step: + # it is what makes ToolRuntime.context contain user_id when the tool + # actually fires. + runtime_ctx = _build_runtime_context("thread-e2e-1", "run-1", config.get("context"), None) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + graph = _build_real_bootstrap_graph(auth_uid) + + # Patch get_paths only (the file-system rooting); everything else is real + with patch( + "deerflow.tools.builtins.setup_agent_tool.get_paths", + return_value=_make_paths_mock(tmp_path), + ): + # Drive the real graph. This goes through real ToolNode + real Runtime merge. + final_state = await graph.ainvoke( + {"messages": [HumanMessage(content="Create an agent named e2e-agent")]}, + config=config, + ) + + expected_dir = tmp_path / "users" / auth_uid / "agents" / "e2e-agent" + default_dir = tmp_path / "users" / "default" / "agents" / "e2e-agent" + + # Load-bearing assertions: + assert expected_dir.exists(), f"Agent directory not found at the authenticated user's path. Expected: {expected_dir}. tmp_path tree: {[str(p) for p in tmp_path.rglob('*')]}" + assert (expected_dir / "SOUL.md").read_text() == "# My E2E Agent\n\nA SOUL written by the model." + assert (expected_dir / "config.yaml").exists() + assert not default_dir.exists(), "REGRESSION: agent landed under users/default/. user_id propagation broke somewhere between HTTP layer and ToolRuntime.context." + + # And final state should reflect tool success + last = final_state["messages"][-1] + assert "Done" in (last.content if isinstance(last.content, str) else str(last.content)) + + +@pytest.mark.no_auto_user +@pytest.mark.asyncio +async def test_inject_failure_falls_back_to_default_proving_test_is_load_bearing(tmp_path: Path): + """Negative control: if inject does NOT happen (no user in request), and + contextvar is empty (no_auto_user), setup_agent must land in default/. + + This proves the positive test is actually load-bearing — i.e. it would + have failed before PR #2784, not passed accidentally. + """ + from langgraph.runtime import Runtime + + config = _assemble_config( + body_config={"recursion_limit": 50}, + body_context={"agent_name": "fallback-agent", "is_bootstrap": True}, + request_user_id=None, # no auth — inject is a no-op + thread_id="thread-e2e-2", + ) + + runtime_ctx = _build_runtime_context("thread-e2e-2", "run-2", config.get("context"), None) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + graph = _build_real_bootstrap_graph("does-not-matter") + + with patch( + "deerflow.tools.builtins.setup_agent_tool.get_paths", + return_value=_make_paths_mock(tmp_path), + ): + await graph.ainvoke( + {"messages": [HumanMessage(content="Create fallback-agent")]}, + config=config, + ) + + default_dir = tmp_path / "users" / "default" / "agents" / "fallback-agent" + assert default_dir.exists(), "Negative control failed: even without inject + contextvar, agent did not land in default/. The test infrastructure may not be reproducing the bug condition." + + +# --------------------------------------------------------------------------- +# L5: Sub-graph runtime propagation (the task tool case) +# --------------------------------------------------------------------------- + + +@pytest.mark.no_auto_user +@pytest.mark.asyncio +async def test_subgraph_invocation_preserves_user_id_in_runtime(tmp_path: Path): + """When a parent graph invokes a child graph (the pattern used by + subagents), parent_runtime.merge() must keep user_id intact. + + We construct a child graph that contains setup_agent and call it from + a parent graph's tool. If LangGraph re-creates the Runtime and drops + user_id at the sub-graph boundary, this fails. + """ + from langchain.agents import create_agent + from langgraph.runtime import Runtime + + from deerflow.tools.builtins.setup_agent_tool import setup_agent + + auth_uid = "deadbeef-0000-1111-2222-333344445555" + + # Inner graph: same as the bootstrap flow + inner_model = FakeToolCallingModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": "setup_agent", + "args": {"soul": "# Inner", "description": "subgraph"}, + "id": "call_inner_1", + "type": "tool_call", + } + ], + ), + AIMessage(content="inner done"), + ] + ) + inner_graph = create_agent( + model=inner_model, + tools=[setup_agent], + system_prompt="inner", + ) + + config = _assemble_config( + body_config={"recursion_limit": 50}, + body_context={"agent_name": "subgraph-agent", "is_bootstrap": True}, + request_user_id=auth_uid, + thread_id="thread-e2e-3", + ) + runtime_ctx = _build_runtime_context("thread-e2e-3", "run-3", config.get("context"), None) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + with patch( + "deerflow.tools.builtins.setup_agent_tool.get_paths", + return_value=_make_paths_mock(tmp_path), + ): + # Direct sub-graph invoke (mimics what a subagent invocation looks like + # — distinct ainvoke call, but parent config carries the same runtime). + await inner_graph.ainvoke( + {"messages": [HumanMessage(content="Create subgraph-agent")]}, + config=config, + ) + + expected_dir = tmp_path / "users" / auth_uid / "agents" / "subgraph-agent" + default_dir = tmp_path / "users" / "default" / "agents" / "subgraph-agent" + assert expected_dir.exists() + assert not default_dir.exists() + + +# --------------------------------------------------------------------------- +# L6: Sync tool path through ContextThreadPoolExecutor +# --------------------------------------------------------------------------- + + +def test_sync_tool_dispatch_through_thread_pool_uses_runtime_context(tmp_path: Path): + """setup_agent is a sync function. When dispatched through ToolNode's + ContextThreadPoolExecutor, runtime.context must still carry user_id — + not via thread-local copy_context (which only carries contextvars), but + because it was passed in as the ToolRuntime constructor argument. + """ + from langchain.agents import create_agent + from langgraph.runtime import Runtime + + from deerflow.tools.builtins.setup_agent_tool import setup_agent + + auth_uid = "11112222-3333-4444-5555-666677778888" + + fake_model = FakeToolCallingModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": "setup_agent", + "args": {"soul": "# Sync", "description": "sync path"}, + "id": "call_sync_1", + "type": "tool_call", + } + ], + ), + AIMessage(content="sync done"), + ] + ) + graph = create_agent(model=fake_model, tools=[setup_agent], system_prompt="sync") + + config = _assemble_config( + body_config={"recursion_limit": 50}, + body_context={"agent_name": "sync-agent", "is_bootstrap": True}, + request_user_id=auth_uid, + thread_id="thread-e2e-4", + ) + runtime_ctx = _build_runtime_context("thread-e2e-4", "run-4", config.get("context"), None) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + with patch( + "deerflow.tools.builtins.setup_agent_tool.get_paths", + return_value=_make_paths_mock(tmp_path), + ): + # Use SYNC invoke to hit the ContextThreadPoolExecutor path + graph.invoke( + {"messages": [HumanMessage(content="Create sync-agent")]}, + config=config, + ) + + expected_dir = tmp_path / "users" / auth_uid / "agents" / "sync-agent" + default_dir = tmp_path / "users" / "default" / "agents" / "sync-agent" + assert expected_dir.exists() + assert not default_dir.exists() diff --git a/backend/tests/test_setup_agent_http_e2e_real_server.py b/backend/tests/test_setup_agent_http_e2e_real_server.py new file mode 100644 index 000000000..950d040a0 --- /dev/null +++ b/backend/tests/test_setup_agent_http_e2e_real_server.py @@ -0,0 +1,326 @@ +"""Real HTTP end-to-end verification for issue #2862's setup_agent path. + +This test drives the **entire** FastAPI gateway through ``starlette.testclient.TestClient``: + + starlette.testclient.TestClient (real ASGI stack) + -> AuthMiddleware (real cookie parsing, real JWT decode) + -> /api/v1/auth/register endpoint (real password hash + sqlite write) + -> /api/threads/{id}/runs/stream endpoint (real start_run config-assembly) + -> background asyncio.create_task(run_agent) (real worker, real Runtime) + -> langchain.agents.create_agent graph (real, with fake LLM) + -> ToolNode dispatch (real) + -> setup_agent tool (real file I/O) + +The only mock is the LLM (no API key needed). Every layer that participates +in ``user_id`` propagation — auth, ContextVar, ``inject_authenticated_user_context``, +``worker._build_runtime_context``, ``Runtime.merge`` — is the real production +code path. If the chain is broken at any layer, this test fails. + +This is what "真实验证" looks like for a server that lives behind authentication: +register a user, log in (cookie), POST to /runs/stream, wait for the run to +finish, then read the filesystem. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any +from unittest.mock import patch + +import pytest +from _agent_e2e_helpers import FakeToolCallingModel, build_single_tool_call_model + + +def _build_fake_create_chat_model(agent_name: str): + """Return a callable matching the real ``create_chat_model`` signature. + + Whenever the lead agent constructs a chat model during the bootstrap flow, + we hand it a fake that emits a single setup_agent tool_call on its first + turn, then a benign final answer on its second turn. + """ + + def fake_create_chat_model(*args: Any, **kwargs: Any) -> FakeToolCallingModel: + return build_single_tool_call_model( + tool_name="setup_agent", + tool_args={ + "soul": f"# Real HTTP E2E SOUL for {agent_name}", + "description": "real-http-e2e agent", + }, + tool_call_id="call_real_http_1", + final_text=f"Agent {agent_name} created via real HTTP e2e.", + ) + + return fake_create_chat_model + + +@pytest.fixture +def isolated_deer_flow_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + """Stand up an isolated DeerFlow data root + config under tmp_path. + + - Sets ``DEER_FLOW_HOME`` so paths land under tmp_path, not the real + ``.deer-flow`` directory. + - Stages a copy of the project's ``config.yaml`` (or ``config.example.yaml`` + on a fresh CI checkout where ``config.yaml`` is gitignored) and pins + ``DEER_FLOW_CONFIG_PATH`` to it, so lifespan boot doesn't depend on the + developer's local config layout. + - Sets a placeholder OPENAI_API_KEY because the config has + ``$OPENAI_API_KEY`` that gets resolved at parse time; the LLM itself is + mocked, so any non-empty value works. + """ + home = tmp_path / "deer-flow-home" + home.mkdir() + monkeypatch.setenv("DEER_FLOW_HOME", str(home)) + monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-key-not-used-because-llm-is-mocked") + monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid") + + # Hermetic config: do not depend on whether the dev machine has a real + # ``config.yaml`` at the repo root. CI's ``actions/checkout`` only ships + # ``config.example.yaml`` (and its ``models:`` list is commented out, so + # AppConfig validation would reject it). Write a minimal, self-sufficient + # config to tmp_path and pin ``DEER_FLOW_CONFIG_PATH`` to it. + staged_config = tmp_path / "config.yaml" + staged_config.write_text(_MINIMAL_CONFIG_YAML, encoding="utf-8") + monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(staged_config)) + + return home + + +# Minimal config that satisfies AppConfig + LeadAgent's _resolve_model_name. +# The model `use` path must resolve to a real class for config parsing to +# succeed; the test patches ``create_chat_model`` on the lead agent module, +# so the model is never actually instantiated. SandboxConfig.use is required +# at schema level; LocalSandboxProvider is the only sandbox that runs without +# Docker. +_MINIMAL_CONFIG_YAML = """\ +log_level: info +models: + - name: fake-test-model + display_name: Fake Test Model + use: langchain_openai:ChatOpenAI + model: gpt-4o-mini + api_key: $OPENAI_API_KEY + base_url: $OPENAI_API_BASE +sandbox: + use: deerflow.sandbox.local:LocalSandboxProvider +agents_api: + enabled: true +database: + backend: sqlite +""" + + +def _reset_process_singletons(monkeypatch: pytest.MonkeyPatch) -> None: + """Reset every process-wide cache that would survive across tests. + + This fixture stands up a full FastAPI app + sqlite DB + LangGraph runtime + inside ``tmp_path``. To get true per-test isolation we have to invalidate + a handful of module-level caches that production normally never resets, + so they pick up our test-only ``DEER_FLOW_HOME`` and sqlite path: + + - ``deerflow.config.app_config`` caches the parsed ``config.yaml``. + - ``deerflow.config.paths`` caches the ``Paths`` singleton derived from + ``DEER_FLOW_HOME`` at first access. + - ``deerflow.persistence.engine`` caches the SQLAlchemy engine and + session factory after the first call to ``init_engine_from_config``. + + ``raising=False`` keeps the fixture resilient if upstream renames or + drops one of these attributes — the test will simply skip that reset + instead of failing with a confusing AttributeError, and the next test + to call ``get_app_config()``/``get_paths()`` will surface the real + incompatibility loudly. + """ + from deerflow.config import app_config as app_config_module + from deerflow.config import paths as paths_module + from deerflow.persistence import engine as engine_module + + for module, attr in ( + (app_config_module, "_app_config"), + (app_config_module, "_app_config_path"), + (app_config_module, "_app_config_mtime"), + (paths_module, "_paths_singleton"), + (engine_module, "_engine"), + (engine_module, "_session_factory"), + ): + monkeypatch.setattr(module, attr, None, raising=False) + + +@pytest.fixture +def isolated_app(isolated_deer_flow_home: Path, monkeypatch: pytest.MonkeyPatch): + """Build a fresh FastAPI app inside a clean DEER_FLOW_HOME. + + Each test gets its own sqlite DB and checkpoint store under ``tmp_path``, + with no cross-test contamination. + """ + _reset_process_singletons(monkeypatch) + + # Re-resolve the config from the test-only DEER_FLOW_HOME and pin its + # sqlite path into tmp_path so the lifespan-time engine init lands there. + from deerflow.config import app_config as app_config_module + + cfg = app_config_module.get_app_config() + cfg.database.sqlite_dir = str(isolated_deer_flow_home / "db") + + from app.gateway.app import create_app + + return create_app() + + +def _drain_stream(response, *, timeout: float = 30.0, max_bytes: int = 4 * 1024 * 1024) -> str: + """Consume an SSE response body until the run terminates and return the text. + + Bounded to keep the test fail-fast: + - Stops as soon as an ``event: end`` SSE frame is observed (the gateway + sends this when the background run finishes — see ``services.format_sse`` + and ``StreamBridge.publish_end``). + - Stops at ``timeout`` seconds wall-clock so a stuck run / runaway heartbeat + loop surfaces a real failure instead of hanging pytest. + - Stops at ``max_bytes`` so a runaway producer can't OOM the test process. + """ + import time as _time + + deadline = _time.monotonic() + timeout + body = b"" + for chunk in response.iter_bytes(): + body += chunk + if b"event: end" in body: + break + if len(body) >= max_bytes: + break + if _time.monotonic() >= deadline: + break + return body.decode("utf-8", errors="replace") + + +def _wait_for_file(path: Path, *, timeout: float = 10.0) -> bool: + """Block until *path* exists or *timeout* elapses. + + The run completes inside ``asyncio.create_task`` after start_run returns, + so the test must wait for the background task to flush its writes. + """ + import time as _time + + deadline = _time.monotonic() + timeout + while _time.monotonic() < deadline: + if path.exists(): + return True + _time.sleep(0.05) + return False + + +@pytest.mark.no_auto_user +def test_real_http_create_agent_lands_in_authenticated_user_dir( + isolated_app: Any, + isolated_deer_flow_home: Path, + monkeypatch: pytest.MonkeyPatch, +): + """The full real-server contract test. + + 1. Register a real user via POST /api/v1/auth/register (also auto-logs in) + 2. POST to /api/threads/{tid}/runs/stream with the **exact** body shape the + frontend (LangGraph SDK) sends during the bootstrap flow. + 3. Wait for the background run to finish. + 4. Assert SOUL.md exists under users//agents//. + 5. Assert NOTHING exists under users/default/agents//. + """ + # ``deerflow.agents.lead_agent.agent`` imports ``create_chat_model`` with + # ``from deerflow.models import create_chat_model`` at module load time, + # rebinding the symbol into its own namespace. So the only patch that + # intercepts the call is the bound name on ``lead_agent.agent`` — patching + # ``deerflow.models.create_chat_model`` would be too late. + agent_name = "real-http-agent" + + from starlette.testclient import TestClient + + with ( + patch( + "deerflow.agents.lead_agent.agent.create_chat_model", + new=_build_fake_create_chat_model(agent_name), + ), + TestClient(isolated_app) as client, + ): + # --- 1. Register & auto-login --- + register = client.post( + "/api/v1/auth/register", + json={"email": "e2e-user@example.com", "password": "very-strong-password-123"}, + ) + assert register.status_code == 201, register.text + registered = register.json() + auth_uid = registered["id"] + # The endpoint sets both access_token (auth) and csrf_token (CSRF Double + # Submit Cookie) cookies; the TestClient cookie jar propagates them. + assert client.cookies.get("access_token"), "register endpoint must set session cookie" + csrf_token = client.cookies.get("csrf_token") + assert csrf_token, "register endpoint must set csrf_token cookie" + + # --- 2. Create a thread (require_existing=True on /runs/stream means + # we must call POST /api/threads first; the React frontend does the + # same via the LangGraph SDK's threads.create) --- + import uuid as _uuid + + thread_id = str(_uuid.uuid4()) + created = client.post( + "/api/threads", + json={"thread_id": thread_id, "metadata": {}}, + headers={"X-CSRF-Token": csrf_token}, + ) + assert created.status_code == 200, created.text + + # --- 3. POST /runs/stream with the bootstrap wire format --- + # This is the EXACT shape the React frontend sends after PR #2784: + # thread.submit(input, {config, context}) -> + # POST /api/threads/{id}/runs/stream body = + # {assistant_id, input, config, context} + body = { + "assistant_id": "lead_agent", + "input": { + "messages": [ + { + "role": "user", + "content": (f"The new custom agent name is {agent_name}. Help me design its SOUL.md before saving it."), + } + ] + }, + "config": {"recursion_limit": 50}, + "context": { + "agent_name": agent_name, + "is_bootstrap": True, + "mode": "flash", + "thinking_enabled": False, + "is_plan_mode": False, + "subagent_enabled": False, + }, + "stream_mode": ["values"], + } + # The /stream endpoint returns SSE; we drain it so the server-side + # background task (run_agent) gets to completion before we look at disk. + with client.stream( + "POST", + f"/api/threads/{thread_id}/runs/stream", + json=body, + headers={"X-CSRF-Token": csrf_token}, + ) as resp: + assert resp.status_code == 200, resp.read().decode() + transcript = _drain_stream(resp) + + # Sanity: the stream should have produced at least one event + assert "event:" in transcript, f"no SSE events in response: {transcript[:500]!r}" + + # --- 4. Verify filesystem outcome --- + expected_dir = isolated_deer_flow_home / "users" / auth_uid / "agents" / agent_name + default_dir = isolated_deer_flow_home / "users" / "default" / "agents" / agent_name + + # The setup_agent tool runs inside the background asyncio task spawned + # by start_run; SSE-drain typically waits for it, but we add a bounded + # poll to be robust against scheduler jitter. + assert _wait_for_file(expected_dir / "SOUL.md", timeout=15.0), ( + "SOUL.md did not appear under users//agents/. " + f"Expected: {expected_dir / 'SOUL.md'}. " + f"tmp tree: {sorted(str(p.relative_to(isolated_deer_flow_home)) for p in isolated_deer_flow_home.rglob('SOUL.md'))}. " + f"SSE transcript tail: {transcript[-1000:]!r}" + ) + + soul_text = (expected_dir / "SOUL.md").read_text() + assert agent_name in soul_text, f"unexpected SOUL content: {soul_text!r}" + + # The smoking-gun assertion: the agent must NOT have landed in default/ + assert not default_dir.exists(), f"REGRESSION: agent landed under users/default/{agent_name} instead of the authenticated user. Default-dir contents: {list(default_dir.rglob('*')) if default_dir.exists() else 'n/a'}" diff --git a/backend/tests/test_update_agent_e2e_user_isolation.py b/backend/tests/test_update_agent_e2e_user_isolation.py new file mode 100644 index 000000000..7fa725352 --- /dev/null +++ b/backend/tests/test_update_agent_e2e_user_isolation.py @@ -0,0 +1,253 @@ +"""End-to-end verification for update_agent's user_id resolution. + +PR #2784 hardened setup_agent to prefer runtime.context["user_id"] over the +contextvar. update_agent had the same latent gap: it unconditionally called +get_effective_user_id() at module level, so any scenario where the contextvar +was unavailable while runtime.context carried user_id (a background task +scheduled outside the request task, a worker pool that doesn't copy_context, +checkpoint resume on a different task) would silently route writes to +users/default/agents/... + +These tests are load-bearing under @no_auto_user (contextvar empty): + +- The negative-control test confirms the fixture actually puts the tool in + the regime where the contextvar fallback would land in users/default/. + Without that, the positive test would be vacuously satisfied. +- The positive test verifies update_agent honours runtime.context["user_id"] + injected by inject_authenticated_user_context in the gateway. Before the + fix in this PR, this test failed; now it passes. +""" + +from __future__ import annotations + +from contextlib import ExitStack +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import UUID + +import pytest +import yaml +from _agent_e2e_helpers import build_single_tool_call_model +from langchain_core.messages import HumanMessage + +from app.gateway.services import ( + build_run_config, + inject_authenticated_user_context, + merge_run_context_overrides, +) +from deerflow.runtime.runs.worker import _build_runtime_context, _install_runtime_context + + +def _make_request(user_id_str: str | None) -> SimpleNamespace: + user = SimpleNamespace(id=UUID(user_id_str), email="alice@local") if user_id_str else None + return SimpleNamespace(state=SimpleNamespace(user=user)) + + +def _assemble_config(*, body_context: dict | None, request_user_id: str | None, thread_id: str) -> dict: + config = build_run_config(thread_id, {"recursion_limit": 50}, None, assistant_id="lead_agent") + merge_run_context_overrides(config, body_context) + inject_authenticated_user_context(config, _make_request(request_user_id)) + return config + + +def _seed_existing_agent(tmp_path: Path, user_id: str, agent_name: str, soul: str = "# Original"): + """Pre-create an agent on disk for update_agent to overwrite.""" + agent_dir = tmp_path / "users" / user_id / "agents" / agent_name + agent_dir.mkdir(parents=True, exist_ok=True) + (agent_dir / "config.yaml").write_text( + yaml.dump({"name": agent_name, "description": "old"}, allow_unicode=True), + encoding="utf-8", + ) + (agent_dir / "SOUL.md").write_text(soul, encoding="utf-8") + return agent_dir + + +def _make_paths_mock(tmp_path: Path): + paths = MagicMock() + paths.base_dir = tmp_path + paths.agent_dir = lambda name: tmp_path / "agents" / name + paths.user_agent_dir = lambda user_id, name: tmp_path / "users" / user_id / "agents" / name + return paths + + +def _patch_update_agent_dependencies(tmp_path: Path): + """update_agent reads load_agent_config + get_app_config — stub them + minimally so the tool can run without a real config file or LLM.""" + fake_model_cfg = SimpleNamespace(name="fake-model") + fake_app_cfg = MagicMock() + fake_app_cfg.get_model_config = lambda name: fake_model_cfg if name == "fake-model" else None + + return [ + patch( + "deerflow.tools.builtins.update_agent_tool.get_paths", + return_value=_make_paths_mock(tmp_path), + ), + patch( + "deerflow.tools.builtins.update_agent_tool.get_app_config", + return_value=fake_app_cfg, + ), + # load_agent_config (used by update_agent to read existing config) also + # reads paths via its own module-level get_paths reference. Patch it too + # or the tool returns "Agent does not exist" before touching disk. + patch( + "deerflow.config.agents_config.get_paths", + return_value=_make_paths_mock(tmp_path), + ), + ] + + +def _build_update_graph(*, soul_payload: str): + from langchain.agents import create_agent + + from deerflow.tools.builtins.update_agent_tool import update_agent + + fake_model = build_single_tool_call_model( + tool_name="update_agent", + tool_args={"soul": soul_payload, "description": "refined"}, + tool_call_id="call_update_1", + final_text="updated", + ) + return create_agent(model=fake_model, tools=[update_agent], system_prompt="updater") + + +# --------------------------------------------------------------------------- +# Negative control — proves the test environment puts update_agent in the +# regime where the contextvar fallback would land in default/. +# --------------------------------------------------------------------------- + + +@pytest.mark.no_auto_user +def test_update_agent_falls_back_to_default_when_no_inject_and_no_contextvar(tmp_path: Path): + """No request.state.user, no contextvar — update_agent must look in + users/default/agents/. We seed the file there so the tool succeeds and + we know which directory it actually consulted.""" + from langgraph.runtime import Runtime + + _seed_existing_agent(tmp_path, "default", "fallback-target") + + config = _assemble_config( + body_context={"agent_name": "fallback-target"}, + request_user_id=None, # no auth, inject is no-op + thread_id="thread-update-1", + ) + runtime_ctx = _build_runtime_context("thread-update-1", "run-1", config.get("context"), None) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + graph = _build_update_graph(soul_payload="# Fallback Updated") + + with ExitStack() as stack: + for p in _patch_update_agent_dependencies(tmp_path): + stack.enter_context(p) + graph.invoke( + {"messages": [HumanMessage(content="update fallback-target")]}, + config=config, + ) + + soul = (tmp_path / "users" / "default" / "agents" / "fallback-target" / "SOUL.md").read_text() + assert soul == "# Fallback Updated", "Sanity: tool should have written under default/" + + +# --------------------------------------------------------------------------- +# Regression guard — passes on this branch, would fail on main before the fix. +# --------------------------------------------------------------------------- + + +@pytest.mark.no_auto_user +def test_update_agent_should_use_runtime_context_user_id_when_contextvar_missing(tmp_path: Path): + """update_agent prefers the authenticated user_id carried in + runtime.context (placed there by inject_authenticated_user_context) + over the contextvar — same contract as setup_agent (PR #2784). + + Before this PR's fix, update_agent unconditionally called + get_effective_user_id() and landed in default/ whenever the contextvar + was unavailable. This test pins the corrected behaviour. + """ + from langgraph.runtime import Runtime + + auth_uid = "abcdef01-2345-6789-abcd-ef0123456789" + + # Seed the agent in BOTH locations so we can prove which one was opened. + auth_dir = _seed_existing_agent(tmp_path, auth_uid, "shared-name", soul="# Auth Original") + default_dir = _seed_existing_agent(tmp_path, "default", "shared-name", soul="# Default Original") + + config = _assemble_config( + body_context={"agent_name": "shared-name"}, + request_user_id=auth_uid, + thread_id="thread-update-2", + ) + runtime_ctx = _build_runtime_context("thread-update-2", "run-2", config.get("context"), None) + assert runtime_ctx["user_id"] == auth_uid, "Pre-condition: inject must have placed user_id into runtime_ctx" + + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + graph = _build_update_graph(soul_payload="# Auth Updated") + + with ExitStack() as stack: + for p in _patch_update_agent_dependencies(tmp_path): + stack.enter_context(p) + graph.invoke( + {"messages": [HumanMessage(content="update shared-name")]}, + config=config, + ) + + auth_soul = (auth_dir / "SOUL.md").read_text() + default_soul = (default_dir / "SOUL.md").read_text() + + assert auth_soul == "# Auth Updated", f"REGRESSION: update_agent ignored runtime.context['user_id']={auth_uid!r} and routed the write to users/default/ instead. auth_soul={auth_soul!r}, default_soul={default_soul!r}" + assert default_soul == "# Default Original", "REGRESSION: update_agent corrupted the shared default-user agent. It should have written under the authenticated user's path." + + +# --------------------------------------------------------------------------- +# Positive — when contextvar IS the auth user (the normal HTTP case), things +# already work. Pin it as a regression guard so future refactors don't +# accidentally break the contextvar path in pursuit of the runtime-context fix. +# --------------------------------------------------------------------------- + + +def test_update_agent_uses_contextvar_when_present(tmp_path: Path, monkeypatch): + """The normal HTTP case: contextvar is set by auth_middleware. This must + keep working regardless of how runtime.context is populated.""" + from types import SimpleNamespace as _SN + + from deerflow.runtime.user_context import reset_current_user, set_current_user + + auth_uid = "11112222-3333-4444-5555-666677778888" + user = _SN(id=auth_uid, email="ctxvar@local") + + _seed_existing_agent(tmp_path, auth_uid, "ctxvar-agent", soul="# Original") + + from langgraph.runtime import Runtime + + config = _assemble_config( + body_context={"agent_name": "ctxvar-agent"}, + request_user_id=auth_uid, + thread_id="thread-update-3", + ) + runtime_ctx = _build_runtime_context("thread-update-3", "run-3", config.get("context"), None) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=runtime_ctx, store=None) + config.setdefault("configurable", {})["__pregel_runtime"] = runtime + + graph = _build_update_graph(soul_payload="# CtxVar Updated") + + with ExitStack() as stack: + for p in _patch_update_agent_dependencies(tmp_path): + stack.enter_context(p) + token = set_current_user(user) + try: + final = graph.invoke( + {"messages": [HumanMessage(content="update ctxvar-agent")]}, + config=config, + ) + finally: + reset_current_user(token) + + # surface the tool's reply for debug if it errored + tool_replies = [m.content for m in final["messages"] if getattr(m, "type", "") == "tool"] + soul = (tmp_path / "users" / auth_uid / "agents" / "ctxvar-agent" / "SOUL.md").read_text() + assert soul == "# CtxVar Updated", f"tool replies: {tool_replies}" From e9deb6c2f203d633b88578e6400c7fab4466ad86 Mon Sep 17 00:00:00 2001 From: He Wang Date: Tue, 12 May 2026 23:21:22 +0800 Subject: [PATCH 17/86] perf(harness): push thread metadata filters into SQL (#2865) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * perf(harness): push thread metadata filters into SQL Replace Python-side metadata filtering (5x overfetch + in-memory match) with database-side json_extract predicates so LIMIT/OFFSET pagination is exact regardless of match density. Co-Authored-By: Claude Opus 4 * fix(harness): add dialect-aware JsonMatch compiler for type-safe metadata SQL filters Replace SQLAlchemy JSON index/comparator APIs with a custom JsonMatch ColumnElement that compiles to json_type/json_extract on SQLite and jsonb_typeof/->>/-> on PostgreSQL. Tighten key validation regex to single-segment identifiers, handle None/bool/numeric value types with json_type-based discrimination, and strengthen test coverage for edge cases and discriminability. Co-Authored-By: Claude Opus 4 * fix(harness): address Copilot review comments on JSON metadata filters - Use json_typeof instead of jsonb_typeof in PostgreSQL compiler; the metadata_json column is JSON not JSONB so jsonb_typeof would error at runtime on any PostgreSQL backend - Align _is_safe_json_key with json_match's _KEY_CHARSET_RE so keys containing hyphens or leading digits are not silently skipped - Add thread_id as secondary ORDER BY in search() to make pagination deterministic when updated_at values collide; remove asyncio.sleep from the pagination regression test Co-Authored-By: Claude Sonnet 4 * fix(harness): address remaining review comments on metadata SQL filters - Remove _is_safe_json_key() and reuse json_match ValueError to avoid validator drift (Copilot #3217603895, #3217411616) - Raise ValueError when all metadata keys are rejected so callers never get silent unfiltered results (WillemJiang) - Fix integer precision: split int/float branches, bind int as Integer() with INTEGER/BIGINT CAST instead of float() coercion (Copilot #3217603972) - Fix jsonb_typeof -> json_typeof on JSON column (Copilot #3217411579) - Replace manual _cleanup() calls with async yield fixture so teardown always runs (Copilot #3217604019) - Remove asyncio.sleep(0.01) pagination ordering; use thread_id secondary sort instead (Copilot #3217411636) - Add type annotations to _bind/_build_clause/_compile_* and remove EOL comments from _Dialect fields (coding.mdc) - Expand test coverage: boolean/null/mixed-type/large-int precision, partial unsafe-key skip with caplog assertion Co-Authored-By: Claude Sonnet 4.6 * fix(harness): address third-round Copilot review comments on JsonMatch - Reject unsupported value types (list, dict, ...) in JsonMatch.__init__ with TypeError so inherit_cache=True never receives an unhashable value and callers get an explicit error instead of silent str() coercion (Copilot #3217933201) - Upgrade int bindparam from Integer() to BigInteger() to align with BIGINT CAST and avoid overflow on large integers (Copilot #3217933252) - Catch TypeError alongside ValueError in search() so non-string metadata keys are warned and skipped rather than raising unexpectedly (Copilot #3217933300) - Add three tests: json_match rejects unsupported value types, search() warns and raises on non-string key, search() warns and raises on unsupported value type Co-Authored-By: Claude Sonnet 4.6 * fix(harness): address fourth-round Copilot review comments on JsonMatch - Add CASE WHEN guard for PostgreSQL integer matching: json_typeof returns 'number' for both ints and floats; wrap CAST in CASE with regex guard '^-?[0-9]+$' so float rows never trigger CAST error (Copilot #3218413860) - Validate isinstance(key, str) before regex match in JsonMatch.__init__ so non-string keys raise ValueError consistently instead of TypeError from re.match (Copilot #3218413900) - Include exception message in metadata filter skip warning so callers can distinguish invalid key from unsupported value type (Copilot #3218413924) - Update tests: assert CASE WHEN guard in PG int compilation, cover non-string key ValueError in test_json_match_rejects_unsafe_key Co-Authored-By: Claude Sonnet 4.6 * fix(harness): align ThreadMetaStore.search() signature with sql.py implementation Use `dict[str, Any]` for `metadata` and `list[dict[str, Any]]` as return type in base class and MemoryThreadMetaStore to resolve an LSP signature mismatch; also correct a test docstring that cited the wrong exception type. Co-Authored-By: Claude Sonnet 4.6 * fix(harness): surface InvalidMetadataFilterError as HTTP 400 in search endpoint Replace bare ValueError with a domain-specific InvalidMetadataFilterError (subclass of ValueError) so the Gateway handler can catch it and return HTTP 400 instead of letting it bubble up as a 500. Co-Authored-By: Claude Opus 4 * fix(harness): sanitize metadata keys in log output to prevent log injection Use ascii() instead of %r to escape control characters in client-supplied metadata keys before logging, preventing multiline/forged log entries. Co-Authored-By: Claude Opus 4 * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix(harness): validate metadata filters at API boundary and dedupe key/value rules - Add Pydantic ``field_validator`` on ``ThreadSearchRequest.metadata`` so unsafe keys / unsupported value types are rejected with HTTP 422 from both SQL and memory backends (closes Copilot review 3218830849). - Export ``validate_metadata_filter_key`` / ``validate_metadata_filter_value`` (and ``ALLOWED_FILTER_VALUE_TYPES``) from ``json_compat`` and have ``JsonMatch.__init__`` reuse them — the Gateway-side validator and the SQL-side ``JsonMatch`` constructor now share one admission rule and cannot drift. - Format ``InvalidMetadataFilterError`` rejected-keys list as a comma-separated plain string instead of a Python list repr so the surfaced HTTP 400 detail is readable (closes Copilot review 3218830899). - Update router tests to cover both 422 boundary paths plus the 400 defense-in-depth path when a backend still raises the error. Co-authored-by: Cursor * fix(harness): harden JsonMatch compile-time key validation against __init__ bypass Co-Authored-By: Claude Sonnet 4 * fix: address review feedback on metadata filter SQL push-down - Add signed 64-bit range check to validate_metadata_filter_value; give out-of-range ints a distinct TypeError message. - Replace assert guards in _compile_sqlite/_compile_pg with explicit if/raise so they survive python -O optimisation. Co-Authored-By: Claude Sonnet 4 --------- Co-authored-by: Claude Opus 4 Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Co-authored-by: Cursor --- backend/app/gateway/routers/threads.py | 38 +- .../deerflow/persistence/json_compat.py | 195 +++++++ .../persistence/thread_meta/__init__.py | 3 +- .../deerflow/persistence/thread_meta/base.py | 9 +- .../persistence/thread_meta/memory.py | 4 +- .../deerflow/persistence/thread_meta/sql.py | 46 +- backend/tests/test_thread_meta_repo.py | 504 +++++++++++++++--- backend/tests/test_threads_router.py | 54 ++ 8 files changed, 757 insertions(+), 96 deletions(-) create mode 100644 backend/packages/harness/deerflow/persistence/json_compat.py diff --git a/backend/app/gateway/routers/threads.py b/backend/app/gateway/routers/threads.py index cb048152e..e6f4fa2ae 100644 --- a/backend/app/gateway/routers/threads.py +++ b/backend/app/gateway/routers/threads.py @@ -90,6 +90,28 @@ class ThreadSearchRequest(BaseModel): offset: int = Field(default=0, ge=0, description="Pagination offset") status: str | None = Field(default=None, description="Filter by thread status") + @field_validator("metadata") + @classmethod + def _validate_metadata_filters(cls, v: dict[str, Any]) -> dict[str, Any]: + """Reject filter entries the SQL backend cannot compile. + + Enforces consistent behaviour across SQL and memory backends. + See ``deerflow.persistence.json_compat`` for the shared validators. + """ + if not v: + return v + from deerflow.persistence.json_compat import validate_metadata_filter_key, validate_metadata_filter_value + + bad_entries: list[str] = [] + for key, value in v.items(): + if not validate_metadata_filter_key(key): + bad_entries.append(f"{key!r} (unsafe key)") + elif not validate_metadata_filter_value(value): + bad_entries.append(f"{key!r} (unsupported value type {type(value).__name__})") + if bad_entries: + raise ValueError(f"Invalid metadata filter entries: {', '.join(bad_entries)}") + return v + class ThreadStateResponse(BaseModel): """Response model for thread state.""" @@ -294,14 +316,18 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th (SQL-backed for sqlite/postgres, Store-backed for memory mode). """ from app.gateway.deps import get_thread_store + from deerflow.persistence.thread_meta import InvalidMetadataFilterError repo = get_thread_store(request) - rows = await repo.search( - metadata=body.metadata or None, - status=body.status, - limit=body.limit, - offset=body.offset, - ) + try: + rows = await repo.search( + metadata=body.metadata or None, + status=body.status, + limit=body.limit, + offset=body.offset, + ) + except InvalidMetadataFilterError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc return [ ThreadResponse( thread_id=r["thread_id"], diff --git a/backend/packages/harness/deerflow/persistence/json_compat.py b/backend/packages/harness/deerflow/persistence/json_compat.py new file mode 100644 index 000000000..442b29e22 --- /dev/null +++ b/backend/packages/harness/deerflow/persistence/json_compat.py @@ -0,0 +1,195 @@ +"""Dialect-aware JSON value matching for SQLAlchemy (SQLite + PostgreSQL).""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Any + +from sqlalchemy import BigInteger, Float, String, bindparam +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.compiler import SQLCompiler +from sqlalchemy.sql.expression import ColumnElement +from sqlalchemy.sql.visitors import InternalTraversal +from sqlalchemy.types import Boolean, TypeEngine + +# Key is interpolated into compiled SQL; restrict charset to prevent injection. +_KEY_CHARSET_RE = re.compile(r"^[A-Za-z0-9_\-]+$") + +# Allowed value types for metadata filter values (same set accepted by JsonMatch). +ALLOWED_FILTER_VALUE_TYPES: tuple[type, ...] = (type(None), bool, int, float, str) + +# SQLite raises an overflow when binding values outside signed 64-bit range; +# PostgreSQL overflows during BIGINT cast. Reject at validation time instead. +_INT64_MIN = -(2**63) +_INT64_MAX = 2**63 - 1 + + +def validate_metadata_filter_key(key: object) -> bool: + """Return True if *key* is safe for use as a JSON metadata filter key. + + A key is "safe" when it is a string matching ``[A-Za-z0-9_-]+``. The + charset is restricted because the key is interpolated into the + compiled SQL path expression (``$.""`` / ``->`` literal), so any + laxer pattern would open a SQL/JSONPath injection surface. + """ + return isinstance(key, str) and bool(_KEY_CHARSET_RE.match(key)) + + +def validate_metadata_filter_value(value: object) -> bool: + """Return True if *value* is an allowed type for a JSON metadata filter. + + Matches the set of types ``_build_clause`` knows how to compile into + a dialect-portable predicate. Anything else (list/dict/bytes/...) is + intentionally rejected rather than silently coerced via ``str()`` — + silent coercion would (a) produce wrong matches and (b) break + SQLAlchemy's ``inherit_cache`` invariant when ``value`` is unhashable. + + Integer values are additionally restricted to the signed 64-bit range + ``[-2**63, 2**63 - 1]``: SQLite overflows when binding larger values + and PostgreSQL overflows during the ``BIGINT`` cast. + """ + if not isinstance(value, ALLOWED_FILTER_VALUE_TYPES): + return False + if isinstance(value, int) and not isinstance(value, bool): + if not (_INT64_MIN <= value <= _INT64_MAX): + return False + return True + + +class JsonMatch(ColumnElement): + """Dialect-portable ``column[key] == value`` for JSON columns. + + Compiles to ``json_type``/``json_extract`` on SQLite and + ``json_typeof``/``->>`` on PostgreSQL, with type-safe comparison + that distinguishes bool vs int and NULL vs missing key. + + *key* must be a single literal key matching ``[A-Za-z0-9_-]+``. + *value* must be one of: ``None``, ``bool``, ``int`` (signed 64-bit), ``float``, ``str``. + """ + + inherit_cache = True + type = Boolean() + _is_implicitly_boolean = True + + _traverse_internals = [ + ("column", InternalTraversal.dp_clauseelement), + ("key", InternalTraversal.dp_string), + ("value", InternalTraversal.dp_plain_obj), + ] + + def __init__(self, column: ColumnElement, key: str, value: object) -> None: + if not validate_metadata_filter_key(key): + raise ValueError(f"JsonMatch key must match {_KEY_CHARSET_RE.pattern!r}; got: {key!r}") + if not validate_metadata_filter_value(value): + if isinstance(value, int) and not isinstance(value, bool): + raise TypeError(f"JsonMatch int value out of signed 64-bit range [-2**63, 2**63-1]: {value!r}") + raise TypeError(f"JsonMatch value must be None, bool, int, float, or str; got: {type(value).__name__!r}") + self.column = column + self.key = key + self.value = value + super().__init__() + + +@dataclass(frozen=True) +class _Dialect: + """Per-dialect names used when emitting JSON type/value comparisons.""" + + null_type: str + num_types: tuple[str, ...] + num_cast: str + int_types: tuple[str, ...] + int_cast: str + # None for SQLite where json_type already returns 'integer'/'real'; + # regex literal for PostgreSQL where json_typeof returns 'number' for + # both ints and floats, so an extra guard prevents CAST errors on floats. + int_guard: str | None + string_type: str + bool_type: str | None + + +_SQLITE = _Dialect( + null_type="null", + num_types=("integer", "real"), + num_cast="REAL", + int_types=("integer",), + int_cast="INTEGER", + int_guard=None, + string_type="text", + bool_type=None, +) + +_PG = _Dialect( + null_type="null", + num_types=("number",), + num_cast="DOUBLE PRECISION", + int_types=("number",), + int_cast="BIGINT", + int_guard="'^-?[0-9]+$'", + string_type="string", + bool_type="boolean", +) + + +def _bind(compiler: SQLCompiler, value: object, sa_type: TypeEngine[Any], **kw: Any) -> str: + param = bindparam(None, value, type_=sa_type) + return compiler.process(param, **kw) + + +def _type_check(typeof: str, types: tuple[str, ...]) -> str: + if len(types) == 1: + return f"{typeof} = '{types[0]}'" + quoted = ", ".join(f"'{t}'" for t in types) + return f"{typeof} IN ({quoted})" + + +def _build_clause(compiler: SQLCompiler, typeof: str, extract: str, value: object, dialect: _Dialect, **kw: Any) -> str: + if value is None: + return f"{typeof} = '{dialect.null_type}'" + if isinstance(value, bool): + # bool check must precede int check — bool is a subclass of int in Python + bool_str = "true" if value else "false" + if dialect.bool_type is None: + return f"{typeof} = '{bool_str}'" + return f"({typeof} = '{dialect.bool_type}' AND {extract} = '{bool_str}')" + if isinstance(value, int): + bp = _bind(compiler, value, BigInteger(), **kw) + if dialect.int_guard: + # CASE prevents CAST error when json_typeof = 'number' also matches floats + return f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} THEN CAST({extract} AS {dialect.int_cast}) END = {bp})" + return f"({_type_check(typeof, dialect.int_types)} AND CAST({extract} AS {dialect.int_cast}) = {bp})" + if isinstance(value, float): + bp = _bind(compiler, value, Float(), **kw) + return f"({_type_check(typeof, dialect.num_types)} AND CAST({extract} AS {dialect.num_cast}) = {bp})" + bp = _bind(compiler, str(value), String(), **kw) + return f"({typeof} = '{dialect.string_type}' AND {extract} = {bp})" + + +@compiles(JsonMatch, "sqlite") +def _compile_sqlite(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str: + if not validate_metadata_filter_key(element.key): + raise ValueError(f"Key escaped validation: {element.key!r}") + col = compiler.process(element.column, **kw) + path = f'$."{element.key}"' + typeof = f"json_type({col}, '{path}')" + extract = f"json_extract({col}, '{path}')" + return _build_clause(compiler, typeof, extract, element.value, _SQLITE, **kw) + + +@compiles(JsonMatch, "postgresql") +def _compile_pg(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str: + if not validate_metadata_filter_key(element.key): + raise ValueError(f"Key escaped validation: {element.key!r}") + col = compiler.process(element.column, **kw) + typeof = f"json_typeof({col} -> '{element.key}')" + extract = f"({col} ->> '{element.key}')" + return _build_clause(compiler, typeof, extract, element.value, _PG, **kw) + + +@compiles(JsonMatch) +def _compile_default(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str: + raise NotImplementedError(f"JsonMatch supports only sqlite and postgresql; got dialect: {compiler.dialect.name}") + + +def json_match(column: ColumnElement, key: str, value: object) -> JsonMatch: + return JsonMatch(column, key, value) diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py b/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py index 080ce8093..b5231f0f9 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py @@ -4,7 +4,7 @@ from __future__ import annotations from typing import TYPE_CHECKING -from deerflow.persistence.thread_meta.base import ThreadMetaStore +from deerflow.persistence.thread_meta.base import InvalidMetadataFilterError, ThreadMetaStore from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore from deerflow.persistence.thread_meta.model import ThreadMetaRow from deerflow.persistence.thread_meta.sql import ThreadMetaRepository @@ -14,6 +14,7 @@ if TYPE_CHECKING: from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker __all__ = [ + "InvalidMetadataFilterError", "MemoryThreadMetaStore", "ThreadMetaRepository", "ThreadMetaRow", diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/base.py b/backend/packages/harness/deerflow/persistence/thread_meta/base.py index c87c10a16..ed55ade8e 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/base.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/base.py @@ -15,10 +15,15 @@ three-state semantics (see :mod:`deerflow.runtime.user_context`): from __future__ import annotations import abc +from typing import Any from deerflow.runtime.user_context import AUTO, _AutoSentinel +class InvalidMetadataFilterError(ValueError): + """Raised when all client-supplied metadata filter keys are rejected.""" + + class ThreadMetaStore(abc.ABC): @abc.abstractmethod async def create( @@ -40,12 +45,12 @@ class ThreadMetaStore(abc.ABC): async def search( self, *, - metadata: dict | None = None, + metadata: dict[str, Any] | None = None, status: str | None = None, limit: int = 100, offset: int = 0, user_id: str | None | _AutoSentinel = AUTO, - ) -> list[dict]: + ) -> list[dict[str, Any]]: pass @abc.abstractmethod diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py index fbe66fdaf..4f642a938 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py @@ -69,12 +69,12 @@ class MemoryThreadMetaStore(ThreadMetaStore): async def search( self, *, - metadata: dict | None = None, + metadata: dict[str, Any] | None = None, status: str | None = None, limit: int = 100, offset: int = 0, user_id: str | None | _AutoSentinel = AUTO, - ) -> list[dict]: + ) -> list[dict[str, Any]]: resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.search") filter_dict: dict[str, Any] = {} if metadata: diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py index 688fbb247..0d3f587de 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py @@ -2,16 +2,20 @@ from __future__ import annotations +import logging from datetime import UTC, datetime from typing import Any from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker -from deerflow.persistence.thread_meta.base import ThreadMetaStore +from deerflow.persistence.json_compat import json_match +from deerflow.persistence.thread_meta.base import InvalidMetadataFilterError, ThreadMetaStore from deerflow.persistence.thread_meta.model import ThreadMetaRow from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id +logger = logging.getLogger(__name__) + class ThreadMetaRepository(ThreadMetaStore): def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: @@ -20,7 +24,7 @@ class ThreadMetaRepository(ThreadMetaStore): @staticmethod def _row_to_dict(row: ThreadMetaRow) -> dict[str, Any]: d = row.to_dict() - d["metadata"] = d.pop("metadata_json", {}) + d["metadata"] = d.pop("metadata_json", None) or {} for key in ("created_at", "updated_at"): val = d.get(key) if isinstance(val, datetime): @@ -104,39 +108,43 @@ class ThreadMetaRepository(ThreadMetaStore): async def search( self, *, - metadata: dict | None = None, + metadata: dict[str, Any] | None = None, status: str | None = None, limit: int = 100, offset: int = 0, user_id: str | None | _AutoSentinel = AUTO, - ) -> list[dict]: + ) -> list[dict[str, Any]]: """Search threads with optional metadata and status filters. Owner filter is enforced by default: caller must be in a user context. Pass ``user_id=None`` to bypass (migration/CLI). """ resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.search") - stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc()) + stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc(), ThreadMetaRow.thread_id.desc()) if resolved_user_id is not None: stmt = stmt.where(ThreadMetaRow.user_id == resolved_user_id) if status: stmt = stmt.where(ThreadMetaRow.status == status) if metadata: - # When metadata filter is active, fetch a larger window and filter - # in Python. TODO(Phase 2): use JSON DB operators (Postgres @>, - # SQLite json_extract) for server-side filtering. - stmt = stmt.limit(limit * 5 + offset) - async with self._sf() as session: - result = await session.execute(stmt) - rows = [self._row_to_dict(r) for r in result.scalars()] - rows = [r for r in rows if all(r.get("metadata", {}).get(k) == v for k, v in metadata.items())] - return rows[offset : offset + limit] - else: - stmt = stmt.limit(limit).offset(offset) - async with self._sf() as session: - result = await session.execute(stmt) - return [self._row_to_dict(r) for r in result.scalars()] + applied = 0 + for key, value in metadata.items(): + try: + stmt = stmt.where(json_match(ThreadMetaRow.metadata_json, key, value)) + applied += 1 + except (ValueError, TypeError) as exc: + logger.warning("Skipping metadata filter key %s: %s", ascii(key), exc) + if applied == 0: + # Comma-separated plain string (no list repr / nested + # quoting) so the 400 detail surfaced by the Gateway is + # easy for clients to read. Sorted for determinism. + rejected_keys = ", ".join(sorted(str(k) for k in metadata)) + raise InvalidMetadataFilterError(f"All metadata filter keys were rejected as unsafe: {rejected_keys}") + + stmt = stmt.limit(limit).offset(offset) + async with self._sf() as session: + result = await session.execute(stmt) + return [self._row_to_dict(r) for r in result.scalars()] async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_user_id: str | None) -> bool: """Return True if the row exists and is owned (or filter bypassed).""" diff --git a/backend/tests/test_thread_meta_repo.py b/backend/tests/test_thread_meta_repo.py index 3a6532567..1cef3752b 100644 --- a/backend/tests/test_thread_meta_repo.py +++ b/backend/tests/test_thread_meta_repo.py @@ -1,28 +1,25 @@ """Tests for ThreadMetaRepository (SQLAlchemy-backed).""" +import logging + import pytest -from deerflow.persistence.thread_meta import ThreadMetaRepository +from deerflow.persistence.thread_meta import InvalidMetadataFilterError, ThreadMetaRepository -async def _make_repo(tmp_path): - from deerflow.persistence.engine import get_session_factory, init_engine +@pytest.fixture +async def repo(tmp_path): + from deerflow.persistence.engine import close_engine, get_session_factory, init_engine url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) - return ThreadMetaRepository(get_session_factory()) - - -async def _cleanup(): - from deerflow.persistence.engine import close_engine - + yield ThreadMetaRepository(get_session_factory()) await close_engine() class TestThreadMetaRepository: @pytest.mark.anyio - async def test_create_and_get(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_create_and_get(self, repo): record = await repo.create("t1") assert record["thread_id"] == "t1" assert record["status"] == "idle" @@ -31,148 +28,523 @@ class TestThreadMetaRepository: fetched = await repo.get("t1") assert fetched is not None assert fetched["thread_id"] == "t1" - await _cleanup() @pytest.mark.anyio - async def test_create_with_assistant_id(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_create_with_assistant_id(self, repo): record = await repo.create("t1", assistant_id="agent1") assert record["assistant_id"] == "agent1" - await _cleanup() @pytest.mark.anyio - async def test_create_with_owner_and_display_name(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_create_with_owner_and_display_name(self, repo): record = await repo.create("t1", user_id="user1", display_name="My Thread") assert record["user_id"] == "user1" assert record["display_name"] == "My Thread" - await _cleanup() @pytest.mark.anyio - async def test_create_with_metadata(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_create_with_metadata(self, repo): record = await repo.create("t1", metadata={"key": "value"}) assert record["metadata"] == {"key": "value"} - await _cleanup() @pytest.mark.anyio - async def test_get_nonexistent(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_get_nonexistent(self, repo): assert await repo.get("nonexistent") is None - await _cleanup() @pytest.mark.anyio - async def test_check_access_no_record_allows(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_no_record_allows(self, repo): assert await repo.check_access("unknown", "user1") is True - await _cleanup() @pytest.mark.anyio - async def test_check_access_owner_matches(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_owner_matches(self, repo): await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user1") is True - await _cleanup() @pytest.mark.anyio - async def test_check_access_owner_mismatch(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_owner_mismatch(self, repo): await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user2") is False - await _cleanup() @pytest.mark.anyio - async def test_check_access_no_owner_allows_all(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_no_owner_allows_all(self, repo): # Explicit user_id=None to bypass the new AUTO default that # would otherwise pick up the test user from the autouse fixture. await repo.create("t1", user_id=None) assert await repo.check_access("t1", "anyone") is True - await _cleanup() @pytest.mark.anyio - async def test_check_access_strict_missing_row_denied(self, tmp_path): + async def test_check_access_strict_missing_row_denied(self, repo): """require_existing=True flips the missing-row case to *denied*. Closes the delete-idempotence cross-user gap: after a thread is deleted, the row is gone, and the permissive default would let any caller "claim" it as untracked. The strict mode demands a row. """ - repo = await _make_repo(tmp_path) assert await repo.check_access("never-existed", "user1", require_existing=True) is False - await _cleanup() @pytest.mark.anyio - async def test_check_access_strict_owner_match_allowed(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_strict_owner_match_allowed(self, repo): await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user1", require_existing=True) is True - await _cleanup() @pytest.mark.anyio - async def test_check_access_strict_owner_mismatch_denied(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_check_access_strict_owner_mismatch_denied(self, repo): await repo.create("t1", user_id="user1") assert await repo.check_access("t1", "user2", require_existing=True) is False - await _cleanup() @pytest.mark.anyio - async def test_check_access_strict_null_owner_still_allowed(self, tmp_path): + async def test_check_access_strict_null_owner_still_allowed(self, repo): """Even in strict mode, a row with NULL user_id stays shared. The strict flag tightens the *missing row* case, not the *shared row* case — legacy pre-auth rows that survived a clean migration without an owner are still everyone's. """ - repo = await _make_repo(tmp_path) await repo.create("t1", user_id=None) assert await repo.check_access("t1", "anyone", require_existing=True) is True - await _cleanup() @pytest.mark.anyio - async def test_update_status(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_update_status(self, repo): await repo.create("t1") await repo.update_status("t1", "busy") record = await repo.get("t1") assert record["status"] == "busy" - await _cleanup() @pytest.mark.anyio - async def test_delete(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_delete(self, repo): await repo.create("t1") await repo.delete("t1") assert await repo.get("t1") is None - await _cleanup() @pytest.mark.anyio - async def test_delete_nonexistent_is_noop(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_delete_nonexistent_is_noop(self, repo): await repo.delete("nonexistent") # should not raise - await _cleanup() @pytest.mark.anyio - async def test_update_metadata_merges(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_update_metadata_merges(self, repo): await repo.create("t1", metadata={"a": 1, "b": 2}) await repo.update_metadata("t1", {"b": 99, "c": 3}) record = await repo.get("t1") # Existing key preserved, overlapping key overwritten, new key added assert record["metadata"] == {"a": 1, "b": 99, "c": 3} - await _cleanup() @pytest.mark.anyio - async def test_update_metadata_on_empty(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_update_metadata_on_empty(self, repo): await repo.create("t1") await repo.update_metadata("t1", {"k": "v"}) record = await repo.get("t1") assert record["metadata"] == {"k": "v"} - await _cleanup() @pytest.mark.anyio - async def test_update_metadata_nonexistent_is_noop(self, tmp_path): - repo = await _make_repo(tmp_path) + async def test_update_metadata_nonexistent_is_noop(self, repo): await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise - await _cleanup() + + # --- search with metadata filter (SQL push-down) --- + + @pytest.mark.anyio + async def test_search_metadata_filter_string(self, repo): + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + await repo.create("t3", metadata={"env": "prod", "region": "us"}) + + results = await repo.search(metadata={"env": "prod"}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1", "t3"} + + @pytest.mark.anyio + async def test_search_metadata_filter_numeric(self, repo): + await repo.create("t1", metadata={"priority": 1}) + await repo.create("t2", metadata={"priority": 2}) + await repo.create("t3", metadata={"priority": 1, "extra": "x"}) + + results = await repo.search(metadata={"priority": 1}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1", "t3"} + + @pytest.mark.anyio + async def test_search_metadata_filter_multiple_keys(self, repo): + await repo.create("t1", metadata={"env": "prod", "region": "us"}) + await repo.create("t2", metadata={"env": "prod", "region": "eu"}) + await repo.create("t3", metadata={"env": "staging", "region": "us"}) + + results = await repo.search(metadata={"env": "prod", "region": "us"}) + assert len(results) == 1 + assert results[0]["thread_id"] == "t1" + + @pytest.mark.anyio + async def test_search_metadata_no_match(self, repo): + await repo.create("t1", metadata={"env": "prod"}) + + results = await repo.search(metadata={"env": "dev"}) + assert results == [] + + @pytest.mark.anyio + async def test_search_metadata_pagination_correct(self, repo): + """Regression: SQL push-down makes limit/offset exact even when most rows don't match.""" + for i in range(30): + meta = {"target": "yes"} if i % 3 == 0 else {"target": "no"} + await repo.create(f"t{i:03d}", metadata=meta) + + # Total matching rows: i in {0,3,6,9,12,15,18,21,24,27} = 10 rows + all_matches = await repo.search(metadata={"target": "yes"}, limit=100) + assert len(all_matches) == 10 + + # Paginate: first page + page1 = await repo.search(metadata={"target": "yes"}, limit=3, offset=0) + assert len(page1) == 3 + + # Paginate: second page + page2 = await repo.search(metadata={"target": "yes"}, limit=3, offset=3) + assert len(page2) == 3 + + # No overlap between pages + page1_ids = {r["thread_id"] for r in page1} + page2_ids = {r["thread_id"] for r in page2} + assert page1_ids.isdisjoint(page2_ids) + + # Last page + page_last = await repo.search(metadata={"target": "yes"}, limit=3, offset=9) + assert len(page_last) == 1 + + @pytest.mark.anyio + async def test_search_metadata_with_status_filter(self, repo): + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "prod"}) + await repo.update_status("t1", "busy") + + results = await repo.search(metadata={"env": "prod"}, status="busy") + assert len(results) == 1 + assert results[0]["thread_id"] == "t1" + + @pytest.mark.anyio + async def test_search_without_metadata_still_works(self, repo): + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2") + + results = await repo.search(limit=10) + assert len(results) == 2 + + @pytest.mark.anyio + async def test_search_metadata_missing_key_no_match(self, repo): + """Rows without the requested metadata key should not match.""" + await repo.create("t1", metadata={"other": "val"}) + await repo.create("t2", metadata={"env": "prod"}) + + results = await repo.search(metadata={"env": "prod"}) + assert len(results) == 1 + assert results[0]["thread_id"] == "t2" + + @pytest.mark.anyio + async def test_search_metadata_all_unsafe_keys_raises(self, repo, caplog): + """When ALL metadata keys are unsafe, raises InvalidMetadataFilterError.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + with pytest.raises(InvalidMetadataFilterError, match="rejected") as exc_info: + await repo.search(metadata={"bad;key": "x"}) + assert any("bad;key" in r.message for r in caplog.records) + # Subclass of ValueError for backward compatibility + assert isinstance(exc_info.value, ValueError) + + @pytest.mark.anyio + async def test_search_metadata_partial_unsafe_key_skipped(self, repo, caplog): + """Valid keys filter rows; only the invalid key is warned and skipped.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + results = await repo.search(metadata={"env": "prod", "bad;key": "x"}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1"} + assert any("bad;key" in r.message for r in caplog.records) + + @pytest.mark.anyio + async def test_search_metadata_filter_boolean(self, repo): + """True matches only boolean true, not integer 1.""" + await repo.create("t1", metadata={"active": True}) + await repo.create("t2", metadata={"active": False}) + await repo.create("t3", metadata={"active": True, "extra": "x"}) + await repo.create("t4", metadata={"active": 1}) + + results = await repo.search(metadata={"active": True}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1", "t3"} + + @pytest.mark.anyio + async def test_search_metadata_filter_none(self, repo): + """Only rows with explicit JSON null match; missing key does not.""" + await repo.create("t1", metadata={"tag": None}) + await repo.create("t2", metadata={"tag": "present"}) + await repo.create("t3", metadata={"other": "val"}) + + results = await repo.search(metadata={"tag": None}) + ids = {r["thread_id"] for r in results} + assert ids == {"t1"} + + @pytest.mark.anyio + async def test_search_metadata_non_string_key_skipped(self, repo, caplog): + """Non-string keys raise ValueError from isinstance check; should be warned and skipped.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + with pytest.raises(InvalidMetadataFilterError, match="rejected"): + await repo.search(metadata={1: "x"}) + assert any("1" in r.message for r in caplog.records) + + @pytest.mark.anyio + async def test_search_metadata_unsupported_value_type_skipped(self, repo, caplog): + """Unsupported value types (list, dict) raise TypeError; should be warned and skipped.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + with pytest.raises(InvalidMetadataFilterError, match="rejected"): + await repo.search(metadata={"env": ["prod", "staging"]}) + + @pytest.mark.anyio + async def test_search_metadata_dotted_key_raises(self, repo, caplog): + """Dotted keys are rejected; when ALL keys are dotted, raises ValueError.""" + await repo.create("t1", metadata={"env": "prod"}) + await repo.create("t2", metadata={"env": "staging"}) + + with caplog.at_level(logging.WARNING, logger="deerflow.persistence.thread_meta.sql"): + with pytest.raises(InvalidMetadataFilterError, match="rejected"): + await repo.search(metadata={"a.b": "anything"}) + assert any("a.b" in r.message for r in caplog.records) + + # --- dialect-aware type-safe filtering edge cases --- + + @pytest.mark.anyio + async def test_search_metadata_bool_vs_int_distinction(self, repo): + """True must not match 1; False must not match 0.""" + await repo.create("bool_true", metadata={"flag": True}) + await repo.create("bool_false", metadata={"flag": False}) + await repo.create("int_one", metadata={"flag": 1}) + await repo.create("int_zero", metadata={"flag": 0}) + + true_hits = {r["thread_id"] for r in await repo.search(metadata={"flag": True})} + assert true_hits == {"bool_true"} + + false_hits = {r["thread_id"] for r in await repo.search(metadata={"flag": False})} + assert false_hits == {"bool_false"} + + @pytest.mark.anyio + async def test_search_metadata_int_does_not_match_bool(self, repo): + """Integer 1 must not match boolean True.""" + await repo.create("bool_true", metadata={"val": True}) + await repo.create("int_one", metadata={"val": 1}) + + hits = {r["thread_id"] for r in await repo.search(metadata={"val": 1})} + assert hits == {"int_one"} + + @pytest.mark.anyio + async def test_search_metadata_none_excludes_missing_key(self, repo): + """Filtering by None matches explicit JSON null only, not missing key or empty {}.""" + await repo.create("explicit_null", metadata={"k": None}) + await repo.create("missing_key", metadata={"other": "x"}) + await repo.create("empty_obj", metadata={}) + + hits = {r["thread_id"] for r in await repo.search(metadata={"k": None})} + assert hits == {"explicit_null"} + + @pytest.mark.anyio + async def test_search_metadata_float_value(self, repo): + await repo.create("t1", metadata={"score": 3.14}) + await repo.create("t2", metadata={"score": 2.71}) + await repo.create("t3", metadata={"score": 3.14}) + + hits = {r["thread_id"] for r in await repo.search(metadata={"score": 3.14})} + assert hits == {"t1", "t3"} + + @pytest.mark.anyio + async def test_search_metadata_mixed_types_same_key(self, repo): + """Each type query only matches its own type, even when the key is shared.""" + await repo.create("str_row", metadata={"x": "hello"}) + await repo.create("int_row", metadata={"x": 42}) + await repo.create("bool_row", metadata={"x": True}) + await repo.create("null_row", metadata={"x": None}) + + assert {r["thread_id"] for r in await repo.search(metadata={"x": "hello"})} == {"str_row"} + assert {r["thread_id"] for r in await repo.search(metadata={"x": 42})} == {"int_row"} + assert {r["thread_id"] for r in await repo.search(metadata={"x": True})} == {"bool_row"} + assert {r["thread_id"] for r in await repo.search(metadata={"x": None})} == {"null_row"} + + @pytest.mark.anyio + async def test_search_metadata_large_int_precision(self, repo): + """Integers beyond float precision (> 2**53) must match exactly.""" + large = 2**53 + 1 + await repo.create("t1", metadata={"id": large}) + await repo.create("t2", metadata={"id": large - 1}) + + hits = {r["thread_id"] for r in await repo.search(metadata={"id": large})} + assert hits == {"t1"} + + +class TestJsonMatchCompilation: + """Verify compiled SQL for both SQLite and PostgreSQL dialects.""" + + def test_json_match_compiles_sqlite(self): + from sqlalchemy import Column, MetaData, String, Table, create_engine + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + engine = create_engine("sqlite://") + + cases = [ + (None, "json_type(t.data, '$.\"k\"') = 'null'"), + (True, "json_type(t.data, '$.\"k\"') = 'true'"), + (False, "json_type(t.data, '$.\"k\"') = 'false'"), + ] + for value, expected_fragment in cases: + expr = json_match(t.c.data, "k", value) + sql = expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True}) + assert str(sql) == expected_fragment, f"value={value!r}: {sql}" + + # int: uses INTEGER cast for precision, type-check narrows to 'integer' only + int_expr = json_match(t.c.data, "k", 42) + sql = str(int_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})) + assert "json_type" in sql + assert "= 'integer'" in sql + assert "INTEGER" in sql + assert "CAST" in sql + + # float: uses REAL cast, type-check spans 'integer' and 'real' + float_expr = json_match(t.c.data, "k", 3.14) + sql = str(float_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})) + assert "json_type" in sql + assert "IN ('integer', 'real')" in sql + assert "REAL" in sql + + str_expr = json_match(t.c.data, "k", "hello") + sql = str(str_expr.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})) + assert "json_type" in sql + assert "'text'" in sql + + def test_json_match_compiles_pg(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.dialects import postgresql + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + dialect = postgresql.dialect() + + cases = [ + (None, "json_typeof(t.data -> 'k') = 'null'"), + (True, "(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'true')"), + (False, "(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')"), + ] + for value, expected_fragment in cases: + expr = json_match(t.c.data, "k", value) + sql = expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True}) + assert str(sql) == expected_fragment, f"value={value!r}: {sql}" + + # int: CASE guard prevents CAST error when 'number' also matches floats + int_expr = json_match(t.c.data, "k", 42) + sql = str(int_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + assert "json_typeof" in sql + assert "'number'" in sql + assert "BIGINT" in sql + assert "CASE WHEN" in sql + assert "'^-?[0-9]+$'" in sql + + # float: uses DOUBLE PRECISION cast + float_expr = json_match(t.c.data, "k", 3.14) + sql = str(float_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + assert "json_typeof" in sql + assert "'number'" in sql + assert "DOUBLE PRECISION" in sql + + str_expr = json_match(t.c.data, "k", "hello") + sql = str(str_expr.compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + assert "json_typeof" in sql + assert "'string'" in sql + + def test_json_match_rejects_unsafe_key(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + + for bad_key in ["a.b", "with space", "bad'quote", 'bad"quote', "back\\slash", "semi;colon", ""]: + with pytest.raises(ValueError, match="JsonMatch key must match"): + json_match(t.c.data, bad_key, "x") + + # Non-string keys must also raise ValueError (not TypeError from re.match) + for non_str_key in [42, None, ("k",)]: + with pytest.raises(ValueError, match="JsonMatch key must match"): + json_match(t.c.data, non_str_key, "x") + + def test_json_match_rejects_unsupported_value_type(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + + for bad_value in [[], {}, object()]: + with pytest.raises(TypeError, match="JsonMatch value must be"): + json_match(t.c.data, "k", bad_value) + + def test_json_match_unsupported_dialect_raises(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.dialects import mysql + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + expr = json_match(t.c.data, "k", "v") + + with pytest.raises(NotImplementedError, match="mysql"): + str(expr.compile(dialect=mysql.dialect(), compile_kwargs={"literal_binds": True})) + + def test_json_match_rejects_out_of_range_int(self): + from sqlalchemy import Column, MetaData, String, Table + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + + # boundary values must be accepted + json_match(t.c.data, "k", 2**63 - 1) + json_match(t.c.data, "k", -(2**63)) + + # one beyond each boundary must be rejected + for out_of_range in [2**63, -(2**63) - 1, 10**30]: + with pytest.raises(TypeError, match="out of signed 64-bit range"): + json_match(t.c.data, "k", out_of_range) + + def test_compiler_raises_on_escaped_key(self): + """Compiler raises ValueError even when __init__ validation is bypassed.""" + from sqlalchemy import Column, MetaData, String, Table, create_engine + from sqlalchemy.dialects import postgresql + from sqlalchemy.types import JSON + + from deerflow.persistence.json_compat import json_match + + metadata = MetaData() + t = Table("t", metadata, Column("data", JSON), Column("id", String)) + engine = create_engine("sqlite://") + + elem = json_match(t.c.data, "k", "v") + elem.key = "bad.key" # bypass __init__ to simulate -O stripping assert + + with pytest.raises(ValueError, match="Key escaped validation"): + str(elem.compile(dialect=engine.dialect, compile_kwargs={"literal_binds": True})) + + with pytest.raises(ValueError, match="Key escaped validation"): + str(elem.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True})) diff --git a/backend/tests/test_threads_router.py b/backend/tests/test_threads_router.py index daf0c0b13..9e37f3c86 100644 --- a/backend/tests/test_threads_router.py +++ b/backend/tests/test_threads_router.py @@ -10,6 +10,7 @@ from langgraph.store.memory import InMemoryStore from app.gateway.routers import threads from deerflow.config.paths import Paths +from deerflow.persistence.thread_meta import InvalidMetadataFilterError from deerflow.persistence.thread_meta.memory import THREADS_NS, MemoryThreadMetaStore _ISO_TIMESTAMP_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}") @@ -431,3 +432,56 @@ def test_get_thread_history_returns_iso_for_legacy_checkpoint_metadata() -> None assert entries, "expected at least one history entry" for entry in entries: assert _ISO_TIMESTAMP_RE.match(entry["created_at"]), entry + + +# ── Metadata filter validation at API boundary ──────────────────────────────── + + +def test_search_threads_rejects_invalid_key_at_api_boundary() -> None: + """Keys that don't match [A-Za-z0-9_-]+ are rejected by the Pydantic + validator on ThreadSearchRequest.metadata — 422 from both backends. + """ + app, _store, _checkpointer = _build_thread_app() + + with TestClient(app) as client: + response = client.post("/api/threads/search", json={"metadata": {"bad;key": "x"}}) + + assert response.status_code == 422 + + +def test_search_threads_rejects_unsupported_value_type_at_api_boundary() -> None: + """Value types outside (None, bool, int, float, str) are rejected.""" + app, _store, _checkpointer = _build_thread_app() + + with TestClient(app) as client: + response = client.post("/api/threads/search", json={"metadata": {"env": ["a", "b"]}}) + + assert response.status_code == 422 + + +def test_search_threads_returns_400_for_backend_invalid_metadata_filter() -> None: + """If the backend still raises InvalidMetadataFilterError (defense in + depth), the handler surfaces it as HTTP 400. + """ + app, _store, _checkpointer = _build_thread_app() + thread_store = app.state.thread_store + + async def _raise(**kwargs): + raise InvalidMetadataFilterError("rejected") + + with TestClient(app) as client: + with patch.object(thread_store, "search", side_effect=_raise): + response = client.post("/api/threads/search", json={"metadata": {"valid_key": "x"}}) + + assert response.status_code == 400 + assert "rejected" in response.json()["detail"] + + +def test_search_threads_succeeds_with_valid_metadata() -> None: + """Sanity check: valid metadata passes through without error.""" + app, _store, _checkpointer = _build_thread_app() + + with TestClient(app) as client: + response = client.post("/api/threads/search", json={"metadata": {"env": "prod"}}) + + assert response.status_code == 200 From 2a1ac06bf4ee0efdc6bc312d5b0548a321f591d6 Mon Sep 17 00:00:00 2001 From: Eilen Shin <136898293+Eilen6316@users.noreply.github.com> Date: Wed, 13 May 2026 15:49:34 +0800 Subject: [PATCH 18/86] fix(persistence): reuse token usage model grouping expression (#2910) --- .../harness/deerflow/persistence/run/sql.py | 5 +- backend/tests/test_run_repository.py | 48 +++++++++++++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/backend/packages/harness/deerflow/persistence/run/sql.py b/backend/packages/harness/deerflow/persistence/run/sql.py index 430fbe4f6..5331451e3 100644 --- a/backend/packages/harness/deerflow/persistence/run/sql.py +++ b/backend/packages/harness/deerflow/persistence/run/sql.py @@ -223,10 +223,11 @@ class RunRepository(RunStore): """Aggregate token usage via a single SQL GROUP BY query.""" _completed = RunRow.status.in_(("success", "error")) _thread = RunRow.thread_id == thread_id + model_name = func.coalesce(RunRow.model_name, "unknown") stmt = ( select( - func.coalesce(RunRow.model_name, "unknown").label("model"), + model_name.label("model"), func.count().label("runs"), func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"), func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"), @@ -236,7 +237,7 @@ class RunRepository(RunStore): func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"), ) .where(_thread, _completed) - .group_by(func.coalesce(RunRow.model_name, "unknown")) + .group_by(model_name) ) async with self._sf() as session: diff --git a/backend/tests/test_run_repository.py b/backend/tests/test_run_repository.py index 6fd534829..5e230e790 100644 --- a/backend/tests/test_run_repository.py +++ b/backend/tests/test_run_repository.py @@ -3,7 +3,10 @@ Uses a temp SQLite DB to test ORM-backed CRUD operations. """ +import re + import pytest +from sqlalchemy.dialects import postgresql from deerflow.persistence.run import RunRepository @@ -278,3 +281,48 @@ class TestRunRepository: assert row4["model_name"] is None await _cleanup() + + @pytest.mark.anyio + async def test_aggregate_tokens_by_thread_reuses_shared_model_name_expression(self): + captured = [] + + class FakeResult: + def all(self): + return [] + + class FakeSession: + async def execute(self, stmt): + captured.append(stmt) + return FakeResult() + + class FakeSessionContext: + async def __aenter__(self): + return FakeSession() + + async def __aexit__(self, exc_type, exc, tb): + return None + + repo = RunRepository(lambda: FakeSessionContext()) + + agg = await repo.aggregate_tokens_by_thread("t1") + assert agg == { + "total_tokens": 0, + "total_input_tokens": 0, + "total_output_tokens": 0, + "total_runs": 0, + "by_model": {}, + "by_caller": {"lead_agent": 0, "subagent": 0, "middleware": 0}, + } + assert len(captured) == 1 + + stmt = captured[0] + compiled_sql = str(stmt.compile(dialect=postgresql.dialect())) + select_sql, group_by_sql = compiled_sql.split(" GROUP BY ", maxsplit=1) + model_expr_pattern = r"coalesce\(runs\.model_name, %\(([^)]+)\)s\)" + + select_match = re.search(model_expr_pattern + r" AS model", select_sql) + group_by_match = re.fullmatch(model_expr_pattern, group_by_sql.strip()) + + assert select_match is not None + assert group_by_match is not None + assert select_match.group(1) == group_by_match.group(1) From f1a0ab699aee5642fccf9f0fc211b231d41e6b5d Mon Sep 17 00:00:00 2001 From: Xinmin Zeng <135568692+fancyboi999@users.noreply.github.com> Date: Wed, 13 May 2026 23:45:47 +0800 Subject: [PATCH 19/86] fix(tools): preserve tool_search promotions across re-entrant get_available_tools (#2885) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(tools): preserve tool_search promotions across re-entrant get_available_tools Closes #2884. ``get_available_tools`` used to unconditionally call ``reset_deferred_registry()`` and rebuild a fresh ``DeferredToolRegistry`` on every invocation. That works for the first call of a request (the ContextVar starts at its default of ``None``), but any RE-ENTRANT call during the same async context — e.g. ``task_tool`` building a subagent's toolset, or a custom middleware that rebuilds tools mid-run — wiped any ``tool_search`` promotions the parent agent had already made. The ``DeferredToolFilterMiddleware`` would then re-hide those tools from the next model call, leaving the agent able to see a tool's name (via the prior ``tool_search`` result that's still in conversation history) but unable to invoke it. Fix: when the ContextVar already holds a registry, reuse it instead of rebuilding. Fresh requests still get a fresh registry because each new graph run starts in a new asyncio task with the ContextVar at ``None``. ## Verification - Unit-level reproduction (``test_get_available_tools_resets_registry_wiping_promotion``): promote a tool in the registry, call ``get_available_tools`` again, assert the promotion is preserved. Fails on main, passes on this branch. - Graph-execution reproduction (two tests): drive a real ``langchain.agents.create_agent`` graph with the real ``DeferredToolFilterMiddleware`` through two model turns, including one that issues a re-entrant ``get_available_tools`` call to simulate the task_tool subagent path. - Real-LLM end-to-end (``test_deferred_tool_promotion_real_llm.py``, opt-in via ``ONEAPI_E2E=1``): drives the same flow against a real OpenAI-compatible model (verified on GPT-5.4-mini through the one-api gateway), watches the model call the promoted ``fake_calculator`` through the deferred-filter middleware, and asserts the right arithmetic result. Passes against the fixed branch. - Companion update to ``test_tool_deduplication.py``: dropped the ``@patch("deerflow.tools.tools.reset_deferred_registry")`` decorators because the symbol is no longer imported there. - Test fixtures in the new files patch ``deerflow.tools.tools.get_app_config`` with a minimal ``model_construct``-ed ``AppConfig`` instead of calling the real loader, so they never trigger ``_apply_singleton_configs`` and never leak ``_memory_config``/``_title_config``/… mutations into the rest of the suite. Full backend suite: 3208 passed / 14 skipped / 0 failed. ruff check + format clean. * fix(tools): address Copilot review on #2885 - tools.py: rewrite the reuse-path comment to spell out (a) why we don't reconcile the registry against the current ``mcp_tools`` snapshot — the MCP cache doesn't refresh mid-graph-run, the lead agent's ``ToolNode`` is already bound to the previous tool set anyway, and ``promote()`` drops the entry so a naive re-sync misclassifies promotions as new tools — and (b) why the log uses ``max(0, …)`` to avoid negative counts when the cache shrinks between snapshots. - Replace direct ``ts_mod._registry_var.set(None)`` in test fixtures with the public ``reset_deferred_registry()`` helper so tests don't couple to module internals. - Correct the docstring path in ``test_deferred_tool_registry_promotion.py`` to match the actual monkeypatch target (``deerflow.mcp.cache.get_cached_mcp_tools``). - Rename ``test_get_available_tools_resets_registry_wiping_promotion`` to ``test_get_available_tools_preserves_promotions_across_reentrant_calls`` so the test name describes the contract being asserted, not the bug it originally reproduced. Full backend suite: 3208 passed / 14 skipped. Real-LLM e2e: 1 passed. --- .../packages/harness/deerflow/tools/tools.py | 53 ++- .../test_deferred_tool_promotion_real_llm.py | 222 ++++++++++ .../test_deferred_tool_registry_promotion.py | 390 ++++++++++++++++++ backend/tests/test_tool_deduplication.py | 12 +- 4 files changed, 661 insertions(+), 16 deletions(-) create mode 100644 backend/tests/test_deferred_tool_promotion_real_llm.py create mode 100644 backend/tests/test_deferred_tool_registry_promotion.py diff --git a/backend/packages/harness/deerflow/tools/tools.py b/backend/packages/harness/deerflow/tools/tools.py index 01bfce43f..5c97962fc 100644 --- a/backend/packages/harness/deerflow/tools/tools.py +++ b/backend/packages/harness/deerflow/tools/tools.py @@ -7,7 +7,7 @@ from deerflow.config.app_config import AppConfig from deerflow.reflection import resolve_variable from deerflow.sandbox.security import is_host_bash_allowed from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool -from deerflow.tools.builtins.tool_search import reset_deferred_registry +from deerflow.tools.builtins.tool_search import get_deferred_registry from deerflow.tools.sync import make_sync_tool_wrapper logger = logging.getLogger(__name__) @@ -116,8 +116,6 @@ def get_available_tools( # made through the Gateway API (which runs in a separate process) are immediately # reflected when loading MCP tools. mcp_tools = [] - # Reset deferred registry upfront to prevent stale state from previous calls - reset_deferred_registry() if include_mcp: try: from deerflow.config.extensions_config import ExtensionsConfig @@ -135,12 +133,51 @@ def get_available_tools( from deerflow.tools.builtins.tool_search import DeferredToolRegistry, set_deferred_registry from deerflow.tools.builtins.tool_search import tool_search as tool_search_tool - registry = DeferredToolRegistry() - for t in mcp_tools: - registry.register(t) - set_deferred_registry(registry) + # Reuse the existing registry if one is already set for + # this async context. ``get_available_tools`` is + # re-entered whenever a subagent is spawned + # (``task_tool`` calls it to build the child agent's + # toolset), and previously we used to unconditionally + # rebuild the registry — wiping out the parent agent's + # tool_search promotions. The + # ``DeferredToolFilterMiddleware`` then re-hid those + # tools from subsequent model calls, leaving the agent + # able to see a tool's name but unable to invoke it + # (issue #2884). ``contextvars`` already gives us the + # lifetime semantics we want: a fresh request / graph + # run starts in a new asyncio task with the + # ContextVar at its default of ``None``, so reuse is + # only triggered for re-entrant calls inside one run. + # + # Intentionally NOT reconciling against the current + # ``mcp_tools`` snapshot. The MCP cache only refreshes + # on ``extensions_config.json`` mtime changes, which + # in practice happens between graph runs — not inside + # one. And even if a refresh did happen mid-run, the + # already-built lead agent's ``ToolNode`` still holds + # the *previous* tool set (LangGraph binds tools at + # graph construction time), so a brand-new MCP tool + # couldn't actually be invoked anyway. The + # ``DeferredToolRegistry`` doesn't retain the names + # of previously-promoted tools (``promote()`` drops + # the entry entirely), so re-syncing the registry + # against a fresh ``mcp_tools`` list would + # mis-classify those promotions as new tools and + # re-register them as deferred — exactly the bug + # this fix exists to prevent. + existing_registry = get_deferred_registry() + if existing_registry is None: + registry = DeferredToolRegistry() + for t in mcp_tools: + registry.register(t) + set_deferred_registry(registry) + logger.info(f"Tool search active: {len(mcp_tools)} tools deferred") + else: + mcp_tool_names = {t.name for t in mcp_tools} + still_deferred = len(existing_registry) + promoted_count = max(0, len(mcp_tool_names) - still_deferred) + logger.info(f"Tool search active (preserved promotions): {still_deferred} tools deferred, {promoted_count} already promoted") builtin_tools.append(tool_search_tool) - logger.info(f"Tool search active: {len(mcp_tools)} tools deferred") except ImportError: logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.") except Exception as e: diff --git a/backend/tests/test_deferred_tool_promotion_real_llm.py b/backend/tests/test_deferred_tool_promotion_real_llm.py new file mode 100644 index 000000000..46ae24d41 --- /dev/null +++ b/backend/tests/test_deferred_tool_promotion_real_llm.py @@ -0,0 +1,222 @@ +"""Real-LLM end-to-end verification for issue #2884. + +Drives a real ``langchain.agents.create_agent`` graph against a real OpenAI- +compatible LLM (one-api gateway), bound through ``DeferredToolFilterMiddleware`` +and the production ``get_available_tools`` pipeline. The only thing we mock is +the MCP tool source — we hand-roll two ``@tool``s and inject them through +``deerflow.mcp.cache.get_cached_mcp_tools``. + +The flow exercised: + 1. Turn 1: agent sees ``tool_search`` (plus a ``fake_subagent_trigger`` + that re-enters ``get_available_tools`` on the same task — this is the + code path issue #2884 reports). It must call ``tool_search`` to + discover the deferred ``fake_calculator`` tool. + 2. Tool batch: ``tool_search`` promotes ``fake_calculator``; + ``fake_subagent_trigger`` re-enters ``get_available_tools``. + 3. Turn 2: the promoted ``fake_calculator`` schema must reach the model + so it can actually call it. Without this PR's fix, the re-entry wipes + the promotion and the model can no longer invoke the tool. + +Skipped unless ``ONEAPI_E2E=1`` is set so this doesn't burn credits on every +test run. Run with:: + + ONEAPI_E2E=1 OPENAI_API_KEY=... OPENAI_API_BASE=... \ + PYTHONPATH=. uv run pytest \ + tests/test_deferred_tool_promotion_real_llm.py -v -s +""" + +from __future__ import annotations + +import os + +import pytest +from langchain_core.messages import HumanMessage +from langchain_core.tools import tool as as_tool + +# --------------------------------------------------------------------------- +# Skip control: only run when explicitly opted in. +# --------------------------------------------------------------------------- + + +pytestmark = pytest.mark.skipif( + os.getenv("ONEAPI_E2E") != "1", + reason="Real-LLM e2e: opt in with ONEAPI_E2E=1 (requires OPENAI_API_KEY + OPENAI_API_BASE)", +) + + +# --------------------------------------------------------------------------- +# Fake "MCP" tools the agent should discover via tool_search. +# Keep them obviously synthetic so the model can pattern-match the search. +# --------------------------------------------------------------------------- + + +_calls: list[str] = [] + + +@as_tool +def fake_calculator(expression: str) -> str: + """Evaluate a tiny arithmetic expression like '2 + 2'. + + Reserved for the user — only call this if the user asks for arithmetic. + """ + _calls.append(f"fake_calculator:{expression}") + try: + # Trivially safe-eval just for the e2e check + allowed = set("0123456789+-*/() .") + if not set(expression) <= allowed: + return "expression contains disallowed characters" + return str(eval(expression, {"__builtins__": {}}, {})) # noqa: S307 + except Exception as e: + return f"error: {e}" + + +@as_tool +def fake_translator(text: str, target_lang: str) -> str: + """Translate text into the given language code. Decorative — not used.""" + _calls.append(f"fake_translator:{text}:{target_lang}") + return f"[{target_lang}] {text}" + + +# --------------------------------------------------------------------------- +# Pipeline wiring (same shape as the in-process tests). +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_registry_between_tests(): + from deerflow.tools.builtins.tool_search import reset_deferred_registry + + reset_deferred_registry() + yield + reset_deferred_registry() + + +def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None: + from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig + + real_ext = ExtensionsConfig( + mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)}, + ) + monkeypatch.setattr( + "deerflow.config.extensions_config.ExtensionsConfig.from_file", + classmethod(lambda cls: real_ext), + ) + monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools)) + + +def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None: + """Build a minimal mock AppConfig and patch the symbol — never call the + real loader, which would trigger ``_apply_singleton_configs`` and + permanently mutate cross-test singletons (memory, title, …).""" + from deerflow.config.app_config import AppConfig + from deerflow.config.tool_search_config import ToolSearchConfig + + mock_cfg = AppConfig.model_construct( + log_level="info", + models=[], + tools=[], + tool_groups=[], + sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"), + tool_search=ToolSearchConfig(enabled=True), + ) + monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg) + + +# --------------------------------------------------------------------------- +# Real-LLM e2e test +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_real_llm_promotes_then_invokes_with_subagent_reentry(monkeypatch: pytest.MonkeyPatch): + """End-to-end against a real OpenAI-compatible LLM. + + The model must: + Turn 1 — see ``tool_search`` (deferred tools aren't bound yet) and + batch-call BOTH ``tool_search(select:fake_calculator)`` AND + ``fake_subagent_trigger(...)``. + Turn 2 — call ``fake_calculator`` and finish. + + Pass criterion: ``fake_calculator`` actually gets invoked at the tool + layer — recorded in ``_calls`` — which proves the model received the + promoted schema after the re-entrant ``get_available_tools`` call. + """ + from langchain.agents import create_agent + from langchain_openai import ChatOpenAI + + from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware + from deerflow.tools.tools import get_available_tools + + _patch_mcp_pipeline(monkeypatch, [fake_calculator, fake_translator]) + _force_tool_search_enabled(monkeypatch) + _calls.clear() + + @as_tool + async def fake_subagent_trigger(prompt: str) -> str: + """Pretend to spawn a subagent. Internally rebuilds the toolset. + + Use this whenever the user asks you to delegate work — pass a short + description as ``prompt``. + """ + # ``task_tool`` does this internally. Whether the registry-reset that + # used to happen here actually leaks back to the parent task depends + # on asyncio's implicit context-copying semantics (gather creates + # child tasks with copied contexts, so reset_deferred_registry is + # task-local) — but the fix in this PR is what GUARANTEES the + # promotion sticks regardless of which integration path triggers a + # re-entrant ``get_available_tools`` call. + get_available_tools(subagent_enabled=False) + _calls.append(f"fake_subagent_trigger:{prompt}") + return "subagent completed" + + tools = get_available_tools() + [fake_subagent_trigger] + + model = ChatOpenAI( + model=os.environ.get("ONEAPI_MODEL", "claude-sonnet-4-6"), + api_key=os.environ["OPENAI_API_KEY"], + base_url=os.environ["OPENAI_API_BASE"], + temperature=0, + max_retries=1, + ) + + system_prompt = ( + "You are a meticulous assistant. Available deferred tools include a " + "calculator and a translator — their schemas are hidden until you " + "search for them via tool_search.\n\n" + "Procedure for the user's request:\n" + " 1. Call tool_search with query 'select:fake_calculator' AND " + "in the SAME tool batch also call fake_subagent_trigger(prompt='go') " + "to delegate the side work. Put both tool_calls in your first response.\n" + " 2. After both tool messages come back, call fake_calculator with " + "the user's expression.\n" + " 3. Reply with just the numeric result." + ) + + graph = create_agent( + model=model, + tools=tools, + middleware=[DeferredToolFilterMiddleware()], + system_prompt=system_prompt, + ) + + result = await graph.ainvoke( + {"messages": [HumanMessage(content="What is 17 * 23? Use the deferred calculator tool.")]}, + config={"recursion_limit": 12}, + ) + + print("\n=== tool calls recorded ===") + for c in _calls: + print(f" {c}") + print("\n=== final message ===") + final_text = result["messages"][-1].content if result["messages"] else "(none)" + print(f" {final_text!r}") + + # The smoking-gun assertion: fake_calculator was actually invoked at the + # tool layer. This is only possible if the promoted schema reached the + # model in turn 2, despite the subagent-style re-entry in turn 1. + calc_calls = [c for c in _calls if c.startswith("fake_calculator:")] + assert calc_calls, f"REGRESSION (#2884): the model never managed to call fake_calculator. All recorded tool calls: {_calls!r}. Final text: {final_text!r}" + + # And the math should actually be done correctly (sanity that the LLM + # really used the result, not just hallucinated the answer). + assert "391" in str(final_text), f"Model didn't surface 17*23=391. Final text: {final_text!r}" diff --git a/backend/tests/test_deferred_tool_registry_promotion.py b/backend/tests/test_deferred_tool_registry_promotion.py new file mode 100644 index 000000000..23b7649ec --- /dev/null +++ b/backend/tests/test_deferred_tool_registry_promotion.py @@ -0,0 +1,390 @@ +"""Reproduce + regression-guard issue #2884. + +Hypothesis from the issue: + ``tools.tools.get_available_tools`` unconditionally calls + ``reset_deferred_registry()`` and constructs a fresh ``DeferredToolRegistry`` + every time it is invoked. If anything calls ``get_available_tools`` again + during the same async context (after the agent has promoted tools via + ``tool_search``), the promotion is wiped and the next model call hides the + tool's schema again. + +These tests pin two things: + +A. **At the unit boundary** — verify the failure mode directly. Promote a + tool in the registry, then call ``get_available_tools`` again and observe + that the ContextVar registry is reset and the promotion is lost. + +B. **At the graph-execution boundary** — drive a real ``create_agent`` graph + with the real ``DeferredToolFilterMiddleware`` through two model turns. + The first turn calls ``tool_search`` which promotes a tool. The second + turn must see that tool's schema in ``request.tools``. If + ``get_available_tools`` were to run again between the two turns and reset + the registry, the second turn's filter would strip the tool. + +Strategy: use the production ``deerflow.tools.tools.get_available_tools`` +unmodified; mock only the LLM and the MCP tool source. Patch +``deerflow.mcp.cache.get_cached_mcp_tools`` (the symbol that +``get_available_tools`` resolves via lazy import) to return our fixture +tools so we don't need a real MCP server. +""" + +from __future__ import annotations + +from typing import Any + +import pytest +from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.runnables import Runnable +from langchain_core.tools import tool as as_tool + + +class FakeToolCallingModel(FakeMessagesListChatModel): + """FakeMessagesListChatModel + no-op bind_tools so create_agent works.""" + + def bind_tools( # type: ignore[override] + self, + tools: Any, + *, + tool_choice: Any = None, + **kwargs: Any, + ) -> Runnable: + return self + + +# --------------------------------------------------------------------------- +# Fixtures: a fake MCP tool source + a way to force config.tool_search.enabled +# --------------------------------------------------------------------------- + + +@as_tool +def fake_mcp_search(query: str) -> str: + """Pretend to search a knowledge base for the given query.""" + return f"results for {query}" + + +@as_tool +def fake_mcp_fetch(url: str) -> str: + """Pretend to fetch a page at the given URL.""" + return f"content of {url}" + + +@pytest.fixture(autouse=True) +def _supply_env(monkeypatch: pytest.MonkeyPatch): + """config.yaml references $OPENAI_API_KEY at parse time; supply a placeholder.""" + monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-not-used") + monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid") + + +@pytest.fixture(autouse=True) +def _reset_deferred_registry_between_tests(): + """Each test must start with a clean ContextVar. + + The registry lives in a module-level ContextVar with no per-task isolation + in a synchronous test runner, so one test's promotion can leak into the + next and silently break filter assertions. + """ + from deerflow.tools.builtins.tool_search import reset_deferred_registry + + reset_deferred_registry() + yield + reset_deferred_registry() + + +def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None: + """Make get_available_tools believe an MCP server is registered. + + Build a real ``ExtensionsConfig`` with one enabled MCP server entry so + that both ``AppConfig.from_file`` (which calls + ``ExtensionsConfig.from_file().model_dump()``) and ``tools.get_available_tools`` + (which calls ``ExtensionsConfig.from_file().get_enabled_mcp_servers()``) + see a valid instance. Then point the MCP tool cache at our fixture tools. + """ + from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig + + real_ext = ExtensionsConfig( + mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)}, + ) + monkeypatch.setattr( + "deerflow.config.extensions_config.ExtensionsConfig.from_file", + classmethod(lambda cls: real_ext), + ) + monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools)) + + +def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None: + """Force config.tool_search.enabled=True without touching the yaml. + + Calling the real ``get_app_config()`` would trigger ``_apply_singleton_configs`` + which permanently mutates module-level singletons (``_memory_config``, + ``_title_config``, …) to match the developer's ``config.yaml`` — even + after pytest restores our patch. That leaks across tests later in the + run that rely on those singletons' DEFAULTS (e.g. memory queue tests + require ``_memory_config.enabled = True``, which is the dataclass default + but FALSE in the actual yaml). + + Build a minimal mock AppConfig instead and never call the real loader. + """ + from deerflow.config.app_config import AppConfig + from deerflow.config.tool_search_config import ToolSearchConfig + + mock_cfg = AppConfig.model_construct( + log_level="info", + models=[], + tools=[], + tool_groups=[], + sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"), + tool_search=ToolSearchConfig(enabled=True), + ) + monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg) + + +# --------------------------------------------------------------------------- +# Section A — direct unit-level reproduction +# --------------------------------------------------------------------------- + + +def test_get_available_tools_preserves_promotions_across_reentrant_calls(monkeypatch: pytest.MonkeyPatch): + """Re-entrant ``get_available_tools()`` must preserve prior promotions. + + Step 1: call get_available_tools() — registers MCP tools as deferred. + Step 2: simulate the agent calling tool_search by promoting one tool. + Step 3: call get_available_tools() again (the same code path + ``task_tool`` exercises mid-run). + + Assertion: after step 3, the promoted tool is STILL promoted (not + re-deferred). On ``main`` before the fix, step 3's + ``reset_deferred_registry()`` wiped the promotion and re-registered + every MCP tool as deferred — this assertion fired with + ``REGRESSION (#2884)``. + """ + from deerflow.tools.builtins.tool_search import get_deferred_registry + from deerflow.tools.tools import get_available_tools + + _patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch]) + _force_tool_search_enabled(monkeypatch) + + # Step 1: first call — both MCP tools start deferred + get_available_tools() + reg1 = get_deferred_registry() + assert reg1 is not None + assert {e.name for e in reg1.entries} == {"fake_mcp_search", "fake_mcp_fetch"} + + # Step 2: simulate tool_search promoting one of them + reg1.promote({"fake_mcp_search"}) + assert {e.name for e in reg1.entries} == {"fake_mcp_fetch"}, "Sanity: promote should remove fake_mcp_search" + + # Step 3: second call — registry must NOT silently undo the promotion + get_available_tools() + reg2 = get_deferred_registry() + assert reg2 is not None + deferred_after = {e.name for e in reg2.entries} + assert "fake_mcp_search" not in deferred_after, f"REGRESSION (#2884): get_available_tools wiped the deferred registry, re-deferring a tool that was already promoted by tool_search. deferred_after_second_call={deferred_after!r}" + + +# --------------------------------------------------------------------------- +# Section B — graph-execution reproduction +# --------------------------------------------------------------------------- + + +class _ToolSearchPromotingModel(FakeToolCallingModel): + """Two-turn model that: + + Turn 1 → emit a tool_call for ``tool_search`` (the real one) + Turn 2 → emit a tool_call for ``fake_mcp_search`` (the promoted tool) + + Records the tools it received on each turn so the test can inspect what + DeferredToolFilterMiddleware actually fed to ``bind_tools``. + """ + + bound_tools_per_turn: list[list[str]] = [] + + def bind_tools( # type: ignore[override] + self, + tools: Any, + *, + tool_choice: Any = None, + **kwargs: Any, + ) -> Runnable: + # Record the tool names the model would see in this turn + names = [getattr(t, "name", getattr(t, "__name__", repr(t))) for t in tools] + self.bound_tools_per_turn.append(names) + return self + + +def _build_promoting_model() -> _ToolSearchPromotingModel: + return _ToolSearchPromotingModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": "tool_search", + "args": {"query": "select:fake_mcp_search"}, + "id": "call_search_1", + "type": "tool_call", + } + ], + ), + AIMessage( + content="", + tool_calls=[ + { + "name": "fake_mcp_search", + "args": {"query": "hello"}, + "id": "call_mcp_1", + "type": "tool_call", + } + ], + ), + AIMessage(content="all done"), + ] + ) + + +def test_promoted_tool_is_visible_to_model_on_second_turn(monkeypatch: pytest.MonkeyPatch): + """End-to-end: drive a real create_agent graph through two turns. + + Without the fix, the second-turn bind_tools call should NOT contain + fake_mcp_search (because DeferredToolFilterMiddleware sees it in the + registry and strips it). With the fix, the model sees the schema and can + invoke it. + """ + from langchain.agents import create_agent + + from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware + from deerflow.tools.tools import get_available_tools + + _patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch]) + _force_tool_search_enabled(monkeypatch) + + tools = get_available_tools() + # Sanity: the assembled tool list includes the deferred tools (they're in + # bind_tools but DeferredToolFilterMiddleware strips deferred ones before + # they reach the model) + tool_names = {getattr(t, "name", "") for t in tools} + assert {"tool_search", "fake_mcp_search", "fake_mcp_fetch"} <= tool_names + + model = _build_promoting_model() + model.bound_tools_per_turn = [] # reset class-level recorder + + graph = create_agent( + model=model, + tools=tools, + middleware=[DeferredToolFilterMiddleware()], + system_prompt="bug-2884-repro", + ) + + graph.invoke({"messages": [HumanMessage(content="use the search tool")]}) + + # Turn 1: model should NOT see fake_mcp_search (it's deferred) + turn1 = set(model.bound_tools_per_turn[0]) + assert "fake_mcp_search" not in turn1, f"Turn 1 sanity: deferred tools must be hidden from the model. Saw: {turn1!r}" + assert "tool_search" in turn1, f"Turn 1 sanity: tool_search must be visible so the agent can discover. Saw: {turn1!r}" + + # Turn 2: AFTER tool_search promotes fake_mcp_search, the model must see it. + # This is the load-bearing assertion for issue #2884. + assert len(model.bound_tools_per_turn) >= 2, f"Expected at least 2 model turns, got {len(model.bound_tools_per_turn)}" + turn2 = set(model.bound_tools_per_turn[1]) + assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): tool_search promoted fake_mcp_search in turn 1, but the deferred-tool filter still hid it from the model in turn 2. Turn 2 bound tools: {turn2!r}" + + +# --------------------------------------------------------------------------- +# Section C — the actual issue #2884 trigger: a re-entrant +# get_available_tools call (e.g. when task_tool spawns a subagent) must not +# wipe the parent's promotion. +# --------------------------------------------------------------------------- + + +def test_reentrant_get_available_tools_preserves_promotion(monkeypatch: pytest.MonkeyPatch): + """Issue #2884 in its real shape: a re-entrant get_available_tools call + (the same pattern that happens when ``task_tool`` builds a subagent's + toolset mid-run) must not wipe the parent agent's tool_search promotions. + + Turn 1's tool batch contains BOTH ``tool_search`` (which promotes + ``fake_mcp_search``) AND ``fake_subagent_trigger`` (which calls + ``get_available_tools`` again — exactly what ``task_tool`` does when it + builds a subagent's toolset). With the fix, turn 2's bind_tools sees the + promoted tool. Without the fix, the re-entry wipes the registry and + the filter re-hides it. + """ + from langchain.agents import create_agent + + from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware + from deerflow.tools.tools import get_available_tools + + _patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch]) + _force_tool_search_enabled(monkeypatch) + + # The trigger tool simulates what task_tool does internally: rebuild the + # toolset by calling get_available_tools while the registry is live. + @as_tool + def fake_subagent_trigger(prompt: str) -> str: + """Pretend to spawn a subagent. Internally rebuilds the toolset.""" + get_available_tools(subagent_enabled=False) + return f"spawned subagent for: {prompt}" + + tools = get_available_tools() + [fake_subagent_trigger] + + bound_per_turn: list[list[str]] = [] + + class _Model(FakeToolCallingModel): + def bind_tools(self, tools_arg, **kwargs): # type: ignore[override] + bound_per_turn.append([getattr(t, "name", repr(t)) for t in tools_arg]) + return self + + model = _Model( + responses=[ + # Turn 1: do both in one batch — promote AND trigger the + # subagent-style rebuild. LangGraph executes them in order in the + # same agent step. + AIMessage( + content="", + tool_calls=[ + { + "name": "tool_search", + "args": {"query": "select:fake_mcp_search"}, + "id": "call_search_1", + "type": "tool_call", + }, + { + "name": "fake_subagent_trigger", + "args": {"prompt": "go"}, + "id": "call_trigger_1", + "type": "tool_call", + }, + ], + ), + # Turn 2: try to invoke the promoted tool. The model gets this + # turn only if turn 1's bind_tools recorded what the filter sent. + AIMessage( + content="", + tool_calls=[ + { + "name": "fake_mcp_search", + "args": {"query": "hello"}, + "id": "call_mcp_1", + "type": "tool_call", + } + ], + ), + AIMessage(content="all done"), + ] + ) + + graph = create_agent( + model=model, + tools=tools, + middleware=[DeferredToolFilterMiddleware()], + system_prompt="bug-2884-subagent-repro", + ) + graph.invoke({"messages": [HumanMessage(content="use the search tool")]}) + + # Turn 1 sanity: deferred tool not visible yet + assert "fake_mcp_search" not in set(bound_per_turn[0]), bound_per_turn[0] + + # The smoking-gun assertion: turn 2 sees the promoted tool DESPITE the + # re-entrant get_available_tools call that happened in turn 1's tool batch. + assert len(bound_per_turn) >= 2, f"Expected ≥2 turns, got {len(bound_per_turn)}" + turn2 = set(bound_per_turn[1]) + assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): a re-entrant get_available_tools call (e.g. task_tool spawning a subagent) wiped the parent agent's promotion. Turn 2 bound tools: {turn2!r}" diff --git a/backend/tests/test_tool_deduplication.py b/backend/tests/test_tool_deduplication.py index ed9efffaf..f018fc57d 100644 --- a/backend/tests/test_tool_deduplication.py +++ b/backend/tests/test_tool_deduplication.py @@ -65,8 +65,7 @@ def _make_minimal_config(tools): @patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) -@patch("deerflow.tools.tools.reset_deferred_registry") -def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash, mock_cfg): +def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg): """Config-loaded async-only tools can still be invoked by sync clients.""" async def async_tool_impl(x: int) -> str: @@ -98,8 +97,7 @@ def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash, @patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) -@patch("deerflow.tools.tools.reset_deferred_registry") -def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg): +def test_no_duplicates_returned(mock_bash, mock_cfg): """get_available_tools() never returns two tools with the same name.""" mock_cfg.return_value = _make_minimal_config([]) @@ -113,8 +111,7 @@ def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg): @patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) -@patch("deerflow.tools.tools.reset_deferred_registry") -def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg): +def test_first_occurrence_wins(mock_bash, mock_cfg): """When duplicates exist, the first occurrence is kept.""" mock_cfg.return_value = _make_minimal_config([]) @@ -132,8 +129,7 @@ def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg): @patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) -@patch("deerflow.tools.tools.reset_deferred_registry") -def test_duplicate_triggers_warning(mock_reset, mock_bash, mock_cfg, caplog): +def test_duplicate_triggers_warning(mock_bash, mock_cfg, caplog): """A warning is logged for every skipped duplicate.""" import logging From eab7ae3d6283a51fbe759e761d39fce2308cc4a3 Mon Sep 17 00:00:00 2001 From: YuJitang Date: Wed, 13 May 2026 23:52:19 +0800 Subject: [PATCH 20/86] feat: stream subagent token usage to header via terminal task events (#2882) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: real-time subagent token usage display in header and per-turn Backend: - Persist subagent token usage to AIMessage.usage_metadata via TokenUsageMiddleware, so accumulateUsage() naturally includes subagent tokens without frontend state management - Cache subagent usage by tool_call_id in task_tool, write back to the dispatching AIMessage on next model response - Emit subagent token usage on all terminal task events (task_completed, task_failed, task_cancelled, task_timed_out) - Report subagent usage to parent RunJournal for API totals - Search backward from ToolMessage to find dispatching AIMessage for correct multi-tool-call attribution Frontend: - Remove subagentUsage state, custom event handling, and prop threading — subagent tokens are now embedded in message metadata - Simplify selectHeaderTokenUsage (no subagentUsage parameter) - Per-turn inline badges show turn-specific usage via message accumulation - Remove isLoading guard from MessageTokenUsageList for dynamic updates during streaming * fix: prevent header token double counting from baseline reset race onFinish, onError, and thread-switch useEffect all reset pendingUsageBaselineMessageIdsRef to an empty Set. If thread.isLoading is still true on the next render, all messages pass the getMessagesAfterBaseline filter and their tokens are added to backendUsage (which already includes them), causing the header to display up to 2× the actual token count. Capture current message IDs instead of using an empty Set so that getMessagesAfterBaseline correctly returns no pending messages even if thread.isLoading lags behind the stream end. * fix: write back subagent tokens for all concurrent task tool calls TokenUsageMiddleware only processed messages[-2], so when a single model response dispatched multiple task tool calls only the last ToolMessage had its cached subagent usage written back to the dispatch AIMessage.usage_metadata. Earlier tasks' usage stayed in _subagent_usage_cache indefinitely (leak) and never appeared in the per-turn inline token display. Walk backward through all consecutive ToolMessages before the new AIMessage, and accumulate updates targeting the same dispatch message into one state update so overlapping writes don't clobber each other. * fix: clean up subagent usage cache entry on task cancellation When a task_tool invocation is cancelled via CancelledError, any cached subagent usage entry leaked because the TokenUsageMiddleware writeback path never fires after cancellation. Pop the cache entry before re-raising to prevent unbounded growth of the module-level _subagent_usage_cache dict. * fix: address token usage review feedback * fix: handle missing config for subagent usage cache --------- Co-authored-by: Willem Jiang --- README.md | 2 +- backend/CLAUDE.md | 2 +- .../middlewares/token_usage_middleware.py | 61 ++++++- .../deerflow/tools/builtins/task_tool.py | 55 ++++++- .../tests/test_memory_queue_user_isolation.py | 5 +- backend/tests/test_task_tool_core_logic.py | 153 ++++++++++++++++++ backend/tests/test_token_usage_middleware.py | 49 +++++- .../messages/message-token-usage.tsx | 41 +++-- frontend/src/core/messages/usage.ts | 4 +- frontend/src/core/threads/hooks.ts | 18 ++- 10 files changed, 349 insertions(+), 41 deletions(-) diff --git a/README.md b/README.md index 9ff1d501b..8248e8fe4 100644 --- a/README.md +++ b/README.md @@ -628,7 +628,7 @@ See [`skills/public/claude-to-deerflow/SKILL.md`](skills/public/claude-to-deerfl Complex tasks rarely fit in a single pass. DeerFlow decomposes them. -The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. +The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. When token usage tracking is enabled, completed sub-agent usage is attributed back to the dispatching step. This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands. diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 67ee9cc7e..5e0aebfdb 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -165,7 +165,7 @@ Lead-agent middlewares are assembled in strict append order across `packages/har 8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting 9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled) 10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode) -11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional) +11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id 12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model 13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses) 14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support) diff --git a/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py index f59e7f2b7..0d3607faf 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/token_usage_middleware.py @@ -9,7 +9,7 @@ from typing import Any, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware.todo import Todo -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, ToolMessage from langgraph.runtime import Runtime logger = logging.getLogger(__name__) @@ -217,6 +217,17 @@ def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str: return "thinking" +def _has_tool_call(message: AIMessage, tool_call_id: str) -> bool: + """Return True if the AIMessage contains a tool_call with the given id.""" + for tc in message.tool_calls or []: + if isinstance(tc, dict): + if tc.get("id") == tool_call_id: + return True + elif hasattr(tc, "id") and tc.id == tool_call_id: + return True + return False + + def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]: tool_calls = getattr(message, "tool_calls", None) or [] actions: list[dict[str, Any]] = [] @@ -261,8 +272,51 @@ class TokenUsageMiddleware(AgentMiddleware): if not messages: return None + # Annotate subagent token usage onto the AIMessage that dispatched it. + # When a task tool completes, its usage is cached by tool_call_id. Detect + # the ToolMessage → search backward for the corresponding AIMessage → merge. + # Walk backward through consecutive ToolMessages before the new AIMessage + # so that multiple concurrent task tool calls all get their subagent tokens + # written back to the same dispatch message (merging into one update). + state_updates: dict[int, AIMessage] = {} + if len(messages) >= 2: + from deerflow.tools.builtins.task_tool import pop_cached_subagent_usage + + idx = len(messages) - 2 + while idx >= 0: + tool_msg = messages[idx] + if not isinstance(tool_msg, ToolMessage) or not tool_msg.tool_call_id: + break + + subagent_usage = pop_cached_subagent_usage(tool_msg.tool_call_id) + if subagent_usage: + # Search backward from the ToolMessage to find the AIMessage + # that dispatched it. A single model response can dispatch + # multiple task tool calls, so we can't assume a fixed offset. + dispatch_idx = idx - 1 + while dispatch_idx >= 0: + candidate = messages[dispatch_idx] + if isinstance(candidate, AIMessage) and _has_tool_call(candidate, tool_msg.tool_call_id): + # Accumulate into an existing update for the same + # AIMessage (multiple task calls in one response), + # or merge fresh from the original message. + existing_update = state_updates.get(dispatch_idx) + prev = existing_update.usage_metadata if existing_update else (getattr(candidate, "usage_metadata", None) or {}) + merged = { + **prev, + "input_tokens": prev.get("input_tokens", 0) + subagent_usage["input_tokens"], + "output_tokens": prev.get("output_tokens", 0) + subagent_usage["output_tokens"], + "total_tokens": prev.get("total_tokens", 0) + subagent_usage["total_tokens"], + } + state_updates[dispatch_idx] = candidate.model_copy(update={"usage_metadata": merged}) + break + dispatch_idx -= 1 + idx -= 1 + last = messages[-1] if not isinstance(last, AIMessage): + if state_updates: + return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} return None usage = getattr(last, "usage_metadata", None) @@ -288,11 +342,12 @@ class TokenUsageMiddleware(AgentMiddleware): additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {}) if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution: - return None + return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} if state_updates else None additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs}) - return {"messages": [updated_msg]} + state_updates[len(messages) - 1] = updated_msg + return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} @override def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: diff --git a/backend/packages/harness/deerflow/tools/builtins/task_tool.py b/backend/packages/harness/deerflow/tools/builtins/task_tool.py index 861c45b45..cf9281ff4 100644 --- a/backend/packages/harness/deerflow/tools/builtins/task_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/task_tool.py @@ -26,6 +26,28 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# Cache subagent token usage by tool_call_id so TokenUsageMiddleware can +# write it back to the triggering AIMessage's usage_metadata. +_subagent_usage_cache: dict[str, dict[str, int]] = {} + + +def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool: + if app_config is None: + try: + app_config = get_app_config() + except FileNotFoundError: + return False + return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False)) + + +def _cache_subagent_usage(tool_call_id: str, usage: dict | None, *, enabled: bool = True) -> None: + if enabled and usage: + _subagent_usage_cache[tool_call_id] = usage + + +def pop_cached_subagent_usage(tool_call_id: str) -> dict | None: + return _subagent_usage_cache.pop(tool_call_id, None) + def _is_subagent_terminal(result: Any) -> bool: """Return whether a background subagent result is safe to clean up.""" @@ -92,6 +114,17 @@ def _find_usage_recorder(runtime: Any) -> Any | None: return None +def _summarize_usage(records: list[dict] | None) -> dict | None: + """Summarize token usage records into a compact dict for SSE events.""" + if not records: + return None + return { + "input_tokens": sum(r.get("input_tokens", 0) or 0 for r in records), + "output_tokens": sum(r.get("output_tokens", 0) or 0 for r in records), + "total_tokens": sum(r.get("total_tokens", 0) or 0 for r in records), + } + + def _report_subagent_usage(runtime: Any, result: Any) -> None: """Report subagent token usage to the parent RunJournal, if available. @@ -177,6 +210,7 @@ async def task_tool( subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD. """ runtime_app_config = _get_runtime_app_config(runtime) + cache_token_usage = _token_usage_cache_enabled(runtime_app_config) available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names() # Get subagent configuration @@ -312,27 +346,32 @@ async def task_tool( last_message_count = current_message_count # Check if task completed, failed, or timed out + usage = _summarize_usage(getattr(result, "token_usage_records", None)) if result.status == SubagentStatus.COMPLETED: + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) _report_subagent_usage(runtime, result) - writer({"type": "task_completed", "task_id": task_id, "result": result.result}) + writer({"type": "task_completed", "task_id": task_id, "result": result.result, "usage": usage}) logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls") cleanup_background_task(task_id) return f"Task Succeeded. Result: {result.result}" elif result.status == SubagentStatus.FAILED: + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) _report_subagent_usage(runtime, result) - writer({"type": "task_failed", "task_id": task_id, "error": result.error}) + writer({"type": "task_failed", "task_id": task_id, "error": result.error, "usage": usage}) logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}") cleanup_background_task(task_id) return f"Task failed. Error: {result.error}" elif result.status == SubagentStatus.CANCELLED: + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) _report_subagent_usage(runtime, result) - writer({"type": "task_cancelled", "task_id": task_id, "error": result.error}) + writer({"type": "task_cancelled", "task_id": task_id, "error": result.error, "usage": usage}) logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}") cleanup_background_task(task_id) return "Task cancelled by user." elif result.status == SubagentStatus.TIMED_OUT: + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) _report_subagent_usage(runtime, result) - writer({"type": "task_timed_out", "task_id": task_id, "error": result.error}) + writer({"type": "task_timed_out", "task_id": task_id, "error": result.error, "usage": usage}) logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}") cleanup_background_task(task_id) return f"Task timed out. Error: {result.error}" @@ -351,7 +390,9 @@ async def task_tool( timeout_minutes = config.timeout_seconds // 60 logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)") _report_subagent_usage(runtime, result) - writer({"type": "task_timed_out", "task_id": task_id}) + usage = _summarize_usage(getattr(result, "token_usage_records", None)) + _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) + writer({"type": "task_timed_out", "task_id": task_id, "usage": usage}) return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}" except asyncio.CancelledError: # Signal the background subagent thread to stop cooperatively. @@ -374,4 +415,8 @@ async def task_tool( cleanup_background_task(task_id) else: _schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count) + _subagent_usage_cache.pop(tool_call_id, None) + raise + except Exception: + _subagent_usage_cache.pop(tool_call_id, None) raise diff --git a/backend/tests/test_memory_queue_user_isolation.py b/backend/tests/test_memory_queue_user_isolation.py index cf068e095..79250817c 100644 --- a/backend/tests/test_memory_queue_user_isolation.py +++ b/backend/tests/test_memory_queue_user_isolation.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue +from deerflow.config.memory_config import MemoryConfig def test_conversation_context_has_user_id(): @@ -17,7 +18,7 @@ def test_conversation_context_user_id_default_none(): def test_queue_add_stores_user_id(): q = MemoryUpdateQueue() - with patch.object(q, "_reset_timer"): + with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"): q.add(thread_id="t1", messages=["msg"], user_id="alice") assert len(q._queue) == 1 assert q._queue[0].user_id == "alice" @@ -26,7 +27,7 @@ def test_queue_add_stores_user_id(): def test_queue_process_passes_user_id_to_updater(): q = MemoryUpdateQueue() - with patch.object(q, "_reset_timer"): + with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"): q.add(thread_id="t1", messages=["msg"], user_id="alice") mock_updater = MagicMock() diff --git a/backend/tests/test_task_tool_core_logic.py b/backend/tests/test_task_tool_core_logic.py index 0591c0e8d..658968d65 100644 --- a/backend/tests/test_task_tool_core_logic.py +++ b/backend/tests/test_task_tool_core_logic.py @@ -59,12 +59,15 @@ def _make_result( ai_messages: list[dict] | None = None, result: str | None = None, error: str | None = None, + token_usage_records: list[dict] | None = None, ) -> SimpleNamespace: return SimpleNamespace( status=status, ai_messages=ai_messages or [], result=result, error=error, + token_usage_records=token_usage_records or [], + usage_reported=False, ) @@ -1132,3 +1135,153 @@ def test_cancellation_reports_subagent_usage(monkeypatch): assert len(report_calls) == 1 assert report_calls[0][1] is cancel_result assert cleanup_calls == ["tc-cancel-report"] + + +@pytest.mark.parametrize( + "status, expected_type", + [ + (FakeSubagentStatus.COMPLETED, "task_completed"), + (FakeSubagentStatus.FAILED, "task_failed"), + (FakeSubagentStatus.CANCELLED, "task_cancelled"), + (FakeSubagentStatus.TIMED_OUT, "task_timed_out"), + ], +) +def test_terminal_events_include_usage(monkeypatch, status, expected_type): + """Terminal task events include a usage summary from token_usage_records.""" + config = _make_subagent_config() + runtime = _make_runtime() + events = [] + + records = [ + {"source_run_id": "r1", "caller": "subagent:general-purpose", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150}, + {"source_run_id": "r2", "caller": "subagent:general-purpose", "input_tokens": 200, "output_tokens": 80, "total_tokens": 280}, + ] + result = _make_result(status, result="ok" if status == FakeSubagentStatus.COMPLETED else None, error="err" if status != FakeSubagentStatus.COMPLETED else None, token_usage_records=records) + + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config) + monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) + monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None) + monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None) + monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[])) + + _run_task_tool( + runtime=runtime, + description="test", + prompt="do work", + subagent_type="general-purpose", + tool_call_id="tc-usage", + ) + + terminal_events = [e for e in events if e["type"] == expected_type] + assert len(terminal_events) == 1 + assert terminal_events[0]["usage"] == { + "input_tokens": 300, + "output_tokens": 130, + "total_tokens": 430, + } + + +def test_terminal_event_usage_none_when_no_records(monkeypatch): + """Terminal event has usage=None when token_usage_records is empty.""" + config = _make_subagent_config() + runtime = _make_runtime() + events = [] + + result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=[]) + + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config) + monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) + monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None) + monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None) + monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[])) + + _run_task_tool( + runtime=runtime, + description="test", + prompt="do work", + subagent_type="general-purpose", + tool_call_id="tc-no-records", + ) + + completed = [e for e in events if e["type"] == "task_completed"] + assert len(completed) == 1 + assert completed[0]["usage"] is None + + +def test_subagent_usage_cache_is_skipped_when_config_file_is_missing(monkeypatch): + monkeypatch.setattr( + task_tool_module, + "get_app_config", + MagicMock(side_effect=FileNotFoundError("missing config")), + ) + + assert task_tool_module._token_usage_cache_enabled(None) is False + + +def test_subagent_usage_cache_is_skipped_when_token_usage_is_disabled(monkeypatch): + config = _make_subagent_config() + app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=False)) + runtime = _make_runtime(app_config=app_config) + records = [{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}] + result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=records) + + task_tool_module._subagent_usage_cache.clear() + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"]) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config) + monkeypatch.setattr( + task_tool_module, + "SubagentExecutor", + type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}), + ) + monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None) + monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None) + monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None) + monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[])) + + _run_task_tool( + runtime=runtime, + description="test", + prompt="do work", + subagent_type="general-purpose", + tool_call_id="tc-disabled-cache", + ) + + assert task_tool_module.pop_cached_subagent_usage("tc-disabled-cache") is None + + +def test_subagent_usage_cache_is_cleared_when_polling_raises(monkeypatch): + config = _make_subagent_config() + app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=True)) + runtime = _make_runtime(app_config=app_config) + + task_tool_module._subagent_usage_cache["tc-error"] = {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2} + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"]) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config) + monkeypatch.setattr( + task_tool_module, + "SubagentExecutor", + type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}), + ) + monkeypatch.setattr(task_tool_module, "get_background_task_result", MagicMock(side_effect=RuntimeError("poll failed"))) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None) + monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[])) + + with pytest.raises(RuntimeError, match="poll failed"): + _run_task_tool( + runtime=runtime, + description="test", + prompt="do work", + subagent_type="general-purpose", + tool_call_id="tc-error", + ) + + assert task_tool_module.pop_cached_subagent_usage("tc-error") is None diff --git a/backend/tests/test_token_usage_middleware.py b/backend/tests/test_token_usage_middleware.py index b24ff7b16..9686455c0 100644 --- a/backend/tests/test_token_usage_middleware.py +++ b/backend/tests/test_token_usage_middleware.py @@ -1,9 +1,10 @@ """Tests for TokenUsageMiddleware attribution annotations.""" +import importlib import logging from unittest.mock import MagicMock -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, ToolMessage from deerflow.agents.middlewares.token_usage_middleware import ( TOKEN_USAGE_ATTRIBUTION_KEY, @@ -232,3 +233,49 @@ class TestTokenUsageMiddleware: "tool_call_id": "write_todos:remove", } ] + + def test_merges_subagent_usage_by_message_position_when_ai_message_ids_are_missing(self, monkeypatch): + middleware = TokenUsageMiddleware() + first_dispatch = AIMessage( + content="", + tool_calls=[{"id": "task:first", "name": "task", "args": {}}], + ) + second_dispatch = AIMessage( + content="", + tool_calls=[ + {"id": "task:second-a", "name": "task", "args": {}}, + {"id": "task:second-b", "name": "task", "args": {}}, + ], + ) + messages = [ + first_dispatch, + ToolMessage(content="first", tool_call_id="task:first"), + second_dispatch, + ToolMessage(content="second-a", tool_call_id="task:second-a"), + ToolMessage(content="second-b", tool_call_id="task:second-b"), + AIMessage(content="done"), + ] + cached_usage = { + "task:second-a": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, + "task:second-b": {"input_tokens": 20, "output_tokens": 7, "total_tokens": 27}, + } + + task_tool_module = importlib.import_module("deerflow.tools.builtins.task_tool") + monkeypatch.setattr( + task_tool_module, + "pop_cached_subagent_usage", + lambda tool_call_id: cached_usage.pop(tool_call_id, None), + ) + + result = middleware.after_model({"messages": messages}, _make_runtime()) + + assert result is not None + usage_updates = [message for message in result["messages"] if getattr(message, "usage_metadata", None)] + assert len(usage_updates) == 1 + updated = usage_updates[0] + assert updated.tool_calls == second_dispatch.tool_calls + assert updated.usage_metadata == { + "input_tokens": 30, + "output_tokens": 12, + "total_tokens": 42, + } diff --git a/frontend/src/components/workspace/messages/message-token-usage.tsx b/frontend/src/components/workspace/messages/message-token-usage.tsx index cc8d0debb..84f8a8057 100644 --- a/frontend/src/components/workspace/messages/message-token-usage.tsx +++ b/frontend/src/components/workspace/messages/message-token-usage.tsx @@ -12,13 +12,11 @@ function TokenUsageSummary({ inputTokens, outputTokens, totalTokens, - unavailable = false, }: { className?: string; inputTokens?: number; outputTokens?: number; totalTokens?: number; - unavailable?: boolean; }) { const { t } = useI18n(); @@ -33,21 +31,15 @@ function TokenUsageSummary({ {t.tokenUsage.label} - {!unavailable ? ( - <> - - {t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)} - - - {t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)} - - - {t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)} - - - ) : ( - {t.tokenUsage.unavailableShort} - )} + + {t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)} + + + {t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)} + + + {t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)} + ); } @@ -55,7 +47,7 @@ function TokenUsageSummary({ export function MessageTokenUsageList({ className, enabled = false, - isLoading = false, + isLoading: _isLoading = false, messages, }: { className?: string; @@ -63,7 +55,7 @@ export function MessageTokenUsageList({ isLoading?: boolean; messages: Message[]; }) { - if (!enabled || isLoading) { + if (!enabled) { return null; } @@ -75,13 +67,16 @@ export function MessageTokenUsageList({ const usage = accumulateUsage(aiMessages); + if (!usage) { + return null; + } + return ( ); } diff --git a/frontend/src/core/messages/usage.ts b/frontend/src/core/messages/usage.ts index 4679dffa5..01e3a59e1 100644 --- a/frontend/src/core/messages/usage.ts +++ b/frontend/src/core/messages/usage.ts @@ -65,7 +65,7 @@ export function accumulateUsage(messages: Message[]): TokenUsage | null { return hasUsage ? cumulative : null; } -function hasNonZeroUsage( +export function hasNonZeroUsage( usage: TokenUsage | null | undefined, ): usage is TokenUsage { return ( @@ -75,7 +75,7 @@ function hasNonZeroUsage( ); } -function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage { +export function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage { return { inputTokens: base.inputTokens + delta.inputTokens, outputTokens: base.outputTokens + delta.outputTokens, diff --git a/frontend/src/core/threads/hooks.ts b/frontend/src/core/threads/hooks.ts index 0ac790eb2..adf9dbbb6 100644 --- a/frontend/src/core/threads/hooks.ts +++ b/frontend/src/core/threads/hooks.ts @@ -296,7 +296,11 @@ export function useThreadStream({ onError(error) { setOptimisticMessages([]); toast.error(getStreamErrorMessage(error)); - pendingUsageBaselineMessageIdsRef.current = new Set(); + pendingUsageBaselineMessageIdsRef.current = new Set( + messagesRef.current + .map(messageIdentity) + .filter((id): id is string => Boolean(id)), + ); if (threadIdRef.current && !isMock) { void queryClient.invalidateQueries({ queryKey: threadTokenUsageQueryKey(threadIdRef.current), @@ -305,7 +309,11 @@ export function useThreadStream({ }, onFinish(state) { listeners.current.onFinish?.(state.values); - pendingUsageBaselineMessageIdsRef.current = new Set(); + pendingUsageBaselineMessageIdsRef.current = new Set( + messagesRef.current + .map(messageIdentity) + .filter((id): id is string => Boolean(id)), + ); void queryClient.invalidateQueries({ queryKey: ["threads", "search"] }); if (threadIdRef.current && !isMock) { void queryClient.invalidateQueries({ @@ -339,7 +347,11 @@ export function useThreadStream({ useEffect(() => { startedRef.current = false; sendInFlightRef.current = false; - pendingUsageBaselineMessageIdsRef.current = new Set(); + pendingUsageBaselineMessageIdsRef.current = new Set( + messagesRef.current + .map(messageIdentity) + .filter((id): id is string => Boolean(id)), + ); prevHumanMsgCountRef.current = latestMessageCountsRef.current.humanMessageCount; }, [threadId]); From 6e8e6a969be803227aa71cb6ba4c5d116910b4b7 Mon Sep 17 00:00:00 2001 From: AochenShen99 <142667174+ShenAC-SAC@users.noreply.github.com> Date: Wed, 13 May 2026 23:56:06 +0800 Subject: [PATCH 21/86] test: add blocking IO detector (#2924) * test: add blocking IO detector * test: add blocking IO probe option * test: harden blocking IO probe lifecycle * test: move blocking io detector to support --- backend/tests/conftest.py | 93 ++++++ backend/tests/support/__init__.py | 1 + backend/tests/support/detectors/__init__.py | 1 + .../tests/support/detectors/blocking_io.py | 287 ++++++++++++++++++ backend/tests/test_blocking_io_detector.py | 190 ++++++++++++ .../test_blocking_io_probe_integration.py | 22 ++ 6 files changed, 594 insertions(+) create mode 100644 backend/tests/support/__init__.py create mode 100644 backend/tests/support/detectors/__init__.py create mode 100644 backend/tests/support/detectors/blocking_io.py create mode 100644 backend/tests/test_blocking_io_detector.py create mode 100644 backend/tests/test_blocking_io_probe_integration.py diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index a357a3962..9bc8d4884 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -4,6 +4,8 @@ Sets up sys.path and pre-mocks modules that would cause circular import issues when unit-testing lightweight config/registry code in isolation. """ +from __future__ import annotations + import importlib.util import sys from pathlib import Path @@ -11,11 +13,16 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from support.detectors.blocking_io import BlockingIOProbe, detect_blocking_io # Make 'app' and 'deerflow' importable from any working directory sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "scripts")) +_BACKEND_ROOT = Path(__file__).resolve().parents[1] +_blocking_io_probe = BlockingIOProbe(_BACKEND_ROOT) +_BLOCKING_IO_DETECTOR_ATTR = "_blocking_io_detector" + # Break the circular import chain that exists in production code: # deerflow.subagents.__init__ # -> .executor (SubagentExecutor, SubagentResult) @@ -56,6 +63,92 @@ def provisioner_module(): return module +@pytest.fixture() +def blocking_io_detector(): + """Fail a focused test if blocking calls run on the event loop thread.""" + with detect_blocking_io(fail_on_exit=True) as detector: + yield detector + + +def pytest_addoption(parser: pytest.Parser) -> None: + group = parser.getgroup("blocking-io") + group.addoption( + "--detect-blocking-io", + action="store_true", + default=False, + help="Collect blocking calls made while an asyncio event loop is running and report a summary.", + ) + group.addoption( + "--detect-blocking-io-fail", + action="store_true", + default=False, + help="Set a failing exit status when --detect-blocking-io records violations.", + ) + + +def pytest_configure(config: pytest.Config) -> None: + config.addinivalue_line("markers", "no_blocking_io_probe: skip the optional blocking IO probe") + + +def pytest_sessionstart(session: pytest.Session) -> None: + if _blocking_io_probe_enabled(session.config): + _blocking_io_probe.clear() + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_call(item: pytest.Item): + if not _blocking_io_probe_enabled(item.config) or _blocking_io_probe_skipped(item): + yield + return + + detector = detect_blocking_io(fail_on_exit=False, stack_limit=18) + detector.__enter__() + setattr(item, _BLOCKING_IO_DETECTOR_ATTR, detector) + yield + + +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_teardown(item: pytest.Item): + yield + + detector = getattr(item, _BLOCKING_IO_DETECTOR_ATTR, None) + if detector is None: + return + + try: + detector.__exit__(None, None, None) + _blocking_io_probe.record(item.nodeid, detector.violations) + finally: + delattr(item, _BLOCKING_IO_DETECTOR_ATTR) + + +def pytest_sessionfinish(session: pytest.Session) -> None: + if _blocking_io_fail_enabled(session.config) and _blocking_io_probe.violation_count and session.exitstatus == pytest.ExitCode.OK: + session.exitstatus = pytest.ExitCode.TESTS_FAILED + + +def pytest_terminal_summary(terminalreporter: pytest.TerminalReporter) -> None: + if not _blocking_io_probe_enabled(terminalreporter.config): + return + + header, *details = _blocking_io_probe.format_summary().splitlines() + terminalreporter.write_sep("=", header) + for line in details: + terminalreporter.write_line(line) + + +def _blocking_io_probe_enabled(config: pytest.Config) -> bool: + return bool(config.getoption("--detect-blocking-io") or config.getoption("--detect-blocking-io-fail")) + + +def _blocking_io_fail_enabled(config: pytest.Config) -> bool: + return bool(config.getoption("--detect-blocking-io-fail")) + + +def _blocking_io_probe_skipped(item: pytest.Item) -> bool: + return item.path.name == "test_blocking_io_detector.py" or item.get_closest_marker("no_blocking_io_probe") is not None + + # --------------------------------------------------------------------------- # Auto-set user context for every test unless marked no_auto_user # --------------------------------------------------------------------------- diff --git a/backend/tests/support/__init__.py b/backend/tests/support/__init__.py new file mode 100644 index 000000000..38361eaf5 --- /dev/null +++ b/backend/tests/support/__init__.py @@ -0,0 +1 @@ +"""Shared test support helpers.""" diff --git a/backend/tests/support/detectors/__init__.py b/backend/tests/support/detectors/__init__.py new file mode 100644 index 000000000..cf9568cb6 --- /dev/null +++ b/backend/tests/support/detectors/__init__.py @@ -0,0 +1 @@ +"""Runtime and static detectors used by tests.""" diff --git a/backend/tests/support/detectors/blocking_io.py b/backend/tests/support/detectors/blocking_io.py new file mode 100644 index 000000000..c1adfd55a --- /dev/null +++ b/backend/tests/support/detectors/blocking_io.py @@ -0,0 +1,287 @@ +"""Test helper for detecting blocking calls on an asyncio event loop. + +The detector is intentionally test-only. It monkeypatches a small set of +well-known blocking entry points and their already-loaded module-level aliases, +then records calls only when they happen on a thread that is currently running +an asyncio event loop. Aliases captured in closures or default arguments remain +out of scope. +""" + +from __future__ import annotations + +import asyncio +import importlib +import sys +import traceback +from collections import Counter +from collections.abc import Callable, Iterable, Iterator +from contextlib import AbstractContextManager +from dataclasses import dataclass +from functools import wraps +from pathlib import Path +from types import TracebackType +from typing import Any + +BlockingCallable = Callable[..., Any] + + +@dataclass(frozen=True) +class BlockingCallSpec: + """Describes one blocking callable to wrap during a detector run.""" + + name: str + target: str + record_on_iteration: bool = False + + +@dataclass(frozen=True) +class BlockingCall: + """One blocking call observed on an asyncio event loop thread.""" + + name: str + target: str + stack: tuple[traceback.FrameSummary, ...] + + +DEFAULT_BLOCKING_CALL_SPECS: tuple[BlockingCallSpec, ...] = ( + BlockingCallSpec("time.sleep", "time:sleep"), + BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"), + BlockingCallSpec("httpx.Client.request", "httpx:Client.request"), + BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True), + BlockingCallSpec("pathlib.Path.resolve", "pathlib:Path.resolve"), + BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"), + BlockingCallSpec("pathlib.Path.write_text", "pathlib:Path.write_text"), +) + + +def _is_event_loop_thread() -> bool: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return False + return loop.is_running() + + +def _resolve_target(target: str) -> tuple[object, str, BlockingCallable]: + module_name, attr_path = target.split(":", maxsplit=1) + owner: object = importlib.import_module(module_name) + parts = attr_path.split(".") + for part in parts[:-1]: + owner = getattr(owner, part) + + attr_name = parts[-1] + original = getattr(owner, attr_name) + return owner, attr_name, original + + +def _trim_detector_frames(stack: Iterable[traceback.FrameSummary]) -> tuple[traceback.FrameSummary, ...]: + return tuple(frame for frame in stack if frame.filename != __file__) + + +class BlockingIODetector(AbstractContextManager["BlockingIODetector"]): + """Record blocking calls made from async runtime code. + + By default the detector reports violations but does not fail on context + exit. Tests can set ``fail_on_exit=True`` or call + ``assert_no_blocking_calls()`` explicitly. + """ + + def __init__( + self, + specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS, + *, + fail_on_exit: bool = False, + patch_loaded_aliases: bool = True, + stack_limit: int = 12, + ) -> None: + self._specs = tuple(specs) + self._fail_on_exit = fail_on_exit + self._patch_loaded_aliases_enabled = patch_loaded_aliases + self._stack_limit = stack_limit + self._patches: list[tuple[object, str, BlockingCallable]] = [] + self._patch_keys: set[tuple[int, str]] = set() + self.violations: list[BlockingCall] = [] + self._active = False + + def __enter__(self) -> BlockingIODetector: + try: + self._active = True + alias_replacements: dict[int, BlockingCallable] = {} + for spec in self._specs: + owner, attr_name, original = _resolve_target(spec.target) + wrapper = self._wrap(spec, original) + self._patch_attribute(owner, attr_name, original, wrapper) + alias_replacements[id(original)] = wrapper + + if self._patch_loaded_aliases_enabled: + self._patch_loaded_module_aliases(alias_replacements) + except Exception: + self._restore() + self._active = False + raise + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback_value: TracebackType | None, + ) -> bool | None: + self._restore() + self._active = False + if exc_type is None and self._fail_on_exit: + self.assert_no_blocking_calls() + return None + + def _restore(self) -> None: + for owner, attr_name, original in reversed(self._patches): + setattr(owner, attr_name, original) + self._patches.clear() + self._patch_keys.clear() + + def _patch_attribute(self, owner: object, attr_name: str, original: BlockingCallable, replacement: BlockingCallable) -> None: + key = (id(owner), attr_name) + if key in self._patch_keys: + return + setattr(owner, attr_name, replacement) + self._patches.append((owner, attr_name, original)) + self._patch_keys.add(key) + + def _patch_loaded_module_aliases(self, replacements_by_id: dict[int, BlockingCallable]) -> None: + for module in tuple(sys.modules.values()): + namespace = getattr(module, "__dict__", None) + if not isinstance(namespace, dict): + continue + + for attr_name, value in tuple(namespace.items()): + replacement = replacements_by_id.get(id(value)) + if replacement is not None: + self._patch_attribute(module, attr_name, value, replacement) + + def _wrap(self, spec: BlockingCallSpec, original: BlockingCallable) -> BlockingCallable: + @wraps(original) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if spec.record_on_iteration: + result = original(*args, **kwargs) + return self._wrap_iteration(spec, result) + self._record_if_blocking(spec) + return original(*args, **kwargs) + + return wrapper + + def _wrap_iteration(self, spec: BlockingCallSpec, iterable: Iterable[Any]) -> Iterator[Any]: + iterator = iter(iterable) + reported = False + + while True: + if not reported: + reported = self._record_if_blocking(spec) + try: + yield next(iterator) + except StopIteration: + return + + def _record_if_blocking(self, spec: BlockingCallSpec) -> bool: + if self._active and _is_event_loop_thread(): + stack = _trim_detector_frames(traceback.extract_stack(limit=self._stack_limit)) + self.violations.append(BlockingCall(spec.name, spec.target, stack)) + return True + return False + + def assert_no_blocking_calls(self) -> None: + if self.violations: + raise AssertionError(format_blocking_calls(self.violations)) + + +class BlockingIOProbe: + """Collect detector output across tests and format a compact summary.""" + + def __init__(self, project_root: Path) -> None: + self._project_root = project_root.resolve() + self._observed: list[tuple[str, BlockingCall]] = [] + + @property + def violation_count(self) -> int: + return len(self._observed) + + @property + def test_count(self) -> int: + return len({nodeid for nodeid, _violation in self._observed}) + + def clear(self) -> None: + self._observed.clear() + + def record(self, nodeid: str, violations: Iterable[BlockingCall]) -> None: + for violation in violations: + self._observed.append((nodeid, violation)) + + def format_summary(self, *, limit: int = 30) -> str: + if not self._observed: + return "blocking io probe: no violations" + + call_sites: Counter[tuple[str, str, int, str, str]] = Counter() + for _nodeid, violation in self._observed: + frame = self._local_call_site(violation.stack) + if frame is None: + call_sites[(violation.name, "", 0, "", "")] += 1 + continue + + call_sites[ + ( + violation.name, + self._relative(frame.filename), + frame.lineno, + frame.name, + (frame.line or "").strip(), + ) + ] += 1 + + lines = [f"blocking io probe: {self.violation_count} violations across {self.test_count} tests", "Top call sites:"] + for (name, filename, lineno, function, line), count in call_sites.most_common(limit): + lines.append(f"{count:4d} {name} {filename}:{lineno} {function} | {line}") + return "\n".join(lines) + + def _relative(self, filename: str) -> str: + try: + return str(Path(filename).resolve().relative_to(self._project_root)) + except ValueError: + return filename + + def _local_call_site(self, stack: tuple[traceback.FrameSummary, ...]) -> traceback.FrameSummary | None: + local_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename and not self._relative(frame.filename).startswith("tests/")] + if local_frames: + return local_frames[-1] + + test_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename] + return test_frames[-1] if test_frames else None + + +def detect_blocking_io( + specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS, + *, + fail_on_exit: bool = False, + patch_loaded_aliases: bool = True, + stack_limit: int = 12, +) -> BlockingIODetector: + """Create a detector context manager for a focused test scope.""" + + return BlockingIODetector(specs, fail_on_exit=fail_on_exit, patch_loaded_aliases=patch_loaded_aliases, stack_limit=stack_limit) + + +def format_blocking_calls(violations: Iterable[BlockingCall]) -> str: + """Format detector output with enough stack context to locate call sites.""" + + lines = ["Blocking calls were executed on an asyncio event loop thread:"] + for index, violation in enumerate(violations, start=1): + lines.append(f"{index}. {violation.name} ({violation.target})") + lines.extend(_format_stack(violation.stack)) + return "\n".join(lines) + + +def _format_stack(stack: Iterable[traceback.FrameSummary]) -> Iterator[str]: + for frame in stack: + location = f"{frame.filename}:{frame.lineno}" + lines = [f" at {frame.name} ({location})"] + if frame.line: + lines.append(f" {frame.line.strip()}") + yield from lines diff --git a/backend/tests/test_blocking_io_detector.py b/backend/tests/test_blocking_io_detector.py new file mode 100644 index 000000000..af44d746d --- /dev/null +++ b/backend/tests/test_blocking_io_detector.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import asyncio +import os +import time +from os import walk as imported_walk +from pathlib import Path +from time import sleep as imported_sleep + +import httpx +import pytest +import requests +from support.detectors.blocking_io import ( + BlockingCallSpec, + BlockingIOProbe, + detect_blocking_io, +) + +pytestmark = pytest.mark.asyncio + + +TIME_SLEEP_ONLY = (BlockingCallSpec("time.sleep", "time:sleep"),) +REQUESTS_ONLY = (BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),) +HTTPX_ONLY = (BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),) +OS_WALK_ONLY = (BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),) +PATH_READ_TEXT_ONLY = (BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),) + + +async def test_records_time_sleep_on_event_loop() -> None: + with detect_blocking_io(TIME_SLEEP_ONLY) as detector: + time.sleep(0) + + assert [violation.name for violation in detector.violations] == ["time.sleep"] + + +async def test_records_already_imported_sleep_alias_on_event_loop() -> None: + original_alias = imported_sleep + + with detect_blocking_io(TIME_SLEEP_ONLY) as detector: + imported_sleep(0) + + assert imported_sleep is original_alias + assert [violation.name for violation in detector.violations] == ["time.sleep"] + + +async def test_can_disable_loaded_alias_patching() -> None: + with detect_blocking_io(TIME_SLEEP_ONLY, patch_loaded_aliases=False) as detector: + imported_sleep(0) + + assert detector.violations == [] + + +async def test_does_not_record_time_sleep_offloaded_to_thread() -> None: + with detect_blocking_io(TIME_SLEEP_ONLY) as detector: + await asyncio.to_thread(time.sleep, 0) + + assert detector.violations == [] + + +async def test_fixture_allows_offloaded_sync_work(blocking_io_detector) -> None: + await asyncio.to_thread(time.sleep, 0) + + assert blocking_io_detector.violations == [] + + +async def test_does_not_record_sync_call_without_running_event_loop() -> None: + def call_sleep() -> list[str]: + with detect_blocking_io(TIME_SLEEP_ONLY) as detector: + time.sleep(0) + return [violation.name for violation in detector.violations] + + assert await asyncio.to_thread(call_sleep) == [] + + +async def test_fail_on_exit_includes_call_site() -> None: + with pytest.raises(AssertionError) as exc_info: + with detect_blocking_io(TIME_SLEEP_ONLY, fail_on_exit=True): + time.sleep(0) + + message = str(exc_info.value) + assert "time.sleep" in message + assert "test_fail_on_exit_includes_call_site" in message + + +async def test_records_requests_session_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None: + def fake_request(self: requests.Session, method: str, url: str, **kwargs: object) -> str: + return f"{method}:{url}" + + monkeypatch.setattr(requests.sessions.Session, "request", fake_request) + + with detect_blocking_io(REQUESTS_ONLY) as detector: + assert requests.get("https://example.invalid") == "get:https://example.invalid" + + assert [violation.name for violation in detector.violations] == ["requests.Session.request"] + + +async def test_records_sync_httpx_client_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None: + def fake_request(self: httpx.Client, method: str, url: str, **kwargs: object) -> httpx.Response: + return httpx.Response(200, request=httpx.Request(method, url)) + + monkeypatch.setattr(httpx.Client, "request", fake_request) + + with detect_blocking_io(HTTPX_ONLY) as detector: + with httpx.Client() as client: + response = client.get("https://example.invalid") + + assert response.status_code == 200 + assert [violation.name for violation in detector.violations] == ["httpx.Client.request"] + + +async def test_records_os_walk_on_event_loop(tmp_path: Path) -> None: + (tmp_path / "nested").mkdir() + + with detect_blocking_io(OS_WALK_ONLY) as detector: + assert list(os.walk(tmp_path)) + + assert [violation.name for violation in detector.violations] == ["os.walk"] + + +async def test_records_already_imported_os_walk_alias_on_iteration(tmp_path: Path) -> None: + (tmp_path / "nested").mkdir() + original_alias = imported_walk + + with detect_blocking_io(OS_WALK_ONLY) as detector: + assert list(imported_walk(tmp_path)) + + assert imported_walk is original_alias + assert [violation.name for violation in detector.violations] == ["os.walk"] + + +async def test_does_not_record_os_walk_before_iteration(tmp_path: Path) -> None: + with detect_blocking_io(OS_WALK_ONLY) as detector: + walker = os.walk(tmp_path) + + assert list(walker) + assert detector.violations == [] + + +async def test_does_not_record_os_walk_iterated_off_event_loop(tmp_path: Path) -> None: + (tmp_path / "nested").mkdir() + + with detect_blocking_io(OS_WALK_ONLY) as detector: + walker = os.walk(tmp_path) + assert await asyncio.to_thread(lambda: list(walker)) + + assert detector.violations == [] + + +async def test_records_path_read_text_on_event_loop(tmp_path: Path) -> None: + path = tmp_path / "data.txt" + path.write_text("content", encoding="utf-8") + + with detect_blocking_io(PATH_READ_TEXT_ONLY) as detector: + assert path.read_text(encoding="utf-8") == "content" + + assert [violation.name for violation in detector.violations] == ["pathlib.Path.read_text"] + + +async def test_probe_formats_summary_for_recorded_violations(tmp_path: Path) -> None: + probe = BlockingIOProbe(Path(__file__).resolve().parents[1]) + path = tmp_path / "data.txt" + path.write_text("content", encoding="utf-8") + + with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector: + assert path.read_text(encoding="utf-8") == "content" + + probe.record("tests/test_example.py::test_example", detector.violations) + summary = probe.format_summary() + + assert "blocking io probe: 1 violations across 1 tests" in summary + assert "pathlib.Path.read_text" in summary + + +async def test_probe_formats_empty_summary_and_can_be_cleared(tmp_path: Path) -> None: + probe = BlockingIOProbe(Path(__file__).resolve().parents[1]) + + assert probe.format_summary() == "blocking io probe: no violations" + + path = tmp_path / "data.txt" + path.write_text("content", encoding="utf-8") + with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector: + assert path.read_text(encoding="utf-8") == "content" + + probe.record("tests/test_example.py::test_example", detector.violations) + assert probe.violation_count == 1 + + probe.clear() + + assert probe.violation_count == 0 + assert probe.format_summary() == "blocking io probe: no violations" diff --git a/backend/tests/test_blocking_io_probe_integration.py b/backend/tests/test_blocking_io_probe_integration.py new file mode 100644 index 000000000..af7a31b9d --- /dev/null +++ b/backend/tests/test_blocking_io_probe_integration.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import time + +import pytest + +ORIGINAL_SLEEP = time.sleep + + +def replacement_sleep(seconds: float) -> None: + return None + + +def test_probe_survives_monkeypatch_teardown(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(time, "sleep", replacement_sleep) + assert time.sleep is replacement_sleep + + +@pytest.mark.no_blocking_io_probe +def test_probe_restores_original_after_monkeypatch_teardown() -> None: + assert time.sleep is ORIGINAL_SLEEP + assert getattr(time.sleep, "__wrapped__", None) is None From ba864112a3b5e9029d6fe3f46ecb0abb1582d118 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 14 May 2026 11:02:58 +0800 Subject: [PATCH 22/86] chore(deps): bump langsmith from 0.7.36 to 0.8.0 in /backend (#2943) Bumps [langsmith](https://github.com/langchain-ai/langsmith-sdk) from 0.7.36 to 0.8.0. - [Release notes](https://github.com/langchain-ai/langsmith-sdk/releases) - [Commits](https://github.com/langchain-ai/langsmith-sdk/compare/v0.7.36...v0.8.0) --- updated-dependencies: - dependency-name: langsmith dependency-version: 0.8.0 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- backend/uv.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/uv.lock b/backend/uv.lock index e144fb07e..cd6bc8543 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -2005,7 +2005,7 @@ wheels = [ [[package]] name = "langsmith" -version = "0.7.36" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, @@ -2018,9 +2018,9 @@ dependencies = [ { name = "xxhash" }, { name = "zstandard" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8d/4c/5f20508000ee0559bfa713b85c431b1cdc95d2913247ff9eb318e7fdff7b/langsmith-0.7.36.tar.gz", hash = "sha256:d18ef34819e0a252cf52c74ce6e9bd5de6deea4f85a3aef50abc9f48d8c5f8b8", size = 4402322, upload-time = "2026-04-24T16:58:06.681Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a8/64/95f1f013531395f4e8ed73caeee780f65c7c58fe028cb543f8937b45611b/langsmith-0.8.0.tar.gz", hash = "sha256:59fe5b2a56bbbe14a08aa76691f84b49e8675dd21e11b57d80c6db8c08bac2e3", size = 4432996, upload-time = "2026-04-30T22:13:07.341Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f3/8d/3ca31ae3a4a437191243ad6d9061ede9367440bb7dc9a0da1ecc2c2a4865/langsmith-0.7.36-py3-none-any.whl", hash = "sha256:e1657a795f3f1982bb8d34c98b143b630ca3eee9de2c10e670c9105233b54654", size = 381808, upload-time = "2026-04-24T16:58:04.572Z" }, + { url = "https://files.pythonhosted.org/packages/f3/e1/a4be2e696c9473bb53298df398237da5674704d781d4b748ed35aeef592a/langsmith-0.8.0-py3-none-any.whl", hash = "sha256:12cc4bc5622b835a6d841964d6034df3617bdb912dae0c1381fd0a68a9b3a3ef", size = 393268, upload-time = "2026-04-30T22:13:05.56Z" }, ] [package.optional-dependencies] From 722c690f4fc734c5057b6ed250e3ff6b93168313 Mon Sep 17 00:00:00 2001 From: LawranceLiao <32213920+kibabsquirrel@users.noreply.github.com> Date: Fri, 15 May 2026 10:26:35 +0800 Subject: [PATCH 23/86] fix(memory): isolate queued memory updates by agent (#2941) * fix(memory): isolate queued memory updates by agent * fix(memory): include user in queue identity * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * Fix the lint error --------- Co-authored-by: Willem Jiang Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- .../harness/deerflow/agents/memory/queue.py | 14 +++- .../agents/memory/summarization_hook.py | 3 + backend/tests/test_memory_queue.py | 84 ++++++++++++++++++- .../tests/test_memory_queue_user_isolation.py | 39 +++++++++ .../tests/test_summarization_middleware.py | 27 +++++- 5 files changed, 163 insertions(+), 4 deletions(-) diff --git a/backend/packages/harness/deerflow/agents/memory/queue.py b/backend/packages/harness/deerflow/agents/memory/queue.py index b2a147bce..129a28c66 100644 --- a/backend/packages/harness/deerflow/agents/memory/queue.py +++ b/backend/packages/harness/deerflow/agents/memory/queue.py @@ -40,6 +40,15 @@ class MemoryUpdateQueue: self._timer: threading.Timer | None = None self._processing = False + @staticmethod + def _queue_key( + thread_id: str, + user_id: str | None, + agent_name: str | None, + ) -> tuple[str, str | None, str | None]: + """Return the debounce identity for a memory update target.""" + return (thread_id, user_id, agent_name) + def add( self, thread_id: str, @@ -115,8 +124,9 @@ class MemoryUpdateQueue: correction_detected: bool, reinforcement_detected: bool, ) -> None: + queue_key = self._queue_key(thread_id, user_id, agent_name) existing_context = next( - (context for context in self._queue if context.thread_id == thread_id), + (context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) == queue_key), None, ) merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False) @@ -130,7 +140,7 @@ class MemoryUpdateQueue: reinforcement_detected=merged_reinforcement_detected, ) - self._queue = [c for c in self._queue if c.thread_id != thread_id] + self._queue = [context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) != queue_key] self._queue.append(context) def _reset_timer(self) -> None: diff --git a/backend/packages/harness/deerflow/agents/memory/summarization_hook.py b/backend/packages/harness/deerflow/agents/memory/summarization_hook.py index dafa7d977..307548e0a 100644 --- a/backend/packages/harness/deerflow/agents/memory/summarization_hook.py +++ b/backend/packages/harness/deerflow/agents/memory/summarization_hook.py @@ -6,6 +6,7 @@ from deerflow.agents.memory.message_processing import detect_correction, detect_ from deerflow.agents.memory.queue import get_memory_queue from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent from deerflow.config.memory_config import get_memory_config +from deerflow.runtime.user_context import resolve_runtime_user_id def memory_flush_hook(event: SummarizationEvent) -> None: @@ -21,11 +22,13 @@ def memory_flush_hook(event: SummarizationEvent) -> None: correction_detected = detect_correction(filtered_messages) reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages) + user_id = resolve_runtime_user_id(event.runtime) queue = get_memory_queue() queue.add_nowait( thread_id=event.thread_id, messages=filtered_messages, agent_name=event.agent_name, + user_id=user_id, correction_detected=correction_detected, reinforcement_detected=reinforcement_detected, ) diff --git a/backend/tests/test_memory_queue.py b/backend/tests/test_memory_queue.py index 27808b0e8..3d62f0497 100644 --- a/backend/tests/test_memory_queue.py +++ b/backend/tests/test_memory_queue.py @@ -1,6 +1,6 @@ import threading import time -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue from deerflow.config.memory_config import MemoryConfig @@ -164,3 +164,85 @@ def test_flush_nowait_is_non_blocking() -> None: assert elapsed < 0.1 assert finished.is_set() is False assert finished.wait(1.0) is True + + +def test_queue_keeps_updates_for_different_agents_in_same_thread() -> None: + queue = MemoryUpdateQueue() + + with ( + patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)), + patch.object(queue, "_reset_timer"), + ): + queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a") + queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b") + + assert queue.pending_count == 2 + assert [context.agent_name for context in queue._queue] == ["agent-a", "agent-b"] + + +def test_queue_still_coalesces_updates_for_same_agent_in_same_thread() -> None: + queue = MemoryUpdateQueue() + + with ( + patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)), + patch.object(queue, "_reset_timer"), + ): + queue.add( + thread_id="thread-1", + messages=["first"], + agent_name="agent-a", + correction_detected=True, + ) + queue.add( + thread_id="thread-1", + messages=["second"], + agent_name="agent-a", + correction_detected=False, + ) + + assert queue.pending_count == 1 + assert queue._queue[0].agent_name == "agent-a" + assert queue._queue[0].messages == ["second"] + assert queue._queue[0].correction_detected is True + + +def test_process_queue_updates_different_agents_in_same_thread_separately() -> None: + queue = MemoryUpdateQueue() + + with ( + patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)), + patch.object(queue, "_reset_timer"), + ): + queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a") + queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b") + + mock_updater = MagicMock() + mock_updater.update_memory.return_value = True + + with ( + patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater), + patch("deerflow.agents.memory.queue.time.sleep"), + ): + queue.flush() + + assert mock_updater.update_memory.call_count == 2 + mock_updater.update_memory.assert_has_calls( + [ + call( + messages=["agent-a"], + thread_id="thread-1", + agent_name="agent-a", + correction_detected=False, + reinforcement_detected=False, + user_id=None, + ), + call( + messages=["agent-b"], + thread_id="thread-1", + agent_name="agent-b", + correction_detected=False, + reinforcement_detected=False, + user_id=None, + ), + ] + ) diff --git a/backend/tests/test_memory_queue_user_isolation.py b/backend/tests/test_memory_queue_user_isolation.py index 79250817c..ce5d41210 100644 --- a/backend/tests/test_memory_queue_user_isolation.py +++ b/backend/tests/test_memory_queue_user_isolation.py @@ -38,3 +38,42 @@ def test_queue_process_passes_user_id_to_updater(): mock_updater.update_memory.assert_called_once() call_kwargs = mock_updater.update_memory.call_args.kwargs assert call_kwargs["user_id"] == "alice" + + +def test_queue_keeps_updates_for_different_users_in_same_thread_and_agent(): + q = MemoryUpdateQueue() + + with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"): + q.add(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice") + q.add(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob") + + assert q.pending_count == 2 + assert [context.user_id for context in q._queue] == ["alice", "bob"] + assert [context.messages for context in q._queue] == [["alice update"], ["bob update"]] + + +def test_queue_still_coalesces_updates_for_same_user_thread_and_agent(): + q = MemoryUpdateQueue() + + with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"): + q.add(thread_id="main", messages=["first"], agent_name="researcher", user_id="alice") + q.add(thread_id="main", messages=["second"], agent_name="researcher", user_id="alice") + + assert q.pending_count == 1 + assert q._queue[0].messages == ["second"] + assert q._queue[0].user_id == "alice" + assert q._queue[0].agent_name == "researcher" + + +def test_add_nowait_keeps_different_users_separate(): + q = MemoryUpdateQueue() + + with ( + patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), + patch.object(q, "_schedule_timer"), + ): + q.add_nowait(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice") + q.add_nowait(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob") + + assert q.pending_count == 2 + assert [context.user_id for context in q._queue] == ["alice", "bob"] diff --git a/backend/tests/test_summarization_middleware.py b/backend/tests/test_summarization_middleware.py index cbd94e434..9cd4fc725 100644 --- a/backend/tests/test_summarization_middleware.py +++ b/backend/tests/test_summarization_middleware.py @@ -30,12 +30,18 @@ def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage: ) -def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace: +def _runtime( + thread_id: str | None = "thread-1", + agent_name: str | None = None, + user_id: str | None = None, +) -> SimpleNamespace: context = {} if thread_id is not None: context["thread_id"] = thread_id if agent_name is not None: context["agent_name"] = agent_name + if user_id is not None: + context["user_id"] = user_id return SimpleNamespace(context=context) @@ -634,3 +640,22 @@ def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.Mon queue.add_nowait.assert_called_once() assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent" + + +def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatch) -> None: + queue = MagicMock() + monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True)) + monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue) + + memory_flush_hook( + SummarizationEvent( + messages_to_summarize=tuple(_messages()[:2]), + preserved_messages=(), + thread_id="main", + agent_name="researcher", + runtime=_runtime(thread_id="main", agent_name="researcher", user_id="alice"), + ) + ) + + queue.add_nowait.assert_called_once() + assert queue.add_nowait.call_args.kwargs["user_id"] == "alice" From 45060a9ffcfbda8f0dc0427ae09f518e496f4f33 Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Fri, 15 May 2026 04:32:09 +0200 Subject: [PATCH 24/86] fix(runtime): avoid postgres aggregate row lock (#2962) --- .../deerflow/runtime/events/store/db.py | 33 ++++++++++++++----- backend/tests/test_run_event_store.py | 33 +++++++++++++++++++ 2 files changed, 58 insertions(+), 8 deletions(-) diff --git a/backend/packages/harness/deerflow/runtime/events/store/db.py b/backend/packages/harness/deerflow/runtime/events/store/db.py index 9374769f3..b7e54754f 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/db.py +++ b/backend/packages/harness/deerflow/runtime/events/store/db.py @@ -11,7 +11,7 @@ import logging from datetime import UTC, datetime from typing import Any -from sqlalchemy import delete, func, select +from sqlalchemy import delete, func, select, text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from deerflow.persistence.models.run_event import RunEventRow @@ -86,6 +86,28 @@ class DbRunEventStore(RunEventStore): user = get_current_user() return str(user.id) if user is not None else None + @staticmethod + async def _max_seq_for_thread(session: AsyncSession, thread_id: str) -> int | None: + """Return the current max seq while serializing writers per thread. + + PostgreSQL rejects ``SELECT max(...) FOR UPDATE`` because aggregate + results are not lockable rows. As a release-safe workaround, take a + transaction-level advisory lock keyed by thread_id before reading the + aggregate. Other dialects keep the existing row-locking statement. + """ + stmt = select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id) + bind = session.get_bind() + dialect_name = bind.dialect.name if bind is not None else "" + + if dialect_name == "postgresql": + await session.execute( + text("SELECT pg_advisory_xact_lock(hashtext(CAST(:thread_id AS text))::bigint)"), + {"thread_id": thread_id}, + ) + return await session.scalar(stmt) + + return await session.scalar(stmt.with_for_update()) + async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401 """Write a single event — low-frequency path only. @@ -100,10 +122,7 @@ class DbRunEventStore(RunEventStore): user_id = self._user_id_from_context() async with self._sf() as session: async with session.begin(): - # Use FOR UPDATE to serialize seq assignment within a thread. - # NOTE: with_for_update() on aggregates is a no-op on SQLite; - # the UNIQUE(thread_id, seq) constraint catches races there. - max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update()) + max_seq = await self._max_seq_for_thread(session, thread_id) seq = (max_seq or 0) + 1 row = RunEventRow( thread_id=thread_id, @@ -126,10 +145,8 @@ class DbRunEventStore(RunEventStore): async with self._sf() as session: async with session.begin(): # Get max seq for the thread (assume all events in batch belong to same thread). - # NOTE: with_for_update() on aggregates is a no-op on SQLite; - # the UNIQUE(thread_id, seq) constraint catches races there. thread_id = events[0]["thread_id"] - max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update()) + max_seq = await self._max_seq_for_thread(session, thread_id) seq = max_seq or 0 rows = [] for e in events: diff --git a/backend/tests/test_run_event_store.py b/backend/tests/test_run_event_store.py index d2c78ccf0..17b796af7 100644 --- a/backend/tests/test_run_event_store.py +++ b/backend/tests/test_run_event_store.py @@ -268,6 +268,39 @@ class TestEdgeCases: class TestDbRunEventStore: """Tests for DbRunEventStore with temp SQLite.""" + @pytest.mark.anyio + async def test_postgres_max_seq_uses_advisory_lock_without_for_update(self): + from sqlalchemy.dialects import postgresql + + from deerflow.runtime.events.store.db import DbRunEventStore + + class FakeSession: + def __init__(self): + self.dialect = postgresql.dialect() + self.execute_calls = [] + self.scalar_stmt = None + + def get_bind(self): + return self + + async def execute(self, stmt, params=None): + self.execute_calls.append((stmt, params)) + + async def scalar(self, stmt): + self.scalar_stmt = stmt + return 41 + + session = FakeSession() + + max_seq = await DbRunEventStore._max_seq_for_thread(session, "thread-1") + + assert max_seq == 41 + assert session.execute_calls + assert session.execute_calls[0][1] == {"thread_id": "thread-1"} + assert "pg_advisory_xact_lock" in str(session.execute_calls[0][0]) + compiled = str(session.scalar_stmt.compile(dialect=postgresql.dialect())) + assert "FOR UPDATE" not in compiled + @pytest.mark.anyio async def test_basic_crud(self, tmp_path): from deerflow.persistence.engine import close_engine, get_session_factory, init_engine From 181d836541069a9d22708cc5f12c1e51f60ffed1 Mon Sep 17 00:00:00 2001 From: LawranceLiao <32213920+kibabsquirrel@users.noreply.github.com> Date: Fri, 15 May 2026 22:09:04 +0800 Subject: [PATCH 25/86] fix(middleware): normalize tool result adjacency before model calls (#2939) * normalizing tool-call transcripts before invocation * test(middleware): cover tool result regrouping edge cases --- .../dangling_tool_call_middleware.py | 49 ++++++----- .../test_dangling_tool_call_middleware.py | 82 +++++++++++++++++++ 2 files changed, 109 insertions(+), 22 deletions(-) diff --git a/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py index 5bb54f3e5..000ca51a2 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py @@ -104,45 +104,46 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): return "[Tool call was interrupted and did not return a result.]" def _build_patched_messages(self, messages: list) -> list | None: - """Return a new message list with patches inserted at the correct positions. + """Return messages with tool results grouped after their tool-call AIMessage. - For each AIMessage with dangling tool_calls (no corresponding ToolMessage), - a synthetic ToolMessage is inserted immediately after that AIMessage. - Returns None if no patches are needed. + This normalizes model-bound causal order before provider serialization while + preserving already-valid transcripts unchanged. """ - # Collect IDs of all existing ToolMessages - existing_tool_msg_ids: set[str] = set() + tool_messages_by_id: dict[str, ToolMessage] = {} for msg in messages: if isinstance(msg, ToolMessage): - existing_tool_msg_ids.add(msg.tool_call_id) + tool_messages_by_id.setdefault(msg.tool_call_id, msg) - # Check if any patching is needed - needs_patch = False + tool_call_ids: set[str] = set() for msg in messages: if getattr(msg, "type", None) != "ai": continue for tc in self._message_tool_calls(msg): tc_id = tc.get("id") - if tc_id and tc_id not in existing_tool_msg_ids: - needs_patch = True - break - if needs_patch: - break + if tc_id: + tool_call_ids.add(tc_id) - if not needs_patch: - return None - - # Build new list with patches inserted right after each dangling AIMessage patched: list = [] - patched_ids: set[str] = set() + consumed_tool_msg_ids: set[str] = set() patch_count = 0 for msg in messages: + if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids: + continue + patched.append(msg) if getattr(msg, "type", None) != "ai": continue + for tc in self._message_tool_calls(msg): tc_id = tc.get("id") - if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids: + if not tc_id or tc_id in consumed_tool_msg_ids: + continue + + existing_tool_msg = tool_messages_by_id.get(tc_id) + if existing_tool_msg is not None: + patched.append(existing_tool_msg) + consumed_tool_msg_ids.add(tc_id) + else: patched.append( ToolMessage( content=self._synthetic_tool_message_content(tc), @@ -151,10 +152,14 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): status="error", ) ) - patched_ids.add(tc_id) + consumed_tool_msg_ids.add(tc_id) patch_count += 1 - logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls") + if patched == messages: + return None + + if patch_count: + logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls") return patched @override diff --git a/backend/tests/test_dangling_tool_call_middleware.py b/backend/tests/test_dangling_tool_call_middleware.py index b1d5c476a..f9f47369d 100644 --- a/backend/tests/test_dangling_tool_call_middleware.py +++ b/backend/tests/test_dangling_tool_call_middleware.py @@ -158,6 +158,88 @@ class TestBuildPatchedMessagesPatching: assert patched[1].name == "bash" assert patched[1].status == "error" + def test_non_adjacent_tool_result_is_moved_next_to_tool_call(self): + middleware = DanglingToolCallMiddleware() + msgs = [ + _ai_with_tool_calls([_tc("bash", "call_1")]), + HumanMessage(content="interruption"), + _tool_msg("call_1", "bash"), + ] + patched = middleware._build_patched_messages(msgs) + assert patched is not None + assert isinstance(patched[0], AIMessage) + assert isinstance(patched[1], ToolMessage) + assert patched[1].tool_call_id == "call_1" + assert isinstance(patched[2], HumanMessage) + + def test_multiple_tool_results_stay_grouped_after_ai_tool_call(self): + mw = DanglingToolCallMiddleware() + msgs = [ + _ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]), + HumanMessage(content="interruption"), + _tool_msg("call_2", "read"), + _tool_msg("call_1", "bash"), + ] + + patched = mw._build_patched_messages(msgs) + + assert patched is not None + assert isinstance(patched[0], AIMessage) + assert isinstance(patched[1], ToolMessage) + assert isinstance(patched[2], ToolMessage) + assert [patched[1].tool_call_id, patched[2].tool_call_id] == ["call_1", "call_2"] + assert isinstance(patched[3], HumanMessage) + + def test_valid_adjacent_tool_results_are_unchanged(self): + mw = DanglingToolCallMiddleware() + msgs = [ + _ai_with_tool_calls([_tc("bash", "call_1")]), + _tool_msg("call_1", "bash"), + HumanMessage(content="next"), + ] + + assert mw._build_patched_messages(msgs) is None + + def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self): + mw = DanglingToolCallMiddleware() + msgs = [ + _ai_with_tool_calls([_tc("bash", "call_1")]), + HumanMessage(content="interruption"), + _ai_with_tool_calls([_tc("read", "call_2")]), + _tool_msg("call_1", "bash"), + _tool_msg("call_2", "read"), + ] + + patched = mw._build_patched_messages(msgs) + + assert patched is not None + assert isinstance(patched[0], AIMessage) + assert isinstance(patched[1], ToolMessage) + assert patched[1].tool_call_id == "call_1" + assert isinstance(patched[2], HumanMessage) + assert isinstance(patched[3], AIMessage) + assert isinstance(patched[4], ToolMessage) + assert patched[4].tool_call_id == "call_2" + + def test_orphan_tool_message_is_preserved_during_grouping(self): + mw = DanglingToolCallMiddleware() + orphan = _tool_msg("orphan_call", "orphan") + msgs = [ + _ai_with_tool_calls([_tc("bash", "call_1")]), + orphan, + HumanMessage(content="interruption"), + _tool_msg("call_1", "bash"), + ] + + patched = mw._build_patched_messages(msgs) + + assert patched is not None + assert isinstance(patched[0], AIMessage) + assert isinstance(patched[1], ToolMessage) + assert patched[1].tool_call_id == "call_1" + assert orphan in patched + assert patched.count(orphan) == 1 + def test_invalid_tool_call_is_patched(self): mw = DanglingToolCallMiddleware() msgs = [_ai_with_invalid_tool_calls([_invalid_tc()])] From 0c37509b3854f18de64df3d5b0b2b43bc09d9b6f Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Fri, 15 May 2026 16:12:37 +0200 Subject: [PATCH 26/86] fix(middleware): Prevent todo completion reminder IMMessage leak (#2907) * fix(middleware): Prevent todo completion reminder IMMessage leak (#2892) * make format * fix(middleware): Clear stale todo reminder counts (#2892) * add size guard for _completion_reminder_counts and add a integration test --- .../agents/middlewares/todo_middleware.py | 222 +++++++++- backend/tests/test_todo_middleware.py | 384 +++++++++++++++++- frontend/src/core/messages/utils.ts | 15 +- .../tests/unit/core/messages/utils.test.ts | 34 ++ 4 files changed, 608 insertions(+), 47 deletions(-) diff --git a/backend/packages/harness/deerflow/agents/middlewares/todo_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/todo_middleware.py index b8cd10884..9215aefc5 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/todo_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/todo_middleware.py @@ -7,17 +7,21 @@ reminder message so the model still knows about the outstanding todo list. Additionally, this middleware prevents the agent from exiting the loop while there are still incomplete todo items. When the model produces a final response -(no tool calls) but todos are not yet complete, the middleware injects a reminder -and jumps back to the model node to force continued engagement. +(no tool calls) but todos are not yet complete, the middleware queues a reminder +for the next model request and jumps back to the model node to force continued +engagement. The completion reminder is injected via ``wrap_model_call`` instead +of being persisted into graph state as a normal user-visible message. """ from __future__ import annotations +import threading +from collections.abc import Awaitable, Callable from typing import Any, override from langchain.agents.middleware import TodoListMiddleware from langchain.agents.middleware.todo import PlanningState, Todo -from langchain.agents.middleware.types import hook_config +from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config from langchain_core.messages import AIMessage, HumanMessage from langgraph.runtime import Runtime @@ -55,6 +59,51 @@ def _format_todos(todos: list[Todo]) -> str: return "\n".join(lines) +def _format_completion_reminder(todos: list[Todo]) -> str: + """Format a completion reminder for incomplete todo items.""" + incomplete = [t for t in todos if t.get("status") != "completed"] + incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete) + return ( + "\n" + "You have incomplete todo items that must be finished before giving your final response:\n\n" + f"{incomplete_text}\n\n" + "Please continue working on these tasks. Call `write_todos` to mark items as completed " + "as you finish them, and only respond when all items are done.\n" + "" + ) + + +_TOOL_CALL_FINISH_REASONS = {"tool_calls", "function_call"} + + +def _has_tool_call_intent_or_error(message: AIMessage) -> bool: + """Return True when an AIMessage is not a clean final answer. + + Todo completion reminders should only fire when the model has produced a + plain final response. Provider/tool parsing details have moved across + LangChain versions and integrations, so keep all tool-intent/error signals + behind this helper instead of checking one concrete field at the call site. + """ + if message.tool_calls: + return True + + if getattr(message, "invalid_tool_calls", None): + return True + + # Backward/provider compatibility: some integrations preserve raw or legacy + # tool-call intent in additional_kwargs even when structured tool_calls is + # empty. If this helper changes, update the matching sentinel test + # `TestToolCallIntentOrError.test_langchain_ai_message_tool_fields_are_explicitly_handled`; + # if that test fails after a LangChain upgrade, review this helper so new + # tool-call/error fields are not silently treated as clean final answers. + additional_kwargs = getattr(message, "additional_kwargs", {}) or {} + if additional_kwargs.get("tool_calls") or additional_kwargs.get("function_call"): + return True + + response_metadata = getattr(message, "response_metadata", {}) or {} + return response_metadata.get("finish_reason") in _TOOL_CALL_FINISH_REASONS + + class TodoMiddleware(TodoListMiddleware): """Extends TodoListMiddleware with `write_todos` context-loss detection. @@ -89,6 +138,7 @@ class TodoMiddleware(TodoListMiddleware): formatted = _format_todos(todos) reminder = HumanMessage( name="todo_reminder", + additional_kwargs={"hide_from_ui": True}, content=( "\n" "Your todo list from earlier is no longer visible in the current context window, " @@ -113,6 +163,100 @@ class TodoMiddleware(TodoListMiddleware): # Maximum number of completion reminders before allowing the agent to exit. # This prevents infinite loops when the agent cannot make further progress. _MAX_COMPLETION_REMINDERS = 2 + # Hard cap for per-run reminder bookkeeping in long-lived middleware instances. + _MAX_COMPLETION_REMINDER_KEYS = 4096 + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._lock = threading.Lock() + self._pending_completion_reminders: dict[tuple[str, str], list[str]] = {} + self._completion_reminder_counts: dict[tuple[str, str], int] = {} + self._completion_reminder_touch_order: dict[tuple[str, str], int] = {} + self._completion_reminder_next_order = 0 + + @staticmethod + def _get_thread_id(runtime: Runtime) -> str: + context = getattr(runtime, "context", None) + thread_id = context.get("thread_id") if context else None + return str(thread_id) if thread_id else "default" + + @staticmethod + def _get_run_id(runtime: Runtime) -> str: + context = getattr(runtime, "context", None) + run_id = context.get("run_id") if context else None + return str(run_id) if run_id else "default" + + def _pending_key(self, runtime: Runtime) -> tuple[str, str]: + return self._get_thread_id(runtime), self._get_run_id(runtime) + + def _touch_completion_reminder_key_locked(self, key: tuple[str, str]) -> None: + self._completion_reminder_next_order += 1 + self._completion_reminder_touch_order[key] = self._completion_reminder_next_order + + def _completion_reminder_keys_locked(self) -> set[tuple[str, str]]: + keys = set(self._pending_completion_reminders) + keys.update(self._completion_reminder_counts) + keys.update(self._completion_reminder_touch_order) + return keys + + def _drop_completion_reminder_key_locked(self, key: tuple[str, str]) -> None: + self._pending_completion_reminders.pop(key, None) + self._completion_reminder_counts.pop(key, None) + self._completion_reminder_touch_order.pop(key, None) + + def _prune_completion_reminder_state_locked(self, protected_key: tuple[str, str]) -> None: + keys = self._completion_reminder_keys_locked() + overflow = len(keys) - self._MAX_COMPLETION_REMINDER_KEYS + if overflow <= 0: + return + + candidates = [key for key in keys if key != protected_key] + candidates.sort(key=lambda key: self._completion_reminder_touch_order.get(key, 0)) + for key in candidates[:overflow]: + self._drop_completion_reminder_key_locked(key) + + def _queue_completion_reminder(self, runtime: Runtime, reminder: str) -> None: + key = self._pending_key(runtime) + with self._lock: + self._pending_completion_reminders.setdefault(key, []).append(reminder) + self._completion_reminder_counts[key] = self._completion_reminder_counts.get(key, 0) + 1 + self._touch_completion_reminder_key_locked(key) + self._prune_completion_reminder_state_locked(protected_key=key) + + def _completion_reminder_count_for_runtime(self, runtime: Runtime) -> int: + key = self._pending_key(runtime) + with self._lock: + return self._completion_reminder_counts.get(key, 0) + + def _drain_completion_reminders(self, runtime: Runtime) -> list[str]: + key = self._pending_key(runtime) + with self._lock: + reminders = self._pending_completion_reminders.pop(key, []) + if reminders or key in self._completion_reminder_counts: + self._touch_completion_reminder_key_locked(key) + return reminders + + def _clear_other_run_completion_reminders(self, runtime: Runtime) -> None: + thread_id, current_run_id = self._pending_key(runtime) + with self._lock: + for key in self._completion_reminder_keys_locked(): + if key[0] == thread_id and key[1] != current_run_id: + self._drop_completion_reminder_key_locked(key) + + def _clear_current_run_completion_reminders(self, runtime: Runtime) -> None: + key = self._pending_key(runtime) + with self._lock: + self._drop_completion_reminder_key_locked(key) + + @override + def before_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: + self._clear_other_run_completion_reminders(runtime) + return None + + @override + async def abefore_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: + self._clear_other_run_completion_reminders(runtime) + return None @hook_config(can_jump_to=["model"]) @override @@ -137,10 +281,12 @@ class TodoMiddleware(TodoListMiddleware): if base_result is not None: return base_result - # 2. Only intervene when the agent wants to exit (no tool calls). + # 2. Only intervene when the agent wants to exit cleanly. Tool-call + # intent or tool-call parse errors should be handled by the tool path + # instead of being masked by todo reminders. messages = state.get("messages") or [] last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None) - if not last_ai or last_ai.tool_calls: + if not last_ai or _has_tool_call_intent_or_error(last_ai): return None # 3. Allow exit when all todos are completed or there are no todos. @@ -149,24 +295,14 @@ class TodoMiddleware(TodoListMiddleware): return None # 4. Enforce a reminder cap to prevent infinite re-engagement loops. - if _completion_reminder_count(messages) >= self._MAX_COMPLETION_REMINDERS: + if self._completion_reminder_count_for_runtime(runtime) >= self._MAX_COMPLETION_REMINDERS: return None - # 5. Inject a reminder and force the agent back to the model. - incomplete = [t for t in todos if t.get("status") != "completed"] - incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete) - reminder = HumanMessage( - name="todo_completion_reminder", - content=( - "\n" - "You have incomplete todo items that must be finished before giving your final response:\n\n" - f"{incomplete_text}\n\n" - "Please continue working on these tasks. Call `write_todos` to mark items as completed " - "as you finish them, and only respond when all items are done.\n" - "" - ), - ) - return {"jump_to": "model", "messages": [reminder]} + # 5. Queue a reminder for the next model request and jump back. We must + # not persist this control prompt as a normal HumanMessage, otherwise it + # can leak into user-visible message streams and saved transcripts. + self._queue_completion_reminder(runtime, _format_completion_reminder(todos)) + return {"jump_to": "model"} @override @hook_config(can_jump_to=["model"]) @@ -177,3 +313,47 @@ class TodoMiddleware(TodoListMiddleware): ) -> dict[str, Any] | None: """Async version of after_model.""" return self.after_model(state, runtime) + + @staticmethod + def _format_pending_completion_reminders(reminders: list[str]) -> str: + return "\n\n".join(dict.fromkeys(reminders)) + + def _augment_request(self, request: ModelRequest) -> ModelRequest: + reminders = self._drain_completion_reminders(request.runtime) + if not reminders: + return request + new_messages = [ + *request.messages, + HumanMessage( + content=self._format_pending_completion_reminders(reminders), + name="todo_completion_reminder", + additional_kwargs={"hide_from_ui": True}, + ), + ] + return request.override(messages=new_messages) + + @override + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: + return handler(self._augment_request(request)) + + @override + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelCallResult: + return await handler(self._augment_request(request)) + + @override + def after_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: + self._clear_current_run_completion_reminders(runtime) + return None + + @override + async def aafter_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None: + self._clear_current_run_completion_reminders(runtime) + return None diff --git a/backend/tests/test_todo_middleware.py b/backend/tests/test_todo_middleware.py index efeee9eb0..934e730f2 100644 --- a/backend/tests/test_todo_middleware.py +++ b/backend/tests/test_todo_middleware.py @@ -1,14 +1,19 @@ """Tests for TodoMiddleware context-loss detection.""" import asyncio -from unittest.mock import MagicMock +from typing import Any +from unittest.mock import AsyncMock, MagicMock +from langchain.agents import create_agent +from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel from langchain_core.messages import AIMessage, HumanMessage +from pydantic import PrivateAttr from deerflow.agents.middlewares.todo_middleware import ( TodoMiddleware, _completion_reminder_count, _format_todos, + _has_tool_call_intent_or_error, _reminder_in_messages, _todos_in_messages, ) @@ -22,9 +27,35 @@ def _reminder_msg(): return HumanMessage(name="todo_reminder", content="reminder") +class _CapturingFakeMessagesListChatModel(FakeMessagesListChatModel): + _seen_messages: list[list[Any]] = PrivateAttr(default_factory=list) + + @property + def seen_messages(self) -> list[list[Any]]: + return self._seen_messages + + def bind_tools(self, tools, *, tool_choice=None, **kwargs): + return self + + def _generate(self, messages, stop=None, run_manager=None, **kwargs): + self._seen_messages.append(list(messages)) + return super()._generate( + messages, + stop=stop, + run_manager=run_manager, + **kwargs, + ) + + def _make_runtime(): runtime = MagicMock() - runtime.context = {"thread_id": "test-thread"} + runtime.context = {"thread_id": "test-thread", "run_id": "test-run"} + return runtime + + +def _make_runtime_for(thread_id: str, run_id: str): + runtime = _make_runtime() + runtime.context = {"thread_id": thread_id, "run_id": run_id} return runtime @@ -161,10 +192,62 @@ def _completion_reminder_msg(): return HumanMessage(name="todo_completion_reminder", content="finish your todos") +def _todo_completion_reminders(messages): + reminders = [] + for message in messages: + if isinstance(message, HumanMessage) and message.name == "todo_completion_reminder": + reminders.append(message) + return reminders + + def _ai_no_tool_calls(): return AIMessage(content="I'm done!") +def _ai_with_invalid_tool_calls(): + return AIMessage( + content="", + tool_calls=[], + invalid_tool_calls=[ + { + "type": "invalid_tool_call", + "id": "write_file:36", + "name": "write_file", + "args": "{invalid", + "error": "Failed to parse tool arguments", + } + ], + ) + + +def _ai_with_raw_provider_tool_calls(): + return AIMessage( + content="", + tool_calls=[], + invalid_tool_calls=[], + additional_kwargs={ + "tool_calls": [ + { + "id": "raw-tool-call", + "type": "function", + "function": {"name": "write_file", "arguments": '{"path":"report.md"}'}, + } + ] + }, + ) + + +def _ai_with_legacy_function_call(): + return AIMessage( + content="", + additional_kwargs={"function_call": {"name": "write_file", "arguments": '{"path":"report.md"}'}}, + ) + + +def _ai_with_tool_finish_reason(): + return AIMessage(content="", response_metadata={"finish_reason": "tool_calls"}) + + def _incomplete_todos(): return [ {"status": "completed", "content": "Step 1"}, @@ -194,6 +277,36 @@ class TestCompletionReminderCount: assert _completion_reminder_count(msgs) == 1 +class TestToolCallIntentOrError: + def test_false_for_plain_final_answer(self): + assert _has_tool_call_intent_or_error(_ai_no_tool_calls()) is False + + def test_true_for_structured_tool_calls(self): + assert _has_tool_call_intent_or_error(_ai_with_write_todos()) is True + + def test_true_for_invalid_tool_calls(self): + assert _has_tool_call_intent_or_error(_ai_with_invalid_tool_calls()) is True + + def test_true_for_raw_provider_tool_calls(self): + assert _has_tool_call_intent_or_error(_ai_with_raw_provider_tool_calls()) is True + + def test_true_for_legacy_function_call(self): + assert _has_tool_call_intent_or_error(_ai_with_legacy_function_call()) is True + + def test_true_for_tool_finish_reason(self): + assert _has_tool_call_intent_or_error(_ai_with_tool_finish_reason()) is True + + def test_langchain_ai_message_tool_fields_are_explicitly_handled(self): + # Sentinel for LangChain compatibility: if future AIMessage versions add + # new top-level tool/function-call fields, this test should fail. When + # it does, update `_has_tool_call_intent_or_error()` so the completion + # reminder guard explicitly decides whether each new field means "not a + # clean final answer"; the helper has a matching comment pointing back + # to this sentinel. + tool_related_fields = {name for name in AIMessage.model_fields if "tool" in name.lower() or ("function" in name.lower() and "call" in name.lower())} + assert tool_related_fields <= {"tool_calls", "invalid_tool_calls"} + + class TestAfterModel: def test_returns_none_when_agent_still_using_tools(self): mw = TodoMiddleware() @@ -235,68 +348,299 @@ class TestAfterModel: } assert mw.after_model(state, _make_runtime()) is None - def test_injects_reminder_and_jumps_to_model_when_incomplete(self): + def test_queues_reminder_and_jumps_to_model_when_incomplete(self): mw = TodoMiddleware() + runtime = _make_runtime() state = { "messages": [HumanMessage(content="hi"), _ai_no_tool_calls()], "todos": _incomplete_todos(), } - result = mw.after_model(state, _make_runtime()) + result = mw.after_model(state, runtime) assert result is not None assert result["jump_to"] == "model" - assert len(result["messages"]) == 1 - reminder = result["messages"][0] + assert "messages" not in result + + request = MagicMock() + request.runtime = runtime + request.messages = state["messages"] + request.override.return_value = "patched-request" + handler = MagicMock(return_value="response") + + assert mw.wrap_model_call(request, handler) == "response" + request.override.assert_called_once() + reminder = request.override.call_args.kwargs["messages"][-1] assert isinstance(reminder, HumanMessage) assert reminder.name == "todo_completion_reminder" + assert reminder.additional_kwargs["hide_from_ui"] is True assert "Step 2" in reminder.content assert "Step 3" in reminder.content + handler.assert_called_once_with("patched-request") def test_reminder_lists_only_incomplete_items(self): mw = TodoMiddleware() + runtime = _make_runtime() state = { "messages": [_ai_no_tool_calls()], "todos": _incomplete_todos(), } - result = mw.after_model(state, _make_runtime()) - content = result["messages"][0].content + result = mw.after_model(state, runtime) + assert result is not None + + request = MagicMock() + request.runtime = runtime + request.messages = state["messages"] + request.override.return_value = "patched-request" + mw.wrap_model_call(request, MagicMock(return_value="response")) + content = request.override.call_args.kwargs["messages"][-1].content assert "Step 1" not in content # completed — should not appear assert "Step 2" in content assert "Step 3" in content def test_allows_exit_after_max_reminders(self): mw = TodoMiddleware() + runtime = _make_runtime() state = { "messages": [ - _completion_reminder_msg(), - _completion_reminder_msg(), _ai_no_tool_calls(), ], "todos": _incomplete_todos(), } + assert mw.after_model(state, runtime) is not None + assert mw.after_model(state, runtime) is not None + assert mw.after_model(state, runtime) is None + + def test_still_sends_reminder_before_cap(self): + mw = TodoMiddleware() + runtime = _make_runtime() + state = { + "messages": [ + _ai_no_tool_calls(), + ], + "todos": _incomplete_todos(), + } + assert mw.after_model(state, runtime) is not None + result = mw.after_model(state, runtime) + assert result is not None + assert result["jump_to"] == "model" + + def test_does_not_trigger_for_invalid_tool_calls(self): + mw = TodoMiddleware() + state = { + "messages": [_ai_with_invalid_tool_calls()], + "todos": _incomplete_todos(), + } assert mw.after_model(state, _make_runtime()) is None - def test_still_sends_reminder_before_cap(self): + def test_does_not_trigger_for_raw_provider_tool_calls(self): mw = TodoMiddleware() state = { - "messages": [ - _completion_reminder_msg(), # 1 reminder so far - _ai_no_tool_calls(), - ], + "messages": [_ai_with_raw_provider_tool_calls()], "todos": _incomplete_todos(), } - result = mw.after_model(state, _make_runtime()) - assert result is not None - assert result["jump_to"] == "model" + assert mw.after_model(state, _make_runtime()) is None + + def test_does_not_trigger_for_legacy_function_call(self): + mw = TodoMiddleware() + state = { + "messages": [_ai_with_legacy_function_call()], + "todos": _incomplete_todos(), + } + assert mw.after_model(state, _make_runtime()) is None + + def test_does_not_trigger_for_tool_finish_reason(self): + mw = TodoMiddleware() + state = { + "messages": [_ai_with_tool_finish_reason()], + "todos": _incomplete_todos(), + } + assert mw.after_model(state, _make_runtime()) is None class TestAafterModel: def test_delegates_to_sync(self): mw = TodoMiddleware() + runtime = _make_runtime() state = { "messages": [_ai_no_tool_calls()], "todos": _incomplete_todos(), } - result = asyncio.run(mw.aafter_model(state, _make_runtime())) + result = asyncio.run(mw.aafter_model(state, runtime)) assert result is not None assert result["jump_to"] == "model" - assert result["messages"][0].name == "todo_completion_reminder" + assert "messages" not in result + + +class TestWrapModelCall: + def test_no_pending_reminder_passthrough(self): + mw = TodoMiddleware() + request = MagicMock() + request.runtime = _make_runtime() + request.messages = [HumanMessage(content="hi")] + handler = MagicMock(return_value="response") + + assert mw.wrap_model_call(request, handler) == "response" + request.override.assert_not_called() + handler.assert_called_once_with(request) + + def test_pending_reminder_is_injected_once(self): + mw = TodoMiddleware() + runtime = _make_runtime() + state = { + "messages": [_ai_no_tool_calls()], + "todos": _incomplete_todos(), + } + mw.after_model(state, runtime) + + request = MagicMock() + request.runtime = runtime + request.messages = state["messages"] + request.override.return_value = "patched-request" + handler = MagicMock(return_value="response") + + assert mw.wrap_model_call(request, handler) == "response" + injected_messages = request.override.call_args.kwargs["messages"] + assert injected_messages[-1].name == "todo_completion_reminder" + + request.override.reset_mock() + handler.reset_mock() + handler.return_value = "second-response" + assert mw.wrap_model_call(request, handler) == "second-response" + request.override.assert_not_called() + handler.assert_called_once_with(request) + + +class TestTodoMiddlewareAgentGraphIntegration: + def test_completion_reminder_is_transient_in_real_agent_graph(self): + mw = TodoMiddleware() + model = _CapturingFakeMessagesListChatModel( + responses=[ + AIMessage( + content="", + tool_calls=[ + { + "name": "write_todos", + "id": "todos-1", + "args": { + "todos": [ + {"content": "Step 1", "status": "completed"}, + {"content": "Step 2", "status": "pending"}, + ] + }, + } + ], + ), + AIMessage(content="premature final 1"), + AIMessage(content="premature final 2"), + AIMessage(content="premature final 3"), + ], + ) + graph = create_agent(model=model, tools=[], middleware=[mw]) + + result = graph.invoke( + {"messages": [("user", "finish all todos")]}, + context={"thread_id": "integration-thread", "run_id": "integration-run"}, + ) + + assert len(model.seen_messages) == 4 + reminders_by_call = [_todo_completion_reminders(messages) for messages in model.seen_messages] + assert reminders_by_call[0] == [] + assert reminders_by_call[1] == [] + assert len(reminders_by_call[2]) == 1 + assert len(reminders_by_call[3]) == 1 + assert "Step 1" not in reminders_by_call[2][0].content + assert "Step 2" in reminders_by_call[2][0].content + + persisted_reminders = _todo_completion_reminders(result["messages"]) + assert persisted_reminders == [] + assert result["messages"][-1].content == "premature final 3" + assert result["todos"] == [ + {"content": "Step 1", "status": "completed"}, + {"content": "Step 2", "status": "pending"}, + ] + assert mw._pending_completion_reminders == {} + assert mw._completion_reminder_counts == {} + + +class TestRunScopedReminderCleanup: + def test_before_agent_clears_stale_count_without_pending_reminder(self): + mw = TodoMiddleware() + stale_runtime = _make_runtime() + stale_runtime.context = {"thread_id": "test-thread", "run_id": "stale-run"} + current_runtime = _make_runtime() + current_runtime.context = {"thread_id": "test-thread", "run_id": "current-run"} + other_thread_runtime = _make_runtime() + other_thread_runtime.context = {"thread_id": "other-thread", "run_id": "stale-run"} + + state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()} + assert mw.after_model(state, stale_runtime) is not None + assert mw.after_model(state, other_thread_runtime) is not None + + # Simulate a model call that drained the pending message, followed by an + # abnormal run end where after_agent did not clear the reminder count. + assert mw._drain_completion_reminders(stale_runtime) + assert mw._completion_reminder_count_for_runtime(stale_runtime) == 1 + + mw.before_agent({}, current_runtime) + + assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0 + assert mw._completion_reminder_count_for_runtime(other_thread_runtime) == 1 + + def test_size_guard_prunes_oldest_count_only_reminder_state(self): + mw = TodoMiddleware() + mw._MAX_COMPLETION_REMINDER_KEYS = 2 + first_runtime = _make_runtime_for("thread-a", "run-a") + second_runtime = _make_runtime_for("thread-b", "run-b") + third_runtime = _make_runtime_for("thread-c", "run-c") + + state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()} + assert mw.after_model(state, first_runtime) is not None + + # Simulate the normal model request path: pending reminder is consumed, + # but the run count remains until after_agent() or stale cleanup. + assert mw._drain_completion_reminders(first_runtime) + assert mw._completion_reminder_count_for_runtime(first_runtime) == 1 + + assert mw.after_model(state, second_runtime) is not None + assert mw.after_model(state, third_runtime) is not None + + assert mw._completion_reminder_count_for_runtime(first_runtime) == 0 + assert mw._completion_reminder_count_for_runtime(second_runtime) == 1 + assert mw._completion_reminder_count_for_runtime(third_runtime) == 1 + assert ("thread-a", "run-a") not in mw._completion_reminder_touch_order + + def test_size_guard_prunes_pending_and_count_state_together(self): + mw = TodoMiddleware() + mw._MAX_COMPLETION_REMINDER_KEYS = 1 + stale_runtime = _make_runtime_for("thread-a", "run-a") + current_runtime = _make_runtime_for("thread-b", "run-b") + + state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()} + assert mw.after_model(state, stale_runtime) is not None + assert mw.after_model(state, current_runtime) is not None + + assert mw._drain_completion_reminders(stale_runtime) == [] + assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0 + assert mw._completion_reminder_count_for_runtime(current_runtime) == 1 + + +class TestAwrapModelCall: + def test_async_pending_reminder_is_injected(self): + mw = TodoMiddleware() + runtime = _make_runtime() + state = { + "messages": [_ai_no_tool_calls()], + "todos": _incomplete_todos(), + } + mw.after_model(state, runtime) + + request = MagicMock() + request.runtime = runtime + request.messages = state["messages"] + request.override.return_value = "patched-request" + handler = AsyncMock(return_value="response") + + result = asyncio.run(mw.awrap_model_call(request, handler)) + assert result == "response" + injected_messages = request.override.call_args.kwargs["messages"] + assert injected_messages[-1].name == "todo_completion_reminder" + handler.assert_awaited_once_with("patched-request") diff --git a/frontend/src/core/messages/utils.ts b/frontend/src/core/messages/utils.ts index e20daa1b6..3f1fef9ad 100644 --- a/frontend/src/core/messages/utils.ts +++ b/frontend/src/core/messages/utils.ts @@ -26,6 +26,13 @@ export type MessageGroup = | AssistantClarificationGroup | AssistantSubagentGroup; +const HIDDEN_CONTROL_MESSAGE_NAMES = new Set([ + "summary", + "loop_warning", + "todo_reminder", + "todo_completion_reminder", +]); + export function getMessageGroups(messages: Message[]): MessageGroup[] { if (messages.length === 0) { return []; @@ -53,10 +60,6 @@ export function getMessageGroups(messages: Message[]): MessageGroup[] { continue; } - if (message.name === "todo_reminder") { - continue; - } - if (message.type === "human") { groups.push({ id: message.id, type: "human", messages: [message] }); continue; @@ -368,8 +371,8 @@ export function findToolCallResult(toolCallId: string, messages: Message[]) { export function isHiddenFromUIMessage(message: Message) { return ( message.additional_kwargs?.hide_from_ui === true || - message.name === "summary" || - message.name === "loop_warning" + (typeof message.name === "string" && + HIDDEN_CONTROL_MESSAGE_NAMES.has(message.name)) ); } diff --git a/frontend/tests/unit/core/messages/utils.test.ts b/frontend/tests/unit/core/messages/utils.test.ts index 24d014c7e..cbc245583 100644 --- a/frontend/tests/unit/core/messages/utils.test.ts +++ b/frontend/tests/unit/core/messages/utils.test.ts @@ -63,3 +63,37 @@ test("aggregates token usage messages once per assistant turn", () => { ), ).toEqual([null, null, ["ai-1", "ai-2"], null, ["ai-3"]]); }); + +test("hides internal todo reminder messages from message groups", () => { + const messages = [ + { + id: "human-1", + type: "human", + content: "Audit the middleware", + }, + { + id: "todo-reminder-1", + type: "human", + name: "todo_completion_reminder", + content: "finish todos", + }, + { + id: "todo-reminder-2", + type: "human", + name: "todo_reminder", + content: "remember todos", + }, + { + id: "ai-1", + type: "ai", + content: "Done", + }, + ] as Message[]; + + const groups = getMessageGroups(messages); + + expect(groups.map((group) => group.type)).toEqual(["human", "assistant"]); + expect( + groups.flatMap((group) => group.messages).map((message) => message.id), + ).toEqual(["human-1", "ai-1"]); +}); From 7a2670eaea01501a3d54c94e980bd513a8c42efb Mon Sep 17 00:00:00 2001 From: Hinotobi Date: Fri, 15 May 2026 22:15:58 +0800 Subject: [PATCH 27/86] fix(gateway): cap skill artifact preview size (#2963) --- backend/app/gateway/routers/artifacts.py | 29 ++++++++++++++++++++---- backend/tests/test_artifacts_router.py | 15 ++++++++++++ 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/backend/app/gateway/routers/artifacts.py b/backend/app/gateway/routers/artifacts.py index 78ea5fa00..a2cc5b02b 100644 --- a/backend/app/gateway/routers/artifacts.py +++ b/backend/app/gateway/routers/artifacts.py @@ -20,6 +20,9 @@ ACTIVE_CONTENT_MIME_TYPES = { "image/svg+xml", } +MAX_SKILL_ARCHIVE_MEMBER_BYTES = 16 * 1024 * 1024 +_SKILL_ARCHIVE_READ_CHUNK_SIZE = 64 * 1024 + def _build_content_disposition(disposition_type: str, filename: str) -> str: """Build an RFC 5987 encoded Content-Disposition header value.""" @@ -44,6 +47,22 @@ def is_text_file_by_content(path: Path, sample_size: int = 8192) -> bool: return False +def _read_skill_archive_member(zip_ref: zipfile.ZipFile, info: zipfile.ZipInfo) -> bytes: + """Read a .skill archive member while enforcing an uncompressed size cap.""" + if info.file_size > MAX_SKILL_ARCHIVE_MEMBER_BYTES: + raise HTTPException(status_code=413, detail="Skill archive member is too large to preview") + + chunks: list[bytes] = [] + total_read = 0 + with zip_ref.open(info, "r") as src: + while chunk := src.read(_SKILL_ARCHIVE_READ_CHUNK_SIZE): + total_read += len(chunk) + if total_read > MAX_SKILL_ARCHIVE_MEMBER_BYTES: + raise HTTPException(status_code=413, detail="Skill archive member is too large to preview") + chunks.append(chunk) + return b"".join(chunks) + + def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> bytes | None: """Extract a file from a .skill ZIP archive. @@ -60,16 +79,16 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte try: with zipfile.ZipFile(zip_path, "r") as zip_ref: # List all files in the archive - namelist = zip_ref.namelist() + infos_by_name = {info.filename: info for info in zip_ref.infolist()} # Try direct path first - if internal_path in namelist: - return zip_ref.read(internal_path) + if internal_path in infos_by_name: + return _read_skill_archive_member(zip_ref, infos_by_name[internal_path]) # Try with any top-level directory prefix (e.g., "skill-name/SKILL.md") - for name in namelist: + for name, info in infos_by_name.items(): if name.endswith("/" + internal_path) or name == internal_path: - return zip_ref.read(name) + return _read_skill_archive_member(zip_ref, info) # Not found return None diff --git a/backend/tests/test_artifacts_router.py b/backend/tests/test_artifacts_router.py index df32e45dc..f0627ff7b 100644 --- a/backend/tests/test_artifacts_router.py +++ b/backend/tests/test_artifacts_router.py @@ -4,6 +4,7 @@ from pathlib import Path import pytest from _router_auth_helpers import call_unwrapped, make_authed_test_app +from fastapi import HTTPException from fastapi.testclient import TestClient from starlette.requests import Request from starlette.responses import FileResponse @@ -102,3 +103,17 @@ def test_get_artifact_download_true_forces_attachment_for_skill_archive(tmp_path assert response.status_code == 200 assert response.text == "hello" assert response.headers.get("content-disposition", "").startswith("attachment;") + + +def test_skill_archive_preview_rejects_oversized_member_before_decompression(tmp_path) -> None: + skill_path = tmp_path / "sample.skill" + payload = b"A" * (artifacts_router.MAX_SKILL_ARCHIVE_MEMBER_BYTES + 1) + with zipfile.ZipFile(skill_path, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as zip_ref: + zip_ref.writestr("SKILL.md", payload) + + assert skill_path.stat().st_size < artifacts_router.MAX_SKILL_ARCHIVE_MEMBER_BYTES + + with pytest.raises(HTTPException) as exc_info: + artifacts_router._extract_file_from_skill_archive(skill_path, "SKILL.md") + + assert exc_info.value.status_code == 413 From 7c42ab3e1670235de2f317521f0b13ff16b0a104 Mon Sep 17 00:00:00 2001 From: Admire <64821731+LittleChenLiya@users.noreply.github.com> Date: Fri, 15 May 2026 22:27:10 +0800 Subject: [PATCH 28/86] fix(frontend): wait for async chat submit before clearing (#2940) * fix(frontend): wait for async chat submit before clearing * test(frontend): cover pending attachment uploads * fix(frontend): preserve sync submit semantics --- .../[agent_name]/chats/[thread_id]/page.tsx | 12 +++- .../app/workspace/chats/[thread_id]/page.tsx | 6 +- .../components/ai-elements/prompt-input.tsx | 32 +++++++--- .../src/components/workspace/input-box.tsx | 25 ++++++-- frontend/tests/e2e/chat.spec.ts | 62 +++++++++++++++++++ 5 files changed, 120 insertions(+), 17 deletions(-) diff --git a/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx b/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx index 8627762b0..c16af882a 100644 --- a/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx +++ b/frontend/src/app/workspace/agents/[agent_name]/chats/[thread_id]/page.tsx @@ -66,6 +66,7 @@ export default function AgentChatPage() { thread, pendingUsageMessages, sendMessage, + isUploading, isHistoryLoading, hasMoreHistory, loadMoreHistory, @@ -106,7 +107,11 @@ export default function AgentChatPage() { const handleSubmit = useCallback( (message: PromptInputMessage) => { - void sendMessage(threadId, message, { agent_name }); + const sendPromise = sendMessage(threadId, message, { agent_name }); + if (message.files.length > 0) { + return sendPromise; + } + void sendPromise; }, [sendMessage, threadId, agent_name], ); @@ -243,7 +248,10 @@ export default function AgentChatPage() { ) } - disabled={env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true"} + disabled={ + env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true" || + isUploading + } onContextChange={(context) => setSettings("context", context)} onSubmit={handleSubmit} onStop={handleStop} diff --git a/frontend/src/app/workspace/chats/[thread_id]/page.tsx b/frontend/src/app/workspace/chats/[thread_id]/page.tsx index ed7d91c68..6f865ade8 100644 --- a/frontend/src/app/workspace/chats/[thread_id]/page.tsx +++ b/frontend/src/app/workspace/chats/[thread_id]/page.tsx @@ -109,7 +109,11 @@ export default function ChatPage() { const handleSubmit = useCallback( (message: PromptInputMessage) => { - void sendMessage(threadId, message); + const sendPromise = sendMessage(threadId, message); + if (message.files.length > 0) { + return sendPromise; + } + void sendPromise; }, [sendMessage, threadId], ); diff --git a/frontend/src/components/ai-elements/prompt-input.tsx b/frontend/src/components/ai-elements/prompt-input.tsx index 52a909cdd..4609c43d3 100644 --- a/frontend/src/components/ai-elements/prompt-input.tsx +++ b/frontend/src/components/ai-elements/prompt-input.tsx @@ -499,6 +499,10 @@ export const PromptInput = ({ // Keep a ref to files for cleanup on unmount (avoids stale closure) const filesRef = useRef(files); filesRef.current = files; + const providerTextRef = useRef(""); + if (usingProvider) { + providerTextRef.current = controller.textInput.value; + } const openFileDialogLocal = useCallback(() => { inputRef.current?.click(); @@ -768,6 +772,24 @@ export const PromptInput = ({ } // Convert blob URLs to data URLs asynchronously + const submittedFileIds = files.map((file) => file.id); + const clearSubmittedState = () => { + const currentFileIds = new Set(filesRef.current.map((file) => file.id)); + const submittedFileIdsStillPresent = submittedFileIds.filter((id) => + currentFileIds.has(id), + ); + if (submittedFileIdsStillPresent.length === filesRef.current.length) { + clear(); + } else { + for (const id of submittedFileIdsStillPresent) { + remove(id); + } + } + if (usingProvider && providerTextRef.current === text) { + controller.textInput.clear(); + } + }; + Promise.all( files.map(async ({ id, ...item }) => { if (item.file instanceof File) { @@ -793,20 +815,14 @@ export const PromptInput = ({ if (result instanceof Promise) { result .then(() => { - clear(); - if (usingProvider) { - controller.textInput.clear(); - } + clearSubmittedState(); }) .catch(() => { // Don't clear on error - user may want to retry }); } else { // Sync function completed without throwing, clear attachments - clear(); - if (usingProvider) { - controller.textInput.clear(); - } + clearSubmittedState(); } } catch { // Don't clear on error - user may want to retry diff --git a/frontend/src/components/workspace/input-box.tsx b/frontend/src/components/workspace/input-box.tsx index 9a33d41e6..6344a26d2 100644 --- a/frontend/src/components/workspace/input-box.tsx +++ b/frontend/src/components/workspace/input-box.tsx @@ -110,6 +110,7 @@ export function InputBox({ threadId, initialValue, onContextChange, + onFollowupsVisibilityChange, onSubmit, onStop, ...props @@ -142,7 +143,8 @@ export function InputBox({ reasoning_effort?: "minimal" | "low" | "medium" | "high"; }, ) => void; - onSubmit?: (message: PromptInputMessage) => void; + onFollowupsVisibilityChange?: (visible: boolean) => void; + onSubmit?: (message: PromptInputMessage) => void | Promise; onStop?: () => void; }) { const { t } = useI18n(); @@ -251,12 +253,12 @@ export function InputBox({ ); const handleSubmit = useCallback( - async (message: PromptInputMessage) => { + (message: PromptInputMessage) => { if (status === "streaming") { onStop?.(); return; } - if (!message.text) { + if (!message.text.trim() && message.files.length === 0) { return; } setFollowups([]); @@ -274,11 +276,14 @@ export function InputBox({ selectedModel?.supports_thinking ?? false, ), }); - setTimeout(() => onSubmit?.(message), 0); - return; + return new Promise((resolve, reject) => { + setTimeout(() => { + Promise.resolve(onSubmit?.(message)).then(resolve).catch(reject); + }, 0); + }); } - onSubmit?.(message); + return onSubmit?.(message); }, [ context, @@ -348,6 +353,14 @@ export function InputBox({ !followupsHidden && (followupsLoading || followups.length > 0); + useEffect(() => { + onFollowupsVisibilityChange?.(showFollowups); + }, [onFollowupsVisibilityChange, showFollowups]); + + useEffect(() => { + return () => onFollowupsVisibilityChange?.(false); + }, [onFollowupsVisibilityChange]); + useEffect(() => { messagesRef.current = thread.messages; }, [thread.messages]); diff --git a/frontend/tests/e2e/chat.spec.ts b/frontend/tests/e2e/chat.spec.ts index 490305de9..e608793df 100644 --- a/frontend/tests/e2e/chat.spec.ts +++ b/frontend/tests/e2e/chat.spec.ts @@ -48,4 +48,66 @@ test.describe("Chat workspace", () => { timeout: 10_000, }); }); + + test("keeps attachments visible while upload submit is pending", async ({ + page, + }) => { + let releaseUpload!: () => void; + const uploadCanFinish = new Promise((resolve) => { + releaseUpload = resolve; + }); + let uploadStarted!: () => void; + const uploadStartedPromise = new Promise((resolve) => { + uploadStarted = resolve; + }); + + await page.route("**/api/threads/*/uploads", async (route) => { + uploadStarted(); + await uploadCanFinish; + return route.fulfill({ + status: 200, + contentType: "application/json", + body: JSON.stringify({ + success: true, + message: "Uploaded", + files: [ + { + filename: "report.docx", + size: 12, + path: "report.docx", + virtual_path: "/mnt/user-data/uploads/report.docx", + artifact_url: "/api/threads/test/uploads/report.docx", + extension: ".docx", + }, + ], + }), + }); + }); + + await page.goto("/workspace/chats/new"); + + const textarea = page.getByPlaceholder(/how can i assist you/i); + await expect(textarea).toBeVisible({ timeout: 15_000 }); + const promptForm = page.locator("form").filter({ has: textarea }); + + await page.getByLabel("Upload files").setInputFiles({ + name: "report.docx", + mimeType: + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + buffer: Buffer.from("fake docx"), + }); + await expect(promptForm.getByText("report.docx")).toBeVisible(); + + await textarea.fill("Summarize this document"); + await textarea.press("Enter"); + + await uploadStartedPromise; + await expect(promptForm.getByText("report.docx")).toBeVisible(); + + releaseUpload(); + await expect(page.getByText("Hello from DeerFlow!")).toBeVisible({ + timeout: 10_000, + }); + await expect(promptForm.getByText("report.docx")).toBeHidden(); + }); }); From 48e038f75283798f358c8522970ac0bc72022cb9 Mon Sep 17 00:00:00 2001 From: Yi Tang <6054101+yitang@users.noreply.github.com> Date: Fri, 15 May 2026 15:30:05 +0100 Subject: [PATCH 29/86] feat(channels): enhance Discord with mention-only mode, thread routing, and typing indicators (#2842) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(channels): enhance Discord with mention-only mode, thread routing, and typing indicators Add mention_only config to only respond when bot is mentioned, with allowed_channels override. Add thread_mode for Hermes-style auto-thread creation. Add periodic typing indicators while bot is processing. * fix(discord): include allowed_channels in mention_only skip condition (line 274) * docs: fix Discord config example to match boolean thread_mode implementation * style: format with ruff * fix(discord): apply Copilot review fixes and resolve lint errors - Remove unused Optional import - Fix thread_ts type hints to str | None - Fix has_mention logic for None values - Implement thread_mode fallback to channel replies on thread creation failure - Fix thread_mode docstring alignment - Fix allowed_channels comment formatting in config.example.yaml * fix(discord): reset context for orphaned threads in mention_only mode When a message arrives in a thread not tracked by _active_threads, clear thread_id and typing_target so the message falls through to the standard channel handling pipeline, which creates a fresh thread instead of incorrectly routing to the stale thread. * fix(discord): create new thread on @ when channel has existing tracked thread When mention_only is enabled and a user @-s the bot in a channel that already has a tracked thread, create a new thread instead of incorrectly routing to the old one. * fix(discord): allow no-@ thread replies while skipping no-@ channel messages The skip block for no-@ messages was too aggressive — it blocked continuation replies within tracked threads AND incorrectly routed no-@ channel messages to the existing thread. Now: - Thread message, no @ → routed to existing tracked thread - Channel message, no @ → skipped - Channel message, with @ → creates new thread * feat(discord): add checkmark reaction to acknowledge received messages * Move discord.py to optional dependency and auto-detect from config.yaml - Add discord extra to [project.optional-dependencies] in pyproject.toml - Update detect_uv_extras.py to map channels.discord.enabled: true -> --extra discord - Set UV_EXTRAS=discord in docker-compose-dev.yaml gateway env * fix(discord): persist thread-channel mappings to store for recovery after restart Discord's _active_threads dict was purely in-memory, so all channel-to-thread mappings were lost on server restart. This fix bridges ChannelStore into DiscordChannel: - Save thread mappings to store.json after every thread creation - Restore active threads from store on DiscordChannel startup - Pass channel_store to all channels via service.py config injection Store keys follow the pattern: discord:: * fix(discord): address Copilot review — fix types, typing targets, cross-thread safety, and config comments * fix(tests): add multitask_strategy param to mock for clarification follow-up test * fix(tests): explicitly set model_name=None for title middleware test isolation * fix(discord): use trigger_typing() instead of typing() for typing indicators discord.py 2.x TextChannel.typing() and Thread.typing() are async context managers, not one-shot coroutines. Use trigger_typing() for periodic typing indicator pings. * fix(discord): cancel typing tasks on channel shutdown Prevents 'Task was destroyed but it is pending' warnings when the Discord client stops while typing indicator loops are still running. * fix(scripts): detect nested YAML config for discord extra section_value() only matched top-level YAML sections. Added nested_section_value() that handles two-level nesting (e.g., channels.discord.enabled), so auto-detection of the discord extra works when config uses the standard nested format. * fix(docker): remove hard-coded UV_EXTRAS=discord from dev compose Relies on auto-detection via detect_uv_extras.py instead of forcing discord.py install even when channels.discord.enabled is false. Matches production docker-compose.yaml behavior (UV_EXTRAS:-). * refactor(nginx): move proxy_buffering/proxy_cache to server level DRY cleanup — these directives were repeated in 14 location blocks. Set at server level once, reducing duplication and risk of drift. * fix(discord): use dedicated JSON file for thread persistence Replace ChannelStore usage for Discord thread-ID persistence with a dedicated discord_threads.json file. ChannelStore is designed to map IM conversations to DeerFlow thread IDs — using it to persist Discord thread IDs was semantically wrong and confusing. Changes: - _save_thread() now reads/writes a simple {channel_id: thread_id} JSON dict - _load_active_threads() reads directly from the JSON file - File path derived from ChannelStore directory (when available) or defaults to ~/.deer-flow/channels/discord_threads.json - Removed unused ChannelStore import * fix(discord): address WillemJiang's code review comments on PR #2842 1. Remove semantically incorrect message_in_thread variable. At this code point (after the Thread case is handled above), we're guaranteed to be in a channel, not a thread. Always apply mention_only check here. 2. Add _active_thread_ids reverse-lookup set for O(1) thread ID membership checks instead of O(n) scan of _active_threads.values(). Keep the set in sync with _active_threads in _load_active_threads() and _save_thread(). 3. Add _thread_store_lock (threading.Lock) to protect _active_threads and the JSON file from concurrent access between the Discord loop thread (_run_client) and the main thread (_load_active_threads, _save_thread). --- backend/app/channels/discord.py | 302 +++++++++++++++++- backend/app/channels/manager.py | 23 +- backend/app/channels/service.py | 2 + backend/pyproject.toml | 1 + backend/tests/test_channels.py | 2 +- backend/tests/test_mindie_provider.py | 1 - .../tests/test_title_middleware_core_logic.py | 2 +- backend/uv.lock | 21 +- config.example.yaml | 8 + docker/nginx/nginx.conf | 24 +- docker/nginx/nginx.local.conf | 41 +++ scripts/detect_uv_extras.py | 81 +++++ 12 files changed, 482 insertions(+), 26 deletions(-) diff --git a/backend/app/channels/discord.py b/backend/app/channels/discord.py index 2d2889126..3b113c28d 100644 --- a/backend/app/channels/discord.py +++ b/backend/app/channels/discord.py @@ -3,8 +3,10 @@ from __future__ import annotations import asyncio +import json import logging import threading +from pathlib import Path from typing import Any from app.channels.base import Channel @@ -21,6 +23,12 @@ class DiscordChannel(Channel): Configuration keys (in ``config.yaml`` under ``channels.discord``): - ``bot_token``: Discord Bot token. - ``allowed_guilds``: (optional) List of allowed Discord guild IDs. Empty = allow all. + - ``mention_only``: (optional) If true, only respond when the bot is mentioned. + - ``allowed_channels``: (optional) List of channel IDs where messages are always accepted + (even when mention_only is true). Use for channels where you want the bot to respond + without mentions. Empty = mention_only applies everywhere. + - ``thread_mode``: (optional) If true, group a channel conversation into a thread. + Default: same as ``mention_only``. """ def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None: @@ -32,6 +40,29 @@ class DiscordChannel(Channel): self._allowed_guilds.add(int(guild_id)) except (TypeError, ValueError): continue + self._mention_only: bool = bool(config.get("mention_only", False)) + self._thread_mode: bool = config.get("thread_mode", self._mention_only) + self._allowed_channels: set[str] = set() + for channel_id in config.get("allowed_channels", []): + self._allowed_channels.add(str(channel_id)) + + # Session tracking: channel_id -> Discord thread_id (in-memory, persisted to JSON). + # Uses a dedicated JSON file separate from ChannelStore, which maps IM + # conversations to DeerFlow thread IDs — a different concern. + self._active_threads: dict[str, str] = {} + # Reverse-lookup set for O(1) thread ID checks (avoids O(n) scan of _active_threads.values()). + self._active_thread_ids: set[str] = set() + # Lock protecting _active_threads and the JSON file from concurrent access. + # _run_client (Discord loop thread) and the main thread both read/write. + self._thread_store_lock = threading.Lock() + store = config.get("channel_store") + if store is not None: + self._thread_store_path = store._path.parent / "discord_threads.json" + else: + self._thread_store_path = Path.home() / ".deer-flow" / "channels" / "discord_threads.json" + + # Typing indicator management + self._typing_tasks: dict[str, asyncio.Task] = {} self._client = None self._thread: threading.Thread | None = None @@ -75,12 +106,56 @@ class DiscordChannel(Channel): self._thread = threading.Thread(target=self._run_client, daemon=True) self._thread.start() + self._load_active_threads() logger.info("Discord channel started") + def _load_active_threads(self) -> None: + """Restore Discord thread mappings from the dedicated JSON file on startup.""" + with self._thread_store_lock: + try: + if not self._thread_store_path.exists(): + logger.debug("[Discord] no thread mappings file at %s", self._thread_store_path) + return + data = json.loads(self._thread_store_path.read_text()) + self._active_threads.clear() + self._active_thread_ids.clear() + for channel_id, thread_id in data.items(): + self._active_threads[channel_id] = thread_id + self._active_thread_ids.add(thread_id) + if self._active_threads: + logger.info("[Discord] restored %d thread mappings from %s", len(self._active_threads), self._thread_store_path) + except Exception: + logger.exception("[Discord] failed to load thread mappings") + + def _save_thread(self, channel_id: str, thread_id: str) -> None: + """Persist a Discord thread mapping to the dedicated JSON file.""" + with self._thread_store_lock: + try: + data: dict[str, str] = {} + if self._thread_store_path.exists(): + data = json.loads(self._thread_store_path.read_text()) + old_id = data.get(channel_id) + data[channel_id] = thread_id + # Update reverse-lookup set + if old_id: + self._active_thread_ids.discard(old_id) + self._active_thread_ids.add(thread_id) + self._thread_store_path.parent.mkdir(parents=True, exist_ok=True) + self._thread_store_path.write_text(json.dumps(data, indent=2)) + except Exception: + logger.exception("[Discord] failed to save thread mapping for channel %s", channel_id) + async def stop(self) -> None: self._running = False self.bus.unsubscribe_outbound(self._on_outbound) + # Cancel all active typing indicator tasks + for target_id, task in list(self._typing_tasks.items()): + if not task.done(): + task.cancel() + logger.debug("[Discord] cancelled typing task for target %s", target_id) + self._typing_tasks.clear() + if self._client and self._discord_loop and self._discord_loop.is_running(): close_future = asyncio.run_coroutine_threadsafe(self._client.close(), self._discord_loop) try: @@ -100,6 +175,10 @@ class DiscordChannel(Channel): logger.info("Discord channel stopped") async def send(self, msg: OutboundMessage) -> None: + # Stop typing indicator once we're sending the response + stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop) + await asyncio.wrap_future(stop_future) + target = await self._resolve_target(msg) if target is None: logger.error("[Discord] target not found for chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts) @@ -111,6 +190,9 @@ class DiscordChannel(Channel): await asyncio.wrap_future(send_future) async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool: + stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop) + await asyncio.wrap_future(stop_future) + target = await self._resolve_target(msg) if target is None: logger.error("[Discord] target not found for file upload chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts) @@ -130,6 +212,41 @@ class DiscordChannel(Channel): logger.exception("[Discord] failed to upload file: %s", attachment.filename) return False + async def _start_typing(self, channel, chat_id: str, thread_ts: str | None = None) -> None: + """Starts a loop to send periodic typing indicators.""" + target_id = thread_ts or chat_id + if target_id in self._typing_tasks: + return # Already typing for this target + + async def _typing_loop(): + try: + while True: + try: + await channel.trigger_typing() + except Exception: + pass + await asyncio.sleep(10) + except asyncio.CancelledError: + pass + + task = asyncio.create_task(_typing_loop()) + self._typing_tasks[target_id] = task + + async def _stop_typing(self, chat_id: str, thread_ts: str | None = None) -> None: + """Stops the typing loop for a specific target.""" + target_id = thread_ts or chat_id + task = self._typing_tasks.pop(target_id, None) + if task and not task.done(): + task.cancel() + logger.debug("[Discord] stopped typing indicator for target %s", target_id) + + async def _add_reaction(self, message) -> None: + """Add a checkmark reaction to acknowledge the message was received.""" + try: + await message.add_reaction("✅") + except Exception: + logger.debug("[Discord] failed to add reaction to message %s", message.id, exc_info=True) + async def _on_message(self, message) -> None: if not self._running or not self._client: return @@ -152,15 +269,143 @@ class DiscordChannel(Channel): if self._discord_module is None: return - if isinstance(message.channel, self._discord_module.Thread): - chat_id = str(message.channel.parent_id or message.channel.id) - thread_id = str(message.channel.id) + # Determine whether the bot is mentioned in this message + user = self._client.user if self._client else None + if user: + bot_mention = user.mention # <@ID> + alt_mention = f"<@!{user.id}>" # <@!ID> (ping variant) + standard_mention = f"<@{user.id}>" else: - thread = await self._create_thread(message) - if thread is None: + bot_mention = None + alt_mention = None + standard_mention = "" + has_mention = (bot_mention and bot_mention in message.content) or (alt_mention and alt_mention in message.content) or (standard_mention and standard_mention in message.content) + + # Strip mention from text for processing + if has_mention: + text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip() + # Don't return early if text is empty — still process the mention (e.g., create thread) + + # --- Determine thread/channel routing and typing target --- + thread_id = None + chat_id = None + typing_target = None # The Discord object to type into + + if isinstance(message.channel, self._discord_module.Thread): + # --- Message already inside a thread --- + thread_obj = message.channel + thread_id = str(thread_obj.id) + chat_id = str(thread_obj.parent_id or thread_obj.id) + typing_target = thread_obj + + # If this is a known active thread, process normally + if thread_id in self._active_thread_ids: + msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT + inbound = self._make_inbound( + chat_id=chat_id, + user_id=str(message.author.id), + text=text, + msg_type=msg_type, + thread_ts=thread_id, + metadata={ + "guild_id": str(guild.id) if guild else None, + "channel_id": str(message.channel.id), + "message_id": str(message.id), + }, + ) + inbound.topic_id = thread_id + self._publish(inbound) + # Start typing indicator in the thread + if typing_target: + asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id)) + asyncio.create_task(self._add_reaction(message)) return - chat_id = str(message.channel.id) - thread_id = str(thread.id) + + # Thread not tracked (orphaned) — create new thread and handle below + logger.debug("[Discord] message in orphaned thread %s, will create new thread", thread_id) + thread_id = None + typing_target = None + + # At this point we're guaranteed to be in a channel, not a thread + # (the Thread case is handled above). Apply mention_only for all + # non-thread messages — no special case needed. + channel_id = str(message.channel.id) + + # Check if there's an active thread for this channel + if channel_id in self._active_threads: + # respect mention_only: if enabled, only process messages that mention the bot + # (unless the channel is in allowed_channels) + # Messages within a thread are always allowed through (continuation). + # At this code point we know the message is in a channel, not a thread + # (Thread case handled above), so always apply the check. + if self._mention_only and not has_mention and channel_id not in self._allowed_channels: + logger.debug("[Discord] skipping no-@ message in channel %s (not in thread)", channel_id) + return + # mention_only + fresh @ → create new thread instead of routing to existing one + if self._mention_only and has_mention: + thread_obj = await self._create_thread(message) + if thread_obj is not None: + target_thread_id = str(thread_obj.id) + self._active_threads[channel_id] = target_thread_id + self._save_thread(channel_id, target_thread_id) + thread_id = target_thread_id + chat_id = channel_id + typing_target = thread_obj + logger.info("[Discord] created new thread %s in channel %s on mention (replacing existing thread)", target_thread_id, channel_id) + else: + logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id) + thread_id = channel_id + chat_id = channel_id + typing_target = message.channel + else: + # Existing session → route to the existing thread + target_thread_id = self._active_threads[channel_id] + logger.debug("[Discord] routing message in channel %s to existing thread %s", channel_id, target_thread_id) + thread_id = target_thread_id + chat_id = channel_id + typing_target = await self._get_channel_or_thread(target_thread_id) + elif self._mention_only and not has_mention and channel_id not in self._allowed_channels: + # Not mentioned and not in an allowed channel → skip + logger.debug("[Discord] skipping message without mention in channel %s", channel_id) + return + elif self._mention_only and has_mention: + # First mention in this channel → create thread + thread_obj = await self._create_thread(message) + if thread_obj is not None: + target_thread_id = str(thread_obj.id) + self._active_threads[channel_id] = target_thread_id + self._save_thread(channel_id, target_thread_id) + thread_id = target_thread_id + chat_id = channel_id + typing_target = thread_obj # Type into the new thread + logger.info("[Discord] created thread %s in channel %s for user %s", target_thread_id, channel_id, message.author.display_name) + else: + # Fallback: thread creation failed (disabled/permissions), reply in channel + logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id) + thread_id = channel_id + chat_id = channel_id + typing_target = message.channel # Type into the channel + elif self._thread_mode: + # thread_mode but mention_only is False → create thread anyway for conversation grouping + thread_obj = await self._create_thread(message) + if thread_obj is None: + # Thread creation failed (disabled/permissions), fall back to channel replies + logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id) + thread_id = channel_id + chat_id = channel_id + typing_target = message.channel # Type into the channel + else: + target_thread_id = str(thread_obj.id) + self._active_threads[channel_id] = target_thread_id + self._save_thread(channel_id, target_thread_id) + thread_id = target_thread_id + chat_id = channel_id + typing_target = thread_obj # Type into the new thread + else: + # No threading — reply directly in channel + thread_id = channel_id + chat_id = channel_id + typing_target = message.channel # Type into the channel msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT inbound = self._make_inbound( @@ -177,6 +422,15 @@ class DiscordChannel(Channel): ) inbound.topic_id = thread_id + # Start typing indicator in the correct target (thread or channel) + if typing_target: + asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id)) + + self._publish(inbound) + asyncio.create_task(self._add_reaction(message)) + + def _publish(self, inbound) -> None: + """Publish an inbound message to the main event loop.""" if self._main_loop and self._main_loop.is_running(): future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop) future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None) @@ -198,14 +452,40 @@ class DiscordChannel(Channel): async def _create_thread(self, message): try: + if self._discord_module is None: + return None + + # Only TextChannel (type 0) and NewsChannel (type 10) support threads + channel_type = message.channel.type + if channel_type not in ( + self._discord_module.ChannelType.text, + self._discord_module.ChannelType.news, + ): + logger.info( + "[Discord] channel type %s (%s) does not support threads", + channel_type.value, + channel_type.name, + ) + return None + thread_name = f"deerflow-{message.author.display_name}-{message.id}"[:100] return await message.create_thread(name=thread_name) + except self._discord_module.errors.HTTPException as exc: + if exc.code == 50024: + logger.info( + "[Discord] cannot create thread in channel %s (error code 50024): %s", + message.channel.id, + channel_type.name if (channel_type := message.channel.type) else "unknown", + ) + else: + logger.exception( + "[Discord] failed to create thread for message=%s (HTTPException %s)", + message.id, + exc.code, + ) + return None except Exception: logger.exception("[Discord] failed to create thread for message=%s (threads may be disabled or missing permissions)", message.id) - try: - await message.channel.send("Could not create a thread for your message. Please check that threads are enabled in this channel.") - except Exception: - pass return None async def _resolve_target(self, msg: OutboundMessage): diff --git a/backend/app/channels/manager.py b/backend/app/channels/manager.py index e59dbcf2c..aa52fa298 100644 --- a/backend/app/channels/manager.py +++ b/backend/app/channels/manager.py @@ -787,13 +787,22 @@ class ChannelManager: return logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100]) - result = await client.runs.wait( - thread_id, - assistant_id, - input={"messages": [{"role": "human", "content": msg.text}]}, - config=run_config, - context=run_context, - ) + try: + result = await client.runs.wait( + thread_id, + assistant_id, + input={"messages": [{"role": "human", "content": msg.text}]}, + config=run_config, + context=run_context, + multitask_strategy="reject", + ) + except Exception as exc: + if _is_thread_busy_error(exc): + logger.warning("[Manager] thread busy (concurrent run rejected): thread_id=%s", thread_id) + await self._send_error(msg, THREAD_BUSY_MESSAGE) + return + else: + raise response_text = _extract_response_text(result) artifacts = _extract_artifacts(result) diff --git a/backend/app/channels/service.py b/backend/app/channels/service.py index 4a3df9060..1b9526297 100644 --- a/backend/app/channels/service.py +++ b/backend/app/channels/service.py @@ -167,6 +167,8 @@ class ChannelService: return False try: + config = dict(config) + config["channel_store"] = self.store channel = channel_cls(bus=self.bus, config=config) self._channels[name] = channel await channel.start() diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 6d2edb0bb..082c3d07d 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ [project.optional-dependencies] postgres = ["deerflow-harness[postgres]"] +discord = ["discord.py>=2.7.0"] [dependency-groups] dev = [ diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index d68701c4e..f85062a17 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -761,7 +761,7 @@ class TestChannelManager: history_by_checkpoint: dict[tuple[str, str], list[str]] = {} - async def _runs_wait(thread_id, assistant_id, *, input, config, context): + async def _runs_wait(thread_id, assistant_id, *, input, config, context, multitask_strategy=None): del assistant_id, context # unused in this test, kept for signature parity checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns") diff --git a/backend/tests/test_mindie_provider.py b/backend/tests/test_mindie_provider.py index 78bc0d972..cfbffbb07 100644 --- a/backend/tests/test_mindie_provider.py +++ b/backend/tests/test_mindie_provider.py @@ -454,7 +454,6 @@ class TestAStream: @pytest.mark.asyncio async def test_with_tools_emits_tool_call_chunk(self): - tool_calls = [{"name": "fn", "args": {}, "id": "c1"}] with patch.object(MindIEChatModel, "_agenerate", new_callable=AsyncMock) as mock_ag, patch.object(MindIEChatModel, "__init__", return_value=None): mock_ag.return_value = _make_chat_result("ok", tool_calls=tool_calls) diff --git a/backend/tests/test_title_middleware_core_logic.py b/backend/tests/test_title_middleware_core_logic.py index 5395f816e..3fdf4d3f9 100644 --- a/backend/tests/test_title_middleware_core_logic.py +++ b/backend/tests/test_title_middleware_core_logic.py @@ -93,7 +93,7 @@ class TestTitleMiddlewareCoreLogic: assert middleware._should_generate_title(state) is False def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch): - _set_test_title_config(max_chars=12) + _set_test_title_config(max_chars=12, model_name=None) middleware = TitleMiddleware() model = MagicMock() model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题")) diff --git a/backend/uv.lock b/backend/uv.lock index cd6bc8543..9cc2030fa 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.12" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", @@ -763,6 +763,9 @@ dependencies = [ ] [package.optional-dependencies] +discord = [ + { name = "discord-py" }, +] postgres = [ { name = "deerflow-harness", extra = ["postgres"] }, ] @@ -781,6 +784,7 @@ requires-dist = [ { name = "deerflow-harness", editable = "packages/harness" }, { name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" }, { name = "dingtalk-stream", specifier = ">=0.24.3" }, + { name = "discord-py", marker = "extra == 'discord'", specifier = ">=2.7.0" }, { name = "email-validator", specifier = ">=2.0.0" }, { name = "fastapi", specifier = ">=0.115.0" }, { name = "httpx", specifier = ">=0.28.0" }, @@ -795,7 +799,7 @@ requires-dist = [ { name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" }, { name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" }, ] -provides-extras = ["postgres"] +provides-extras = ["postgres", "discord"] [package.metadata.requires-dev] dev = [ @@ -923,6 +927,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4c/44/102dede3f371277598df6aa9725b82e3add068c729333c7a5dbc12764579/dingtalk_stream-0.24.3-py3-none-any.whl", hash = "sha256:2160403656985962878bf60cdf5adf41619f21067348e06f07a7c7eebf5943ad", size = 27813, upload-time = "2025-10-24T09:36:57.497Z" }, ] +[[package]] +name = "discord-py" +version = "2.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "audioop-lts", marker = "python_full_version >= '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ef/57/9a2d9abdabdc9db8ef28ce0cf4129669e1c8717ba28d607b5ba357c4de3b/discord_py-2.7.1.tar.gz", hash = "sha256:24d5e6a45535152e4b98148a9dd6b550d25dc2c9fb41b6d670319411641249da", size = 1106326, upload-time = "2026-03-03T18:40:46.24Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f7/a7/17208c3b3f92319e7fad259f1c6d5a5baf8fd0654c54846ced329f83c3eb/discord_py-2.7.1-py3-none-any.whl", hash = "sha256:849dca2c63b171146f3a7f3f8acc04248098e9e6203412ce3cf2745f284f7439", size = 1227550, upload-time = "2026-03-03T18:40:44.492Z" }, +] + [[package]] name = "distro" version = "1.9.0" diff --git a/config.example.yaml b/config.example.yaml index 9a8d07bf4..7396f6cfb 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -1029,6 +1029,14 @@ run_events: # client_secret: $DINGTALK_CLIENT_SECRET # allowed_users: [] # empty = allow all # card_template_id: "" # Optional: AI Card template ID for streaming updates +# +# discord: +# enabled: false +# bot_token: $DISCORD_BOT_TOKEN +# allowed_guilds: [] # empty = allow all guilds; can also be a single guild ID +# mention_only: false # If true, only respond when the bot is mentioned +# allowed_channels: [] # Optional: channel IDs exempt from mention_only (bot responds without mention) +# thread_mode: false # If true, group a channel conversation into a thread # ============================================================================ # Guardrails Configuration diff --git a/docker/nginx/nginx.conf b/docker/nginx/nginx.conf index 45be0ab97..18481adb3 100644 --- a/docker/nginx/nginx.conf +++ b/docker/nginx/nginx.conf @@ -28,6 +28,10 @@ http { set $gateway_upstream gateway:8001; set $frontend_upstream frontend:3000; + # Default proxy settings for all locations (streaming/SSE support) + proxy_buffering off; + proxy_cache off; + # Keep the unified nginx endpoint same-origin by default. When split # frontend/backend or port-forwarded deployments need browser CORS, # configure the Gateway allowlist with GATEWAY_CORS_ORIGINS so CORS and @@ -49,8 +53,6 @@ http { proxy_set_header Connection ''; # SSE/Streaming support - proxy_buffering off; - proxy_cache off; proxy_set_header X-Accel-Buffering no; # Timeouts for long-running requests @@ -70,6 +72,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # Custom API: Memory endpoint @@ -80,6 +83,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # Custom API: MCP configuration endpoint @@ -90,6 +94,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # Custom API: Skills configuration endpoint @@ -100,6 +105,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # Custom API: Agents endpoint @@ -110,6 +116,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # Custom API: Uploads endpoint @@ -124,6 +131,8 @@ http { # Large file upload support client_max_body_size 100M; proxy_request_buffering off; + + # Disable response buffering to avoid permission errors } # Custom API: Other endpoints under /api/threads @@ -134,6 +143,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # API Documentation: Swagger UI @@ -144,6 +154,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # API Documentation: ReDoc @@ -154,6 +165,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # API Documentation: OpenAPI Schema @@ -164,6 +176,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # Health check endpoint (gateway) @@ -174,6 +187,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # ── Provisioner API (sandbox management) ──────────────────────── @@ -187,6 +201,7 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + } # Catch-all for /api/ routes not covered above (e.g. /api/v1/auth/*). @@ -198,6 +213,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + # Disable buffering to avoid permission errors when nginx + # runs as a non-root user (e.g. local development). } # All other requests go to frontend @@ -220,4 +238,4 @@ http { proxy_read_timeout 600s; } } -} +} \ No newline at end of file diff --git a/docker/nginx/nginx.local.conf b/docker/nginx/nginx.local.conf index 68ca1f1ac..035406862 100644 --- a/docker/nginx/nginx.local.conf +++ b/docker/nginx/nginx.local.conf @@ -70,6 +70,11 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + # Disable buffering to avoid permission errors when nginx + # runs as a non-root user (e.g. local development). + proxy_buffering off; + proxy_cache off; } # Custom API: Memory endpoint @@ -80,6 +85,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + proxy_buffering off; + proxy_cache off; } # Custom API: MCP configuration endpoint @@ -90,6 +98,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + proxy_buffering off; + proxy_cache off; } # Custom API: Skills configuration endpoint @@ -100,6 +111,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + proxy_buffering off; + proxy_cache off; } # Custom API: Agents endpoint @@ -110,6 +124,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + proxy_buffering off; + proxy_cache off; } # Custom API: Uploads endpoint @@ -124,6 +141,10 @@ http { # Large file upload support client_max_body_size 100M; proxy_request_buffering off; + + # Disable response buffering to avoid permission errors + proxy_buffering off; + proxy_cache off; } # Custom API: Other endpoints under /api/threads @@ -134,6 +155,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + proxy_buffering off; + proxy_cache off; } # API Documentation: Swagger UI @@ -144,6 +168,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + proxy_buffering off; + proxy_cache off; } # API Documentation: ReDoc @@ -154,6 +181,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + proxy_buffering off; + proxy_cache off; } # API Documentation: OpenAPI Schema @@ -164,6 +194,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + proxy_buffering off; + proxy_cache off; } # Health check endpoint (gateway) @@ -174,6 +207,9 @@ http { proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Proto $scheme; + + proxy_buffering off; + proxy_cache off; } # Catch-all for any /api/* prefix not matched by a more specific block above. @@ -193,6 +229,11 @@ http { # Auth endpoints set HttpOnly cookies — make sure nginx doesn't # strip the Set-Cookie header from upstream responses. proxy_pass_header Set-Cookie; + + # Disable buffering to avoid permission errors when nginx + # runs as a non-root user (e.g. local development). + proxy_buffering off; + proxy_cache off; } # All other requests go to frontend diff --git a/scripts/detect_uv_extras.py b/scripts/detect_uv_extras.py index 91a9bd0ad..e6f4e8a24 100755 --- a/scripts/detect_uv_extras.py +++ b/scripts/detect_uv_extras.py @@ -72,6 +72,7 @@ def find_config_file() -> Path | None: _SECTION_RE = re.compile(r"^([A-Za-z_][\w-]*)\s*:\s*$") +_INDENTED_SECTION_RE = re.compile(r"^\s+([A-Za-z_][\w-]*)\s*:\s*$") _KEY_RE = re.compile(r"^\s+([A-Za-z_][\w-]*)\s*:\s*(\S.*?)\s*$") @@ -141,6 +142,84 @@ def section_value(lines: list[str], section: str, key: str) -> str | None: return None +def nested_section_value(lines: list[str], section_path: str, key: str) -> str | None: + """Return the value of a nested YAML key like ``channels.discord.enabled``. + + Handles two levels of nesting: + channels: + discord: + enabled: true + """ + parts = section_path.split(".") + if len(parts) != 2: + return None + parent_section, child_section = parts + + inside_parent = False + inside_child = False + parent_indent: int | None = None + child_indent: int | None = None + + for raw in lines: + line = _strip_comment(raw) + if not line.strip(): + continue + + stripped = line.lstrip() + indent = len(line) - len(stripped) + + # Top-level section match + sect_match = _SECTION_RE.match(line) + if sect_match: + if indent == 0: + inside_parent = sect_match.group(1) == parent_section + inside_child = False + parent_indent = None + child_indent = None + continue + + if not inside_parent: + continue + + # Track parent indent from first child + if parent_indent is None and indent > 0: + parent_indent = indent + + # If indent goes back to 0, we left the parent section + if indent == 0: + inside_parent = False + inside_child = False + continue + + # Check if we're at the parent's child level (subsection) + if parent_indent is not None and indent == parent_indent: + # This could be a subsection or a direct key of parent + sub_match = _INDENTED_SECTION_RE.match(line) + if sub_match and sub_match.group(1) == child_section: + inside_child = True + child_indent = None + continue + else: + inside_child = False + continue + + if not inside_child: + continue + + # We're inside the subsection — track child indent + if child_indent is None and indent > (parent_indent or 0): + child_indent = indent + + if child_indent is not None and indent != child_indent: + continue + + key_match = _KEY_RE.match(line) + if key_match and key_match.group(1) == key: + return _unquote(key_match.group(2).strip()) + + return None + + def detect_from_config(path: Path) -> list[str]: try: text = path.read_text(encoding="utf-8", errors="replace") @@ -152,6 +231,8 @@ def detect_from_config(path: Path) -> list[str]: extras.add("postgres") if (section_value(lines, "checkpointer", "type") or "").lower() == "postgres": extras.add("postgres") + if (nested_section_value(lines, "channels.discord", "enabled") or "").lower() == "true": + extras.add("discord") return sorted(extras) From 6d3cffb4f04d100929ed8b7eeb787503eeab4fd8 Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Sat, 16 May 2026 02:48:19 +0200 Subject: [PATCH 30/86] fix(frontend): deduplicate restored thread messages (#2958) * fix(frontend): fix duplicate messages when reopening agent sessions (#2957) * make format * fix(frontend): retry pending thread history loads --- frontend/src/core/threads/hooks.ts | 175 ++++++++++++++---- .../unit/core/threads/message-merge.test.ts | 64 +++++++ 2 files changed, 198 insertions(+), 41 deletions(-) create mode 100644 frontend/tests/unit/core/threads/message-merge.test.ts diff --git a/frontend/src/core/threads/hooks.ts b/frontend/src/core/threads/hooks.ts index adf9dbbb6..fba3edd0c 100644 --- a/frontend/src/core/threads/hooks.ts +++ b/frontend/src/core/threads/hooks.ts @@ -45,15 +45,60 @@ type SendMessageOptions = { additionalKwargs?: Record; }; -function mergeMessages( +function isNonEmptyString(value: string | undefined): value is string { + return typeof value === "string" && value.length > 0; +} + +function messageIdentity(message: Message): string | undefined { + if ( + "tool_call_id" in message && + typeof message.tool_call_id === "string" && + message.tool_call_id.length > 0 + ) { + return `tool:${message.tool_call_id}`; + } + if (typeof message.id === "string" && message.id.length > 0) { + return `message:${message.id}`; + } + return undefined; +} + +function dedupeMessagesByIdentity(messages: Message[]): Message[] { + const lastIndexByIdentity = new Map(); + + messages.forEach((message, index) => { + const identity = messageIdentity(message); + if (identity) { + lastIndexByIdentity.set(identity, index); + } + }); + + return messages.filter((message, index) => { + const identity = messageIdentity(message); + return !identity || lastIndexByIdentity.get(identity) === index; + }); +} + +function findLatestUnloadedRunIndex( + runs: Run[], + loadedRunIds: ReadonlySet, +): number { + for (let i = runs.length - 1; i >= 0; i--) { + const run = runs[i]; + if (run && !loadedRunIds.has(run.run_id)) { + return i; + } + } + return -1; +} + +export function mergeMessages( historyMessages: Message[], threadMessages: Message[], optimisticMessages: Message[], ): Message[] { const threadMessageIds = new Set( - threadMessages - .map((m) => ("tool_call_id" in m ? m.tool_call_id : m.id)) - .filter(Boolean), + threadMessages.map(messageIdentity).filter(isNonEmptyString), ); // The overlap is a contiguous suffix of historyMessages (newest history == oldest thread). @@ -65,28 +110,19 @@ function mergeMessages( if (!msg) { continue; } - if ( - (msg?.id && threadMessageIds.has(msg.id)) || - ("tool_call_id" in msg && threadMessageIds.has(msg.tool_call_id)) - ) { + const identity = messageIdentity(msg); + if (identity && threadMessageIds.has(identity)) { cutoff = i; } else { break; } } - return [ + return dedupeMessagesByIdentity([ ...historyMessages.slice(0, cutoff), ...threadMessages, ...optimisticMessages, - ]; -} - -function messageIdentity(message: Message): string | undefined { - if ("tool_call_id" in message) { - return message.tool_call_id; - } - return message.id; + ]); } function getMessagesAfterBaseline( @@ -627,48 +663,105 @@ export function useThreadHistory(threadId: string) { const runsRef = useRef(runs.data ?? []); const indexRef = useRef(-1); const loadingRef = useRef(false); + const pendingLoadRef = useRef(false); + const loadingRunIdRef = useRef(null); + const loadedRunIdsRef = useRef>(new Set()); const [loading, setLoading] = useState(false); const [messages, setMessages] = useState([]); - loadingRef.current = loading; const loadMessages = useCallback(async () => { + if (loadingRef.current) { + const pendingRunIndex = findLatestUnloadedRunIndex( + runsRef.current, + loadedRunIdsRef.current, + ); + const pendingRun = runsRef.current[pendingRunIndex]; + if (pendingRun && pendingRun.run_id !== loadingRunIdRef.current) { + pendingLoadRef.current = true; + } + return; + } if (runsRef.current.length === 0) { return; } - const run = runsRef.current[indexRef.current]; - if (!run || loadingRef.current) { - return; - } + + loadingRef.current = true; + setLoading(true); + try { - setLoading(true); - const result: { data: RunMessage[]; hasMore: boolean } = await fetch( - `${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadIdRef.current)}/runs/${encodeURIComponent(run.run_id)}/messages`, - { - method: "GET", - headers: { - "Content-Type": "application/json", + do { + pendingLoadRef.current = false; + + const nextRunIndex = findLatestUnloadedRunIndex( + runsRef.current, + loadedRunIdsRef.current, + ); + indexRef.current = nextRunIndex; + + const run = runsRef.current[nextRunIndex]; + if (!run) { + indexRef.current = -1; + return; + } + + const requestThreadId = threadIdRef.current; + loadingRunIdRef.current = run.run_id; + const result: { data: RunMessage[]; hasMore: boolean } = await fetch( + `${getBackendBaseURL()}/api/threads/${encodeURIComponent(requestThreadId)}/runs/${encodeURIComponent(run.run_id)}/messages`, + { + method: "GET", + headers: { + "Content-Type": "application/json", + }, + credentials: "include", }, - credentials: "include", - }, - ).then((res) => { - return res.json(); - }); - const _messages = result.data - .filter((m) => !m.metadata.caller?.startsWith("middleware:")) - .map((m) => m.content); - setMessages((prev) => [..._messages, ...prev]); - indexRef.current -= 1; + ).then((res) => { + return res.json(); + }); + const _messages = result.data + .filter((m) => !m.metadata.caller?.startsWith("middleware:")) + .map((m) => m.content); + if (threadIdRef.current !== requestThreadId) { + return; + } + setMessages((prev) => + dedupeMessagesByIdentity([..._messages, ...prev]), + ); + loadedRunIdsRef.current.add(run.run_id); + indexRef.current = findLatestUnloadedRunIndex( + runsRef.current, + loadedRunIdsRef.current, + ); + } while (pendingLoadRef.current); } catch (err) { console.error(err); } finally { + loadingRef.current = false; + loadingRunIdRef.current = null; setLoading(false); } }, []); useEffect(() => { + const threadChanged = threadIdRef.current !== threadId; threadIdRef.current = threadId; + + if (threadChanged) { + runsRef.current = []; + indexRef.current = -1; + pendingLoadRef.current = false; + loadingRunIdRef.current = null; + loadedRunIdsRef.current = new Set(); + loadingRef.current = false; + setLoading(false); + setMessages([]); + } + if (runs.data && runs.data.length > 0) { runsRef.current = runs.data ?? []; - indexRef.current = runs.data.length - 1; + indexRef.current = findLatestUnloadedRunIndex( + runs.data, + loadedRunIdsRef.current, + ); } loadMessages().catch(() => { toast.error("Failed to load thread history."); @@ -677,7 +770,7 @@ export function useThreadHistory(threadId: string) { const appendMessages = useCallback((_messages: Message[]) => { setMessages((prev) => { - return [...prev, ..._messages]; + return dedupeMessagesByIdentity([...prev, ..._messages]); }); }, []); const hasMore = indexRef.current >= 0 || !runs.data; diff --git a/frontend/tests/unit/core/threads/message-merge.test.ts b/frontend/tests/unit/core/threads/message-merge.test.ts new file mode 100644 index 000000000..9b29aebc9 --- /dev/null +++ b/frontend/tests/unit/core/threads/message-merge.test.ts @@ -0,0 +1,64 @@ +import type { Message } from "@langchain/langgraph-sdk"; +import { expect, test } from "vitest"; + +import { mergeMessages } from "@/core/threads/hooks"; + +test("mergeMessages removes duplicate messages already present in history", () => { + const human = { + id: "human-1", + type: "human", + content: "Design an agent", + } as Message; + const ai = { + id: "ai-1", + type: "ai", + content: "Let's design it.", + } as Message; + + expect(mergeMessages([human, ai, human, ai], [], [])).toEqual([human, ai]); +}); + +test("mergeMessages lets live thread messages replace overlapping history", () => { + const oldHuman = { + id: "human-1", + type: "human", + content: "old", + } as Message; + const liveHuman = { + id: "human-1", + type: "human", + content: "live", + } as Message; + const oldAi = { + id: "ai-1", + type: "ai", + content: "old", + } as Message; + const liveAi = { + id: "ai-1", + type: "ai", + content: "live", + } as Message; + + expect(mergeMessages([oldHuman, oldAi], [liveHuman, liveAi], [])).toEqual([ + liveHuman, + liveAi, + ]); +}); + +test("mergeMessages deduplicates tool messages by tool_call_id", () => { + const oldTool = { + id: "tool-message-old", + type: "tool", + tool_call_id: "call-1", + content: "old", + } as Message; + const liveTool = { + id: "tool-message-live", + type: "tool", + tool_call_id: "call-1", + content: "live", + } as Message; + + expect(mergeMessages([oldTool], [liveTool], [])).toEqual([liveTool]); +}); From 6d611c2bf65c817580b684c0312f8b37f5fa5cbf Mon Sep 17 00:00:00 2001 From: Willem Jiang Date: Sat, 16 May 2026 09:24:40 +0800 Subject: [PATCH 31/86] fix(auth): persist auto-generated JWT secret to survive restarts (#2933) * fix(auth): persist auto-generated JWT secret to survive restarts When AUTH_JWT_SECRET is not set, the auto-generated secret is now written to .deer-flow/.jwt_secret (mode 0600) and reused on subsequent starts. This prevents session invalidation on every restart while still allowing explicit AUTH_JWT_SECRET in .env to take precedence. * Apply suggestions from code review Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix the lint errors of backend --------- Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- backend/app/gateway/auth/config.py | 34 ++++++++++++++++-- backend/docs/AUTH_UPGRADE.md | 4 +-- backend/tests/test_auth_config.py | 58 ++++++++++++++++++++++++------ 3 files changed, 80 insertions(+), 16 deletions(-) diff --git a/backend/app/gateway/auth/config.py b/backend/app/gateway/auth/config.py index 4734f0897..27c1984f1 100644 --- a/backend/app/gateway/auth/config.py +++ b/backend/app/gateway/auth/config.py @@ -8,6 +8,8 @@ from pydantic import BaseModel, Field logger = logging.getLogger(__name__) +_SECRET_FILE = ".jwt_secret" + class AuthConfig(BaseModel): """JWT and auth-related configuration. Parsed once at startup. @@ -30,6 +32,32 @@ class AuthConfig(BaseModel): _auth_config: AuthConfig | None = None +def _load_or_create_secret() -> str: + """Load persisted JWT secret from ``{base_dir}/.jwt_secret``, or generate and persist a new one.""" + from deerflow.config.paths import get_paths + + paths = get_paths() + secret_file = paths.base_dir / _SECRET_FILE + + try: + if secret_file.exists(): + secret = secret_file.read_text(encoding="utf-8").strip() + if secret: + return secret + except OSError as exc: + raise RuntimeError(f"Failed to read JWT secret from {secret_file}. Set AUTH_JWT_SECRET explicitly or fix DEER_FLOW_HOME/base directory permissions so DeerFlow can read its persisted auth secret.") from exc + + secret = secrets.token_urlsafe(32) + try: + secret_file.parent.mkdir(parents=True, exist_ok=True) + fd = os.open(secret_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600) + with os.fdopen(fd, "w", encoding="utf-8") as fh: + fh.write(secret) + except OSError as exc: + raise RuntimeError(f"Failed to persist JWT secret to {secret_file}. Set AUTH_JWT_SECRET explicitly or fix DEER_FLOW_HOME/base directory permissions so DeerFlow can store a stable auth secret.") from exc + return secret + + def get_auth_config() -> AuthConfig: """Get the global AuthConfig instance. Parses from env on first call.""" global _auth_config @@ -39,11 +67,11 @@ def get_auth_config() -> AuthConfig: load_dotenv() jwt_secret = os.environ.get("AUTH_JWT_SECRET") if not jwt_secret: - jwt_secret = secrets.token_urlsafe(32) + jwt_secret = _load_or_create_secret() os.environ["AUTH_JWT_SECRET"] = jwt_secret logger.warning( - "⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. " - "Sessions will be invalidated on restart. " + "⚠ AUTH_JWT_SECRET is not set — using an auto-generated secret " + "persisted to .jwt_secret. Sessions will survive restarts. " "For production, add AUTH_JWT_SECRET to your .env file: " 'python -c "import secrets; print(secrets.token_urlsafe(32))"' ) diff --git a/backend/docs/AUTH_UPGRADE.md b/backend/docs/AUTH_UPGRADE.md index 75fe8b3cb..b54283d24 100644 --- a/backend/docs/AUTH_UPGRADE.md +++ b/backend/docs/AUTH_UPGRADE.md @@ -99,7 +99,7 @@ rm -f backend/.deer-flow/data/deerflow.db | `.deer-flow/users/{user_id}/memory.json` | 用户级 memory | | `.deer-flow/users/{user_id}/agents/{agent_name}/` | 用户自定义 agent 配置、SOUL 和 agent memory | | `.deer-flow/admin_initial_credentials.txt` | `reset_admin` 生成的新凭据文件(0600,读完应删除) | -| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成临时密钥,重启后 session 失效) | +| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成并持久化到 `.deer-flow/.jwt_secret`,重启后 session 保持) | ### 生产环境建议 @@ -137,4 +137,4 @@ python -c "import secrets; print(secrets.token_urlsafe(32))" | 启动后没看到密码 | 当前实现不在启动日志输出密码 | 首次安装访问 `/setup`;忘记密码用 `reset_admin` | | `/login` 自动跳到 `/setup` | 系统还没有 admin | 在 `/setup` 创建第一个 admin | | 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 | -| 重启后需要重新登录 | `AUTH_JWT_SECRET` 未持久化 | 在 `.env` 中设置固定密钥 | +| 重启后需要重新登录 | `.jwt_secret` 文件被删除且 `.env` 未设置 `AUTH_JWT_SECRET` | 在 `.env` 中设置固定密钥 | diff --git a/backend/tests/test_auth_config.py b/backend/tests/test_auth_config.py index 21b8bd81b..61d1d7d2e 100644 --- a/backend/tests/test_auth_config.py +++ b/backend/tests/test_auth_config.py @@ -5,28 +5,26 @@ from unittest.mock import patch import pytest -from app.gateway.auth.config import AuthConfig +import app.gateway.auth.config as cfg def test_auth_config_defaults(): - config = AuthConfig(jwt_secret="test-secret-key-123") + config = cfg.AuthConfig(jwt_secret="test-secret-key-123") assert config.token_expiry_days == 7 def test_auth_config_token_expiry_range(): - AuthConfig(jwt_secret="s", token_expiry_days=1) - AuthConfig(jwt_secret="s", token_expiry_days=30) + cfg.AuthConfig(jwt_secret="s", token_expiry_days=1) + cfg.AuthConfig(jwt_secret="s", token_expiry_days=30) with pytest.raises(Exception): - AuthConfig(jwt_secret="s", token_expiry_days=0) + cfg.AuthConfig(jwt_secret="s", token_expiry_days=0) with pytest.raises(Exception): - AuthConfig(jwt_secret="s", token_expiry_days=31) + cfg.AuthConfig(jwt_secret="s", token_expiry_days=31) def test_auth_config_from_env(): env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"} with patch.dict(os.environ, env, clear=False): - import app.gateway.auth.config as cfg - old = cfg._auth_config cfg._auth_config = None try: @@ -36,19 +34,57 @@ def test_auth_config_from_env(): cfg._auth_config = old -def test_auth_config_missing_secret_generates_ephemeral(caplog): +def test_auth_config_missing_secret_generates_and_persists(tmp_path, caplog): import logging - import app.gateway.auth.config as cfg + from deerflow.config.paths import Paths old = cfg._auth_config cfg._auth_config = None + secret_file = tmp_path / ".jwt_secret" try: with patch.dict(os.environ, {}, clear=True): os.environ.pop("AUTH_JWT_SECRET", None) - with caplog.at_level(logging.WARNING): + with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)), caplog.at_level(logging.WARNING): config = cfg.get_auth_config() assert config.jwt_secret assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages) + assert secret_file.exists() + assert secret_file.read_text().strip() == config.jwt_secret + finally: + cfg._auth_config = old + + +def test_auth_config_reuses_persisted_secret(tmp_path): + from deerflow.config.paths import Paths + + old = cfg._auth_config + cfg._auth_config = None + persisted = "persisted-secret-from-file-min-32-chars!!" + (tmp_path / ".jwt_secret").write_text(persisted, encoding="utf-8") + try: + with patch.dict(os.environ, {}, clear=True): + os.environ.pop("AUTH_JWT_SECRET", None) + with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)): + config = cfg.get_auth_config() + assert config.jwt_secret == persisted + finally: + cfg._auth_config = old + + +def test_auth_config_empty_secret_file_generates_new(tmp_path): + from deerflow.config.paths import Paths + + old = cfg._auth_config + cfg._auth_config = None + (tmp_path / ".jwt_secret").write_text("", encoding="utf-8") + try: + with patch.dict(os.environ, {}, clear=True): + os.environ.pop("AUTH_JWT_SECRET", None) + with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)): + config = cfg.get_auth_config() + assert config.jwt_secret + assert len(config.jwt_secret) > 20 + assert (tmp_path / ".jwt_secret").read_text().strip() == config.jwt_secret finally: cfg._auth_config = old From 4538c3229875aebebcdeac5c3ffb9904c2d7944d Mon Sep 17 00:00:00 2001 From: pereverzev <136249885+daniil-pereverzev@users.noreply.github.com> Date: Sat, 16 May 2026 12:55:34 +0300 Subject: [PATCH 32/86] Fix type check for 'thinking' in message content (#2964) * Fix type check for 'thinking' in message content When Gemini via Vertex AI returns content as a string inside an array, the in operator throws TypeError because it can't be used on primitives. * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Zil6n <136249885+Zil6n@users.noreply.github.com> Co-authored-by: Willem Jiang Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- frontend/src/core/messages/utils.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/src/core/messages/utils.ts b/frontend/src/core/messages/utils.ts index 3f1fef9ad..22f985009 100644 --- a/frontend/src/core/messages/utils.ts +++ b/frontend/src/core/messages/utils.ts @@ -251,7 +251,7 @@ export function extractReasoningContentFromMessage(message: Message) { } if (Array.isArray(message.content)) { const part = message.content[0]; - if (part && "thinking" in part) { + if (part && typeof part === "object" && "thinking" in part) { return part.thinking as string; } } From 380255f722ccc568f37d4b63f3c02bdeeb399465 Mon Sep 17 00:00:00 2001 From: Xinmin Zeng <135568692+fancyboi999@users.noreply.github.com> Date: Sun, 17 May 2026 08:26:04 +0800 Subject: [PATCH 33/86] fix(sandbox): uphold /mnt/user-data contract at Sandbox API boundary (#2873) (#2881) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(sandbox): uphold /mnt/user-data contract at Sandbox API boundary (#2873) LocalSandboxProvider used a process-wide singleton with no /mnt/user-data mapping, forcing every caller to translate virtual paths via tools.py before invoking the public Sandbox API. AIO already exposes /mnt/user-data natively (per-thread bind mounts), so the same code path behaved differently across implementations — and direct callers like uploads.py:282 / feishu.py:389 only worked thanks to the `uses_thread_data_mounts` workaround flag. Switch the provider to a dual-track cache: keep the `"local"` singleton for legacy acquire(None) callers (backward-compat for existing tests and scripts), and create a per-thread LocalSandbox with id `"local:{tid}"` for acquire(thread_id). Each per-thread instance carries PathMapping entries for /mnt/user-data, its three subdirs, and /mnt/acp-workspace, mirroring how AioSandboxProvider mounts those paths into its container. is_local_sandbox() now recognises both id formats. `_agent_written_paths` becomes per-thread (it was a process-wide set that leaked across threads — a latent isolation bug also fixed by this change). Verified via TDD: a new contract test suite hits the public Sandbox API directly (write/read/list/exec/glob/grep/update + per-thread isolation + lifecycle). 3212 backend tests still pass, ruff is clean. * fix(sandbox): address Copilot review on #2881 Three follow-ups from Copilot's review of the LocalSandboxProvider refactor: 1. Synchronisation: ``acquire`` / ``get`` / ``reset`` mutated the cache without any lock, so concurrent acquire of the same ``thread_id`` could create two ``LocalSandbox`` instances and lose one's ``_agent_written_paths`` state. Add a provider-wide ``threading.Lock`` (matching ``AioSandboxProvider``) and build per-thread mappings outside the lock to avoid holding it during the ``ensure_thread_dirs`` filesystem touch. 2. Memory bound: ``_thread_sandboxes`` grew monotonically. Replace the plain dict with an ``OrderedDict`` LRU capped at ``DEFAULT_MAX_CACHED_THREAD_SANDBOXES`` (256, configurable per provider instance). ``get`` promotes touched threads to the MRU end so an active thread isn't evicted under load. Eviction is graceful: the next ``acquire`` rebuilds a fresh sandbox; only ``_agent_written_paths`` (reverse-resolve hint) is lost. 3. Docs: update ``CLAUDE.md`` to reflect the new per-thread architecture, the LRU cap, and that ``is_local_sandbox`` recognises both id formats. New regression tests: - Concurrent ``acquire("alpha")`` from 8 threads yields a single instance (slow-init injection forces the race window wide open). - Concurrent ``acquire`` of distinct thread_ids yields distinct instances. - The cache evicts the least-recently-used thread once the cap is exceeded. - ``get`` promotes recency so a polled thread survives a later acquire-storm. --- backend/CLAUDE.md | 6 +- .../sandbox/local/local_sandbox_provider.py | 235 ++++++++++- .../harness/deerflow/sandbox/tools.py | 10 +- ...est_local_sandbox_virtual_path_contract.py | 366 ++++++++++++++++++ 4 files changed, 592 insertions(+), 25 deletions(-) create mode 100644 backend/tests/test_local_sandbox_virtual_path_contract.py diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 5e0aebfdb..35607c6fd 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -232,14 +232,14 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti **Interface**: Abstract `Sandbox` with `execute_command`, `read_file`, `write_file`, `list_dir` **Provider Pattern**: `SandboxProvider` with `acquire`, `get`, `release` lifecycle **Implementations**: -- `LocalSandboxProvider` - Singleton local filesystem execution with path mappings +- `LocalSandboxProvider` - Local filesystem execution. `acquire(thread_id)` returns a per-thread `LocalSandbox` (id `local:{thread_id}`) whose `path_mappings` resolve `/mnt/user-data/{workspace,uploads,outputs}` and `/mnt/acp-workspace` to that thread's host directories, so the public `Sandbox` API honours the `/mnt/user-data` contract uniformly with AIO. `acquire()` / `acquire(None)` keeps the legacy generic singleton (id `local`) for callers without a thread context. Per-thread sandboxes are held in an LRU cache (default 256 entries) guarded by a `threading.Lock`. - `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation **Virtual Path System**: - Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills` - Physical: `backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/...`, `deer-flow/skills/` -- Translation: `replace_virtual_path()` / `replace_virtual_paths_in_command()` -- Detection: `is_local_sandbox()` checks `sandbox_id == "local"` +- Translation: `LocalSandboxProvider` builds per-thread `PathMapping`s for the user-data prefixes at acquire time; `tools.py` keeps `replace_virtual_path()` / `replace_virtual_paths_in_command()` as a defense-in-depth layer (and for path validation). AIO has the directories volume-mounted at the same virtual paths inside its container, so both implementations accept `/mnt/user-data/...` natively. +- Detection: `is_local_sandbox()` accepts both `sandbox_id == "local"` (legacy / no-thread) and `sandbox_id.startswith("local:")` (per-thread) **Sandbox Tools** (in `packages/harness/deerflow/sandbox/tools.py`): - `bash` - Execute commands with path translation and error handling diff --git a/backend/packages/harness/deerflow/sandbox/local/local_sandbox_provider.py b/backend/packages/harness/deerflow/sandbox/local/local_sandbox_provider.py index 0510a2473..d64a1c220 100644 --- a/backend/packages/harness/deerflow/sandbox/local/local_sandbox_provider.py +++ b/backend/packages/harness/deerflow/sandbox/local/local_sandbox_provider.py @@ -1,4 +1,6 @@ import logging +import threading +from collections import OrderedDict from pathlib import Path from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping @@ -7,25 +9,87 @@ from deerflow.sandbox.sandbox_provider import SandboxProvider logger = logging.getLogger(__name__) +# Module-level alias kept for backward compatibility with older callers/tests +# that reach into ``local_sandbox_provider._singleton`` directly. New code reads +# the provider instance attributes (``_generic_sandbox`` / ``_thread_sandboxes``) +# instead. _singleton: LocalSandbox | None = None +# Virtual prefixes that must be reserved by the per-thread mappings created in +# ``acquire`` — custom mounts from ``config.yaml`` may not overlap with these. +_USER_DATA_VIRTUAL_PREFIX = "/mnt/user-data" +_ACP_WORKSPACE_VIRTUAL_PREFIX = "/mnt/acp-workspace" + +# Default upper bound on per-thread LocalSandbox instances retained in memory. +# Each cached instance is cheap (a small Python object with a list of +# PathMapping and a set of agent-written paths used for reverse resolve), but +# in a long-running gateway the number of distinct thread_ids is unbounded. +# When the cap is exceeded the least-recently-used entry is dropped; the next +# ``acquire(thread_id)`` for that thread simply rebuilds the sandbox at the +# cost of losing its accumulated ``_agent_written_paths`` (read_file falls +# back to no reverse resolution, which is the same behaviour as a fresh run). +DEFAULT_MAX_CACHED_THREAD_SANDBOXES = 256 + class LocalSandboxProvider(SandboxProvider): + """Local-filesystem sandbox provider with per-thread path scoping. + + Earlier revisions of this provider returned a single process-wide + ``LocalSandbox`` keyed by the literal id ``"local"``. That singleton could + not honour the documented ``/mnt/user-data/...`` contract at the public + ``Sandbox`` API boundary because the corresponding host directory is + per-thread (``{base_dir}/users/{user_id}/threads/{thread_id}/user-data/``). + + The provider now produces a fresh ``LocalSandbox`` per ``thread_id`` whose + ``path_mappings`` include thread-scoped entries for + ``/mnt/user-data/{workspace,uploads,outputs}`` and ``/mnt/acp-workspace``, + mirroring how :class:`AioSandboxProvider` bind-mounts those paths into its + docker container. The legacy ``acquire()`` / ``acquire(None)`` call still + returns a generic singleton with id ``"local"`` for callers (and tests) + that do not have a thread context. + + Thread-safety: ``acquire``, ``get`` and ``reset`` may be invoked from + multiple threads (Gateway tool dispatch, subagent worker pools, the + background memory updater, …) so all cache state changes are serialised + through a provider-wide :class:`threading.Lock`. This matches the pattern + used by :class:`AioSandboxProvider`. + + Memory bound: ``_thread_sandboxes`` is an LRU cache capped at + ``max_cached_threads`` (default :data:`DEFAULT_MAX_CACHED_THREAD_SANDBOXES`). + When the cap is exceeded the least-recently-used entry is evicted on the + next ``acquire``; the evicted thread's next ``acquire`` rebuilds a fresh + sandbox (losing only its ``_agent_written_paths`` reverse-resolve hint, + which gracefully degrades read_file output). + """ + uses_thread_data_mounts = True - def __init__(self): - """Initialize the local sandbox provider with path mappings.""" + def __init__(self, max_cached_threads: int = DEFAULT_MAX_CACHED_THREAD_SANDBOXES): + """Initialize the local sandbox provider with static path mappings. + + Args: + max_cached_threads: Upper bound on per-thread sandboxes retained in + the LRU cache. When exceeded, the least-recently-used entry is + evicted on the next ``acquire``. + """ self._path_mappings = self._setup_path_mappings() + self._generic_sandbox: LocalSandbox | None = None + self._thread_sandboxes: OrderedDict[str, LocalSandbox] = OrderedDict() + self._max_cached_threads = max_cached_threads + self._lock = threading.Lock() def _setup_path_mappings(self) -> list[PathMapping]: """ - Setup path mappings for local sandbox. + Setup static path mappings shared by every sandbox this provider yields. - Maps container paths to actual local paths, including skills directory - and any custom mounts configured in config.yaml. + Static mappings cover the skills directory and any custom mounts from + ``config.yaml`` — both are process-wide and identical for every thread. + Per-thread ``/mnt/user-data/...`` and ``/mnt/acp-workspace`` mappings + are appended inside :meth:`acquire` because they depend on + ``thread_id`` and the effective ``user_id``. Returns: - List of path mappings + List of static path mappings """ mappings: list[PathMapping] = [] @@ -48,7 +112,11 @@ class LocalSandboxProvider(SandboxProvider): ) # Map custom mounts from sandbox config - _RESERVED_CONTAINER_PREFIXES = [container_path, "/mnt/acp-workspace", "/mnt/user-data"] + _RESERVED_CONTAINER_PREFIXES = [ + container_path, + _ACP_WORKSPACE_VIRTUAL_PREFIX, + _USER_DATA_VIRTUAL_PREFIX, + ] sandbox_config = config.sandbox if sandbox_config and sandbox_config.mounts: for mount in sandbox_config.mounts: @@ -99,33 +167,162 @@ class LocalSandboxProvider(SandboxProvider): return mappings + @staticmethod + def _build_thread_path_mappings(thread_id: str) -> list[PathMapping]: + """Build per-thread path mappings for /mnt/user-data and /mnt/acp-workspace. + + Resolves ``user_id`` via :func:`get_effective_user_id` (the same path + :class:`AioSandboxProvider` uses) and ensures the backing host + directories exist before they are mapped into the sandbox view. + """ + from deerflow.config.paths import get_paths + from deerflow.runtime.user_context import get_effective_user_id + + paths = get_paths() + user_id = get_effective_user_id() + paths.ensure_thread_dirs(thread_id, user_id=user_id) + + return [ + # Aggregate parent mapping so ``ls /mnt/user-data`` and other + # parent-level operations behave the same as inside AIO (where the + # parent directory is real and contains the three subdirs). Longer + # subpath mappings below still win for ``/mnt/user-data/workspace/...`` + # because ``_find_path_mapping`` sorts by container_path length. + PathMapping( + container_path=_USER_DATA_VIRTUAL_PREFIX, + local_path=str(paths.sandbox_user_data_dir(thread_id, user_id=user_id)), + read_only=False, + ), + PathMapping( + container_path=f"{_USER_DATA_VIRTUAL_PREFIX}/workspace", + local_path=str(paths.sandbox_work_dir(thread_id, user_id=user_id)), + read_only=False, + ), + PathMapping( + container_path=f"{_USER_DATA_VIRTUAL_PREFIX}/uploads", + local_path=str(paths.sandbox_uploads_dir(thread_id, user_id=user_id)), + read_only=False, + ), + PathMapping( + container_path=f"{_USER_DATA_VIRTUAL_PREFIX}/outputs", + local_path=str(paths.sandbox_outputs_dir(thread_id, user_id=user_id)), + read_only=False, + ), + PathMapping( + container_path=_ACP_WORKSPACE_VIRTUAL_PREFIX, + local_path=str(paths.acp_workspace_dir(thread_id, user_id=user_id)), + read_only=False, + ), + ] + def acquire(self, thread_id: str | None = None) -> str: + """Return a sandbox id scoped to *thread_id* (or the generic singleton). + + - ``thread_id=None`` keeps the legacy singleton with id ``"local"`` for + callers that have no thread context (e.g. legacy tests, scripts). + - ``thread_id="abc"`` yields a per-thread ``LocalSandbox`` with id + ``"local:abc"`` whose ``path_mappings`` resolve ``/mnt/user-data/...`` + to that thread's host directories. + + Thread-safe under concurrent invocation: the cache check + insert is + guarded by ``self._lock`` so two callers racing on the same + ``thread_id`` always observe the same LocalSandbox instance. + """ global _singleton - if _singleton is None: - _singleton = LocalSandbox("local", path_mappings=self._path_mappings) - return _singleton.id + + if thread_id is None: + with self._lock: + if self._generic_sandbox is None: + self._generic_sandbox = LocalSandbox("local", path_mappings=list(self._path_mappings)) + _singleton = self._generic_sandbox + return self._generic_sandbox.id + + # Fast path under lock. + with self._lock: + cached = self._thread_sandboxes.get(thread_id) + if cached is not None: + # Mark as most-recently used so frequently-touched threads + # survive eviction. + self._thread_sandboxes.move_to_end(thread_id) + return cached.id + + # ``_build_thread_path_mappings`` touches the filesystem + # (``ensure_thread_dirs``); release the lock during I/O. + new_mappings = list(self._path_mappings) + self._build_thread_path_mappings(thread_id) + + with self._lock: + # Re-check after the lock-free I/O: another caller may have + # populated the cache while we were computing mappings. + cached = self._thread_sandboxes.get(thread_id) + if cached is None: + cached = LocalSandbox(f"local:{thread_id}", path_mappings=new_mappings) + self._thread_sandboxes[thread_id] = cached + self._evict_until_within_cap_locked() + else: + self._thread_sandboxes.move_to_end(thread_id) + return cached.id + + def _evict_until_within_cap_locked(self) -> None: + """LRU-evict cached thread sandboxes once the cap is exceeded. + + Caller MUST hold ``self._lock``. + """ + while len(self._thread_sandboxes) > self._max_cached_threads: + evicted_thread_id, _ = self._thread_sandboxes.popitem(last=False) + logger.info( + "Evicting LocalSandbox cache entry for thread %s (cap=%d)", + evicted_thread_id, + self._max_cached_threads, + ) def get(self, sandbox_id: str) -> Sandbox | None: if sandbox_id == "local": - if _singleton is None: + with self._lock: + generic = self._generic_sandbox + if generic is None: self.acquire() - return _singleton + with self._lock: + return self._generic_sandbox + return generic + if isinstance(sandbox_id, str) and sandbox_id.startswith("local:"): + thread_id = sandbox_id[len("local:") :] + with self._lock: + cached = self._thread_sandboxes.get(thread_id) + if cached is not None: + # Touching a thread via ``get`` (used by tools.py to look + # up the sandbox once per tool call) promotes it in LRU + # order so an active thread isn't evicted under load. + self._thread_sandboxes.move_to_end(thread_id) + return cached return None def release(self, sandbox_id: str) -> None: - # LocalSandbox uses singleton pattern - no cleanup needed. + # LocalSandbox has no resources to release; keep the cached instance so + # that ``_agent_written_paths`` (used to reverse-resolve agent-authored + # file contents on read) survives between turns. LRU eviction in + # ``acquire`` and explicit ``reset()`` / ``shutdown()`` are the only + # paths that drop cached entries. + # # Note: This method is intentionally not called by SandboxMiddleware # to allow sandbox reuse across multiple turns in a thread. - # For Docker-based providers (e.g., AioSandboxProvider), cleanup - # happens at application shutdown via the shutdown() method. pass def reset(self) -> None: - # reset_sandbox_provider() must also clear the module singleton. + """Drop all cached LocalSandbox instances. + + ``reset_sandbox_provider()`` calls this to ensure config / mount + changes take effect on the next ``acquire()``. We also reset the + module-level ``_singleton`` alias so older callers/tests that reach + into it see a fresh state. + """ global _singleton - _singleton = None + with self._lock: + self._generic_sandbox = None + self._thread_sandboxes.clear() + _singleton = None def shutdown(self) -> None: - # LocalSandboxProvider has no extra resources beyond the shared - # singleton, so shutdown uses the same cleanup path as reset. + # LocalSandboxProvider has no extra resources beyond the cached + # ``LocalSandbox`` instances, so shutdown uses the same cleanup path + # as ``reset``. self.reset() diff --git a/backend/packages/harness/deerflow/sandbox/tools.py b/backend/packages/harness/deerflow/sandbox/tools.py index 7c746b1aa..2694e9406 100644 --- a/backend/packages/harness/deerflow/sandbox/tools.py +++ b/backend/packages/harness/deerflow/sandbox/tools.py @@ -1006,8 +1006,9 @@ def get_thread_data(runtime: Runtime | None) -> ThreadDataState | None: def is_local_sandbox(runtime: Runtime | None) -> bool: """Check if the current sandbox is a local sandbox. - Path replacement is only needed for local sandbox since aio sandbox - already has /mnt/user-data mounted in the container. + Accepts both the legacy generic id ``"local"`` (acquire with no thread + context) and the per-thread id format ``"local:{thread_id}"`` produced by + :meth:`LocalSandboxProvider.acquire` once a thread is known. """ if runtime is None: return False @@ -1016,7 +1017,10 @@ def is_local_sandbox(runtime: Runtime | None) -> bool: sandbox_state = runtime.state.get("sandbox") if sandbox_state is None: return False - return sandbox_state.get("sandbox_id") == "local" + sandbox_id = sandbox_state.get("sandbox_id") + if not isinstance(sandbox_id, str): + return False + return sandbox_id == "local" or sandbox_id.startswith("local:") def sandbox_from_runtime(runtime: Runtime | None = None) -> Sandbox: diff --git a/backend/tests/test_local_sandbox_virtual_path_contract.py b/backend/tests/test_local_sandbox_virtual_path_contract.py new file mode 100644 index 000000000..d9ec0cbdc --- /dev/null +++ b/backend/tests/test_local_sandbox_virtual_path_contract.py @@ -0,0 +1,366 @@ +"""Issue #2873 regression — the public Sandbox API must honor the documented +/mnt/user-data contract uniformly across implementations. + +Today AIO sandbox already accepts /mnt/user-data/... paths directly because the +container has those paths bind-mounted per-thread. LocalSandbox, however, +externalises that translation to ``deerflow.sandbox.tools`` via ``thread_data``, +so any caller that bypasses tools.py (e.g. ``uploads.py`` syncing files into a +remote sandbox via ``sandbox.update_file(virtual_path, ...)``) sees inconsistent +behaviour. + +These tests pin down the **public Sandbox API boundary**: when a caller obtains +a ``LocalSandbox`` from ``LocalSandboxProvider.acquire(thread_id)`` and invokes +its abstract methods with documented virtual paths, those paths must resolve to +the thread's user-data directory automatically — no tools.py / thread_data +shim required. +""" + +from __future__ import annotations + +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from deerflow.config.sandbox_config import SandboxConfig +from deerflow.sandbox.local.local_sandbox_provider import LocalSandboxProvider + + +def _build_config(skills_dir: Path) -> SimpleNamespace: + """Minimal app config covering what ``LocalSandboxProvider`` reads at init.""" + return SimpleNamespace( + skills=SimpleNamespace( + container_path="/mnt/skills", + get_skills_path=lambda: skills_dir, + use="deerflow.skills.storage.local_skill_storage:LocalSkillStorage", + ), + sandbox=SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=[]), + ) + + +@pytest.fixture +def isolated_paths(monkeypatch, tmp_path): + """Redirect ``get_paths().base_dir`` to ``tmp_path`` and reset its singleton. + + Without this, per-thread directories would be created under the developer's + real ``.deer-flow/`` tree. + """ + monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path)) + from deerflow.config import paths as paths_module + + monkeypatch.setattr(paths_module, "_paths", None) + yield tmp_path + monkeypatch.setattr(paths_module, "_paths", None) + + +@pytest.fixture +def provider(isolated_paths, tmp_path): + """Provider with a real skills dir and no custom mounts.""" + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + cfg = _build_config(skills_dir) + with patch("deerflow.config.get_app_config", return_value=cfg): + yield LocalSandboxProvider() + + +# ────────────────────────────────────────────────────────────────────────── +# 1. Direct Sandbox API accepts the virtual path contract for ``acquire(tid)`` +# ────────────────────────────────────────────────────────────────────────── + + +def test_acquire_with_thread_id_returns_per_thread_id(provider): + sandbox_id = provider.acquire("alpha") + assert sandbox_id == "local:alpha" + + +def test_acquire_without_thread_id_remains_legacy_local_id(provider): + """Backward-compat: ``acquire()`` with no thread keeps the singleton id.""" + assert provider.acquire() == "local" + assert provider.acquire(None) == "local" + + +def test_write_then_read_via_public_api_with_virtual_path(provider): + sandbox_id = provider.acquire("alpha") + sbx = provider.get(sandbox_id) + assert sbx is not None + + virtual = "/mnt/user-data/workspace/hello.txt" + sbx.write_file(virtual, "hi there") + assert sbx.read_file(virtual) == "hi there" + + +def test_list_dir_via_public_api_with_virtual_path(provider): + sandbox_id = provider.acquire("alpha") + sbx = provider.get(sandbox_id) + sbx.write_file("/mnt/user-data/workspace/foo.txt", "x") + entries = sbx.list_dir("/mnt/user-data/workspace") + # entries should be reverse-resolved back to the virtual prefix + assert any("/mnt/user-data/workspace/foo.txt" in e for e in entries) + + +def test_execute_command_with_virtual_path(provider): + sandbox_id = provider.acquire("alpha") + sbx = provider.get(sandbox_id) + sbx.write_file("/mnt/user-data/uploads/note.txt", "payload") + output = sbx.execute_command("ls /mnt/user-data/uploads") + assert "note.txt" in output + + +def test_glob_with_virtual_path(provider): + sandbox_id = provider.acquire("alpha") + sbx = provider.get(sandbox_id) + sbx.write_file("/mnt/user-data/outputs/report.md", "# r") + matches, _ = sbx.glob("/mnt/user-data/outputs", "*.md") + assert any(m.endswith("/mnt/user-data/outputs/report.md") for m in matches) + + +def test_grep_with_virtual_path(provider): + sandbox_id = provider.acquire("alpha") + sbx = provider.get(sandbox_id) + sbx.write_file("/mnt/user-data/workspace/findme.txt", "needle line\nother line") + matches, _ = sbx.grep("/mnt/user-data/workspace", "needle", literal=True) + assert matches + assert matches[0].path.endswith("/mnt/user-data/workspace/findme.txt") + + +def test_execute_command_lists_aggregate_user_data_root(provider): + """``ls /mnt/user-data`` (the parent prefix itself) must list the three + subdirs — matching the AIO container's natural filesystem view.""" + sandbox_id = provider.acquire("alpha") + sbx = provider.get(sandbox_id) + # Touch all three subdirs so they materialise on disk + sbx.write_file("/mnt/user-data/workspace/.keep", "") + sbx.write_file("/mnt/user-data/uploads/.keep", "") + sbx.write_file("/mnt/user-data/outputs/.keep", "") + output = sbx.execute_command("ls /mnt/user-data") + assert "workspace" in output + assert "uploads" in output + assert "outputs" in output + + +def test_update_file_with_virtual_path_for_remote_sync_scenario(provider): + """This is the exact code path used by ``uploads.py:282`` and ``feishu.py:389``. + + They build a ``virtual_path`` like ``/mnt/user-data/uploads/foo.pdf`` and hand + raw bytes to the sandbox. Before this fix LocalSandbox would try to write to + the literal host path ``/mnt/user-data/uploads/foo.pdf`` and fail. + """ + sandbox_id = provider.acquire("alpha") + sbx = provider.get(sandbox_id) + sbx.update_file("/mnt/user-data/uploads/blob.bin", b"\x00\x01\x02binary") + assert sbx.read_file("/mnt/user-data/uploads/blob.bin").startswith("\x00\x01\x02") + + +# ────────────────────────────────────────────────────────────────────────── +# 2. Per-thread isolation (no cross-thread state leaks) +# ────────────────────────────────────────────────────────────────────────── + + +def test_two_threads_get_distinct_sandboxes(provider): + sid_a = provider.acquire("alpha") + sid_b = provider.acquire("beta") + assert sid_a != sid_b + + sbx_a = provider.get(sid_a) + sbx_b = provider.get(sid_b) + assert sbx_a is not sbx_b + + +def test_per_thread_user_data_mapping_isolated(provider, isolated_paths): + """Files written via one thread's sandbox must not be visible through another.""" + sid_a = provider.acquire("alpha") + sid_b = provider.acquire("beta") + sbx_a = provider.get(sid_a) + sbx_b = provider.get(sid_b) + + sbx_a.write_file("/mnt/user-data/workspace/secret.txt", "alpha-only") + # The same virtual path resolves to a different host path in thread "beta" + with pytest.raises(FileNotFoundError): + sbx_b.read_file("/mnt/user-data/workspace/secret.txt") + + +def test_agent_written_paths_per_thread_isolation(provider): + """``_agent_written_paths`` tracks files this sandbox wrote so reverse-resolve + runs on read. The set must not leak across threads.""" + sid_a = provider.acquire("alpha") + sid_b = provider.acquire("beta") + sbx_a = provider.get(sid_a) + sbx_b = provider.get(sid_b) + sbx_a.write_file("/mnt/user-data/workspace/in-a.txt", "marker") + assert sbx_a._agent_written_paths + assert not sbx_b._agent_written_paths + + +# ────────────────────────────────────────────────────────────────────────── +# 3. Lifecycle: get / release / reset +# ────────────────────────────────────────────────────────────────────────── + + +def test_get_returns_cached_instance_for_known_id(provider): + sid = provider.acquire("alpha") + assert provider.get(sid) is provider.get(sid) + + +def test_get_unknown_id_returns_none(provider): + assert provider.get("local:nonexistent") is None + + +def test_release_is_noop_keeps_instance_available(provider): + """Local has no resources to release; the cached instance stays alive across + turns so ``_agent_written_paths`` persists for reverse-resolve on later reads.""" + sid = provider.acquire("alpha") + sbx_before = provider.get(sid) + provider.release(sid) + sbx_after = provider.get(sid) + assert sbx_before is sbx_after + + +def test_reset_clears_both_generic_and_per_thread_caches(provider): + provider.acquire() # populate generic + provider.acquire("alpha") # populate per-thread + assert provider._generic_sandbox is not None + assert provider._thread_sandboxes + + provider.reset() + assert provider._generic_sandbox is None + assert not provider._thread_sandboxes + + +# ────────────────────────────────────────────────────────────────────────── +# 4. is_local_sandbox detects both legacy and per-thread ids +# ────────────────────────────────────────────────────────────────────────── + + +def test_is_local_sandbox_accepts_both_id_formats(): + from deerflow.sandbox.tools import is_local_sandbox + + legacy = SimpleNamespace(state={"sandbox": {"sandbox_id": "local"}}, context={}) + per_thread = SimpleNamespace(state={"sandbox": {"sandbox_id": "local:alpha"}}, context={}) + foreign = SimpleNamespace(state={"sandbox": {"sandbox_id": "aio-12345"}}, context={}) + unset = SimpleNamespace(state={}, context={}) + + assert is_local_sandbox(legacy) is True + assert is_local_sandbox(per_thread) is True + assert is_local_sandbox(foreign) is False + assert is_local_sandbox(unset) is False + + +# ────────────────────────────────────────────────────────────────────────── +# 5. Concurrency safety (Copilot review feedback) +# ────────────────────────────────────────────────────────────────────────── + + +def test_concurrent_acquire_same_thread_yields_single_instance(provider): + """Two threads racing on ``acquire("alpha")`` must share one LocalSandbox. + + Without the provider lock the check-then-act in ``acquire`` is non-atomic: + both racers would see an empty cache, both would build their own + LocalSandbox, and one would overwrite the other — losing the loser's + ``_agent_written_paths`` and any in-flight state on it. + """ + import threading + import time + + from deerflow.sandbox.local import local_sandbox as local_sandbox_module + + # Force a wide race window by slowing the LocalSandbox constructor down. + original_init = local_sandbox_module.LocalSandbox.__init__ + + def slow_init(self, *args, **kwargs): + time.sleep(0.05) + original_init(self, *args, **kwargs) + + barrier = threading.Barrier(8) + results: list[str] = [] + results_lock = threading.Lock() + + def racer(): + barrier.wait() + sid = provider.acquire("alpha") + with results_lock: + results.append(sid) + + with patch.object(local_sandbox_module.LocalSandbox, "__init__", slow_init): + threads = [threading.Thread(target=racer) for _ in range(8)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Every racer must observe the same ``sandbox_id``… + assert len(set(results)) == 1, f"Racers saw different ids: {results}" + # …and the cache must hold exactly one instance for ``alpha``. + assert len(provider._thread_sandboxes) == 1 + assert "alpha" in provider._thread_sandboxes + + +def test_concurrent_acquire_distinct_threads_yields_distinct_instances(provider): + """Different thread_ids race-acquired in parallel each get their own sandbox.""" + import threading + + barrier = threading.Barrier(6) + sids: dict[str, str] = {} + lock = threading.Lock() + + def racer(name: str): + barrier.wait() + sid = provider.acquire(name) + with lock: + sids[name] = sid + + threads = [threading.Thread(target=racer, args=(f"t{i}",)) for i in range(6)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert set(sids.values()) == {f"local:t{i}" for i in range(6)} + assert set(provider._thread_sandboxes.keys()) == {f"t{i}" for i in range(6)} + + +# ────────────────────────────────────────────────────────────────────────── +# 6. Bounded memory growth (Copilot review feedback) +# ────────────────────────────────────────────────────────────────────────── + + +def test_thread_sandbox_cache_is_bounded(isolated_paths, tmp_path): + """The LRU cap must evict the least-recently-used thread sandboxes once + exceeded — otherwise long-running gateways would accumulate cache entries + for every distinct ``thread_id`` ever served.""" + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + cfg = _build_config(skills_dir) + + with patch("deerflow.config.get_app_config", return_value=cfg): + provider = LocalSandboxProvider(max_cached_threads=3) + + for i in range(5): + provider.acquire(f"t{i}") + + # Only the 3 most-recent thread_ids should be retained. + assert set(provider._thread_sandboxes.keys()) == {"t2", "t3", "t4"} + assert provider.get("local:t0") is None + assert provider.get("local:t4") is not None + + +def test_lru_promotes_recently_used_thread(isolated_paths, tmp_path): + """``get`` on a cached thread should mark it as most-recently used so a + later acquire-storm doesn't evict an active thread that is being polled.""" + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + cfg = _build_config(skills_dir) + + with patch("deerflow.config.get_app_config", return_value=cfg): + provider = LocalSandboxProvider(max_cached_threads=3) + + for name in ["a", "b", "c"]: + provider.acquire(name) + # Touch "a" via ``get`` so it becomes most-recently used. + provider.get("local:a") + # Adding a fourth thread should evict "b" (the new LRU), not "a". + provider.acquire("d") + + assert "a" in provider._thread_sandboxes + assert "b" not in provider._thread_sandboxes + assert {"a", "c", "d"} == set(provider._thread_sandboxes.keys()) From a814ab50b5721b4fbad95b005f59043f880ef18e Mon Sep 17 00:00:00 2001 From: Willem Jiang Date: Sun, 17 May 2026 08:59:42 +0800 Subject: [PATCH 34/86] fix(skills): make security scanner JSON parsing robust for LLM output variations (#2987) The moderation model's response was silently falling through to a conservative block when LLMs wrapped structured output in markdown code fences, added prose around the JSON, returned case-variant decisions (e.g. "Allow"), or included nested braces in the reason field. The greedy `\{.*\}` regex also over-matched on nested braces. - Rewrite _extract_json_object() with markdown fence stripping and brace-balanced string-aware extraction - Normalize decision field to lowercase for case-insensitive matching - Distinguish "model unavailable" from "unparseable output" in fallback - Strengthen system prompt to explicitly forbid code fences and prose - Add 15 tests covering all reported scenarios Fixes #2985 --- .../deerflow/skills/security_scanner.py | 59 +++++++-- backend/tests/test_security_scanner.py | 117 ++++++++++++++++-- 2 files changed, 159 insertions(+), 17 deletions(-) diff --git a/backend/packages/harness/deerflow/skills/security_scanner.py b/backend/packages/harness/deerflow/skills/security_scanner.py index 3bddb018f..a9c7b0279 100644 --- a/backend/packages/harness/deerflow/skills/security_scanner.py +++ b/backend/packages/harness/deerflow/skills/security_scanner.py @@ -23,19 +23,49 @@ class ScanResult: def _extract_json_object(raw: str) -> dict | None: raw = raw.strip() + + # Strip markdown code fences (```json ... ``` or ``` ... ```) + fence_match = re.match(r"^```(?:json)?\s*\n?(.*?)\n?\s*```$", raw, re.DOTALL) + if fence_match: + raw = fence_match.group(1).strip() + try: return json.loads(raw) except json.JSONDecodeError: pass - match = re.search(r"\{.*\}", raw, re.DOTALL) - if not match: - return None - try: - return json.loads(match.group(0)) - except json.JSONDecodeError: + # Brace-balanced extraction with string-awareness + start = raw.find("{") + if start == -1: return None + depth = 0 + in_string = False + escape = False + for i in range(start, len(raw)): + c = raw[i] + if escape: + escape = False + continue + if c == "\\": + escape = True + continue + if c == '"': + in_string = not in_string + continue + if in_string: + continue + if c == "{": + depth += 1 + elif c == "}": + depth -= 1 + if depth == 0: + try: + return json.loads(raw[start : i + 1]) + except json.JSONDecodeError: + return None + return None + async def scan_skill_content(content: str, *, executable: bool = False, location: str = SKILL_MD_FILE, app_config: AppConfig | None = None) -> ScanResult: """Screen skill content before it is written to disk.""" @@ -44,10 +74,12 @@ async def scan_skill_content(content: str, *, executable: bool = False, location "Classify the content as allow, warn, or block. " "Block clear prompt-injection, system-role override, privilege escalation, exfiltration, " "or unsafe executable code. Warn for borderline external API references. " - 'Return strict JSON: {"decision":"allow|warn|block","reason":"..."}.' + "Respond with ONLY a single JSON object on one line, no code fences, no commentary:\n" + '{"decision":"allow|warn|block","reason":"..."}' ) prompt = f"Location: {location}\nExecutable: {str(executable).lower()}\n\nReview this content:\n-----\n{content}\n-----" + model_responded = False try: config = app_config or get_app_config() model_name = config.skill_evolution.moderation_model_name @@ -59,12 +91,19 @@ async def scan_skill_content(content: str, *, executable: bool = False, location ], config={"run_name": "security_agent"}, ) - parsed = _extract_json_object(str(getattr(response, "content", "") or "")) - if parsed and parsed.get("decision") in {"allow", "warn", "block"}: - return ScanResult(parsed["decision"], str(parsed.get("reason") or "No reason provided.")) + model_responded = True + raw = str(getattr(response, "content", "") or "") + parsed = _extract_json_object(raw) + if parsed: + decision = str(parsed.get("decision", "")).lower() + if decision in {"allow", "warn", "block"}: + return ScanResult(decision, str(parsed.get("reason") or "No reason provided.")) + logger.warning("Security scan produced unparseable output: %s", raw[:200]) except Exception: logger.warning("Skill security scan model call failed; using conservative fallback", exc_info=True) + if model_responded: + return ScanResult("block", "Security scan produced unparseable output; manual review required.") if executable: return ScanResult("block", "Security scan unavailable for executable content; manual review required.") return ScanResult("block", "Security scan unavailable for skill content; manual review required.") diff --git a/backend/tests/test_security_scanner.py b/backend/tests/test_security_scanner.py index 088cb2c11..61277efd8 100644 --- a/backend/tests/test_security_scanner.py +++ b/backend/tests/test_security_scanner.py @@ -2,13 +2,12 @@ from types import SimpleNamespace import pytest -from deerflow.skills.security_scanner import scan_skill_content +from deerflow.skills.security_scanner import _extract_json_object, scan_skill_content -@pytest.mark.anyio -async def test_scan_skill_content_passes_run_name_to_model(monkeypatch): +def _make_env(monkeypatch, response_content): config = SimpleNamespace(skill_evolution=SimpleNamespace(moderation_model_name=None)) - fake_response = SimpleNamespace(content='{"decision":"allow","reason":"ok"}') + fake_response = SimpleNamespace(content=response_content) class FakeModel: async def ainvoke(self, *args, **kwargs): @@ -19,9 +18,59 @@ async def test_scan_skill_content_passes_run_name_to_model(monkeypatch): model = FakeModel() monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config) monkeypatch.setattr("deerflow.skills.security_scanner.create_chat_model", lambda **kwargs: model) + return model - result = await scan_skill_content("---\nname: demo-skill\ndescription: demo\n---\n", executable=False) +SKILL_CONTENT = "---\nname: demo-skill\ndescription: demo\n---\n" + + +# --- _extract_json_object unit tests --- + + +def test_extract_json_plain(): + assert _extract_json_object('{"decision":"allow","reason":"ok"}') == {"decision": "allow", "reason": "ok"} + + +def test_extract_json_markdown_fence(): + raw = '```json\n{"decision": "allow", "reason": "ok"}\n```' + assert _extract_json_object(raw) == {"decision": "allow", "reason": "ok"} + + +def test_extract_json_fence_no_language(): + raw = '```\n{"decision": "allow", "reason": "ok"}\n```' + assert _extract_json_object(raw) == {"decision": "allow", "reason": "ok"} + + +def test_extract_json_prose_wrapped(): + raw = 'Looking at this content I conclude: {"decision": "allow", "reason": "clean"} and that is final.' + assert _extract_json_object(raw) == {"decision": "allow", "reason": "clean"} + + +def test_extract_json_nested_braces_in_reason(): + raw = '{"decision": "allow", "reason": "no issues with {placeholder} found"}' + assert _extract_json_object(raw) == {"decision": "allow", "reason": "no issues with {placeholder} found"} + + +def test_extract_json_nested_braces_code_snippet(): + raw = 'Here is my review: {"decision": "block", "reason": "contains {\\"x\\": 1} code injection"}' + assert _extract_json_object(raw) == {"decision": "block", "reason": 'contains {"x": 1} code injection'} + + +def test_extract_json_returns_none_for_garbage(): + assert _extract_json_object("no json here") is None + + +def test_extract_json_returns_none_for_unclosed_brace(): + assert _extract_json_object('{"decision": "allow"') is None + + +# --- scan_skill_content integration tests --- + + +@pytest.mark.anyio +async def test_scan_skill_content_passes_run_name_to_model(monkeypatch): + model = _make_env(monkeypatch, '{"decision":"allow","reason":"ok"}') + result = await scan_skill_content(SKILL_CONTENT, executable=False) assert result.decision == "allow" assert model.kwargs["config"] == {"run_name": "security_agent"} @@ -32,7 +81,61 @@ async def test_scan_skill_content_blocks_when_model_unavailable(monkeypatch): monkeypatch.setattr("deerflow.skills.security_scanner.get_app_config", lambda: config) monkeypatch.setattr("deerflow.skills.security_scanner.create_chat_model", lambda **kwargs: (_ for _ in ()).throw(RuntimeError("boom"))) - result = await scan_skill_content("---\nname: demo-skill\ndescription: demo\n---\n", executable=False) + result = await scan_skill_content(SKILL_CONTENT, executable=False) assert result.decision == "block" - assert "manual review required" in result.reason + assert "unavailable" in result.reason + + +@pytest.mark.anyio +async def test_scan_allows_markdown_fenced_response(monkeypatch): + _make_env(monkeypatch, '```json\n{"decision": "allow", "reason": "clean"}\n```') + result = await scan_skill_content(SKILL_CONTENT, executable=False) + assert result.decision == "allow" + assert result.reason == "clean" + + +@pytest.mark.anyio +async def test_scan_normalizes_decision_case(monkeypatch): + _make_env(monkeypatch, '{"decision": "Allow", "reason": "looks fine"}') + result = await scan_skill_content(SKILL_CONTENT, executable=False) + assert result.decision == "allow" + + +@pytest.mark.anyio +async def test_scan_normalizes_uppercase_decision(monkeypatch): + _make_env(monkeypatch, '{"decision": "BLOCK", "reason": "dangerous"}') + result = await scan_skill_content(SKILL_CONTENT, executable=False) + assert result.decision == "block" + + +@pytest.mark.anyio +async def test_scan_handles_nested_braces_in_reason(monkeypatch): + _make_env(monkeypatch, '{"decision": "allow", "reason": "no issues with {placeholder}"}') + result = await scan_skill_content(SKILL_CONTENT, executable=False) + assert result.decision == "allow" + assert "{placeholder}" in result.reason + + +@pytest.mark.anyio +async def test_scan_handles_prose_wrapped_json(monkeypatch): + _make_env(monkeypatch, 'I reviewed the content: {"decision": "allow", "reason": "safe"}\nDone.') + result = await scan_skill_content(SKILL_CONTENT, executable=False) + assert result.decision == "allow" + + +@pytest.mark.anyio +async def test_scan_distinguishes_unparseable_from_unavailable(monkeypatch): + _make_env(monkeypatch, "I can't decide, this is just prose without any JSON at all.") + result = await scan_skill_content(SKILL_CONTENT, executable=False) + assert result.decision == "block" + assert "unparseable" in result.reason + + +@pytest.mark.anyio +async def test_scan_distinguishes_unparseable_executable(monkeypatch): + _make_env(monkeypatch, "no json here") + result = await scan_skill_content(SKILL_CONTENT, executable=True) + # Even for executable content, unparseable uses the unparseable message + assert result.decision == "block" + assert "unparseable" in result.reason From c0233cae268a7deaca559354ef45e1eaf327f197 Mon Sep 17 00:00:00 2001 From: jinghuan-Chen <42742857+jinghuan-Chen@users.noreply.github.com> Date: Sun, 17 May 2026 09:01:42 +0800 Subject: [PATCH 35/86] fix(frontend): resolve login page flickering and resize observer loop. (#2954) * fix(frontend): resolve login page flickering and resize observer loop. * fix(frontend): allow vertical scrolling on login page Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- frontend/src/app/(auth)/login/page.tsx | 2 +- frontend/src/components/ui/flickering-grid.tsx | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/src/app/(auth)/login/page.tsx b/frontend/src/app/(auth)/login/page.tsx index 82fcf8b90..6fb48a572 100644 --- a/frontend/src/app/(auth)/login/page.tsx +++ b/frontend/src/app/(auth)/login/page.tsx @@ -130,7 +130,7 @@ export default function LoginPage() { const actualTheme = theme === "system" ? resolvedTheme : theme; return ( -
+
= ({ return (
Date: Sun, 17 May 2026 15:23:42 +0800 Subject: [PATCH 36/86] fix(sandbox): scope provisioner PVC data by user (#2973) * fix(sandbox): scope provisioner PVC data by user * Address provisioner PVC review feedback --- .../community/aio_sandbox/remote_backend.py | 3 ++ backend/tests/test_aio_sandbox_provider.py | 35 +++++++++++++++++++ backend/tests/test_provisioner_pvc_volumes.py | 22 +++++++----- backend/tests/test_remote_sandbox_backend.py | 6 +++- docker/docker-compose-dev.yaml | 2 +- docker/provisioner/README.md | 25 ++++++++++--- docker/provisioner/app.py | 28 +++++++-------- 7 files changed, 92 insertions(+), 29 deletions(-) diff --git a/backend/packages/harness/deerflow/community/aio_sandbox/remote_backend.py b/backend/packages/harness/deerflow/community/aio_sandbox/remote_backend.py index 4f64070d2..9b23e05dc 100644 --- a/backend/packages/harness/deerflow/community/aio_sandbox/remote_backend.py +++ b/backend/packages/harness/deerflow/community/aio_sandbox/remote_backend.py @@ -21,6 +21,8 @@ import logging import requests +from deerflow.runtime.user_context import get_effective_user_id + from .backend import SandboxBackend from .sandbox_info import SandboxInfo @@ -138,6 +140,7 @@ class RemoteSandboxBackend(SandboxBackend): json={ "sandbox_id": sandbox_id, "thread_id": thread_id, + "user_id": get_effective_user_id(), }, timeout=30, ) diff --git a/backend/tests/test_aio_sandbox_provider.py b/backend/tests/test_aio_sandbox_provider.py index c7984531f..732d52170 100644 --- a/backend/tests/test_aio_sandbox_provider.py +++ b/backend/tests/test_aio_sandbox_provider.py @@ -1,11 +1,13 @@ """Tests for AioSandboxProvider mount helpers.""" import importlib +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest from deerflow.config.paths import Paths, join_host_path +from deerflow.runtime.user_context import reset_current_user, set_current_user # ── ensure_thread_dirs ─────────────────────────────────────────────────────── @@ -136,3 +138,36 @@ def test_discover_or_create_only_unlocks_when_lock_succeeds(tmp_path, monkeypatc provider._discover_or_create_with_lock("thread-5", "sandbox-5") assert unlock_calls == [] + + +def test_remote_backend_create_forwards_effective_user_id(monkeypatch): + """Provisioner mode must receive user_id so PVC subPath matches user isolation.""" + remote_mod = importlib.import_module("deerflow.community.aio_sandbox.remote_backend") + backend = remote_mod.RemoteSandboxBackend("http://provisioner:8002") + token = set_current_user(SimpleNamespace(id="user-7")) + posted: dict = {} + + class _Response: + def raise_for_status(self): + return None + + def json(self): + return {"sandbox_url": "http://sandbox.local"} + + def _post(url, json, timeout): # noqa: A002 - mirrors requests.post kwarg + posted.update({"url": url, "json": json, "timeout": timeout}) + return _Response() + + monkeypatch.setattr(remote_mod.requests, "post", _post) + + try: + backend.create("thread-42", "sandbox-42") + finally: + reset_current_user(token) + + assert posted["url"] == "http://provisioner:8002/api/sandboxes" + assert posted["json"] == { + "sandbox_id": "sandbox-42", + "thread_id": "thread-42", + "user_id": "user-7", + } diff --git a/backend/tests/test_provisioner_pvc_volumes.py b/backend/tests/test_provisioner_pvc_volumes.py index 5566f63bd..d5b66a2c7 100644 --- a/backend/tests/test_provisioner_pvc_volumes.py +++ b/backend/tests/test_provisioner_pvc_volumes.py @@ -92,12 +92,19 @@ class TestBuildVolumeMounts: userdata_mount = mounts[1] assert userdata_mount.sub_path is None - def test_pvc_sets_subpath(self, provisioner_module): - """PVC mode should set sub_path to threads/{thread_id}/user-data.""" + def test_pvc_sets_user_scoped_subpath(self, provisioner_module): + """PVC mode should include user_id in the user-data subPath.""" + provisioner_module.USERDATA_PVC_NAME = "my-pvc" + mounts = provisioner_module._build_volume_mounts("thread-42", user_id="user-7") + userdata_mount = mounts[1] + assert userdata_mount.sub_path == "deer-flow/users/user-7/threads/thread-42/user-data" + + def test_pvc_defaults_to_default_user_subpath(self, provisioner_module): + """Older callers should still land under a stable default user namespace.""" provisioner_module.USERDATA_PVC_NAME = "my-pvc" mounts = provisioner_module._build_volume_mounts("thread-42") userdata_mount = mounts[1] - assert userdata_mount.sub_path == "threads/thread-42/user-data" + assert userdata_mount.sub_path == "deer-flow/users/default/threads/thread-42/user-data" def test_skills_mount_read_only(self, provisioner_module): """Skills mount should always be read-only.""" @@ -146,13 +153,12 @@ class TestBuildPodVolumes: pod = provisioner_module._build_pod("sandbox-1", "thread-1") assert len(pod.spec.containers[0].volume_mounts) == 2 - def test_pod_pvc_mode(self, provisioner_module): - """Pod should use PVC volumes when PVC names are configured.""" + def test_pod_pvc_mode_uses_user_scoped_subpath(self, provisioner_module): + """Pod should use a user-scoped subPath for PVC user-data.""" provisioner_module.SKILLS_PVC_NAME = "skills-pvc" provisioner_module.USERDATA_PVC_NAME = "userdata-pvc" - pod = provisioner_module._build_pod("sandbox-1", "thread-1") + pod = provisioner_module._build_pod("sandbox-1", "thread-1", user_id="user-7") assert pod.spec.volumes[0].persistent_volume_claim is not None assert pod.spec.volumes[1].persistent_volume_claim is not None - # subPath should be set on user-data mount userdata_mount = pod.spec.containers[0].volume_mounts[1] - assert userdata_mount.sub_path == "threads/thread-1/user-data" + assert userdata_mount.sub_path == "deer-flow/users/user-7/threads/thread-1/user-data" diff --git a/backend/tests/test_remote_sandbox_backend.py b/backend/tests/test_remote_sandbox_backend.py index c33cd66ef..ed4dd7991 100644 --- a/backend/tests/test_remote_sandbox_backend.py +++ b/backend/tests/test_remote_sandbox_backend.py @@ -144,7 +144,11 @@ def test_provisioner_create_returns_sandbox_info(monkeypatch): def mock_post(url: str, json: dict, timeout: int): assert url == "http://provisioner:8002/api/sandboxes" - assert json == {"sandbox_id": "abc123", "thread_id": "thread-1"} + assert json == { + "sandbox_id": "abc123", + "thread_id": "thread-1", + "user_id": "test-user-autouse", + } assert timeout == 30 return _StubResponse(payload={"sandbox_id": "abc123", "sandbox_url": "http://k3s:31001"}) diff --git a/docker/docker-compose-dev.yaml b/docker/docker-compose-dev.yaml index db608f597..b2e15680f 100644 --- a/docker/docker-compose-dev.yaml +++ b/docker/docker-compose-dev.yaml @@ -37,7 +37,7 @@ services: - THREADS_HOST_PATH=${DEER_FLOW_ROOT}/backend/.deer-flow/threads # Production: use PVC instead of hostPath to avoid data loss on node failure. # When set, hostPath vars above are ignored for the corresponding volume. - # USERDATA_PVC_NAME uses subPath (threads/{thread_id}/user-data) automatically. + # USERDATA_PVC_NAME uses subPath (deer-flow/users/{user_id}/threads/{thread_id}/user-data) automatically. # - SKILLS_PVC_NAME=deer-flow-skills-pvc # - USERDATA_PVC_NAME=deer-flow-userdata-pvc - KUBECONFIG_PATH=/root/.kube/config diff --git a/docker/provisioner/README.md b/docker/provisioner/README.md index 557ad6cfd..36251da17 100644 --- a/docker/provisioner/README.md +++ b/docker/provisioner/README.md @@ -20,7 +20,7 @@ The **Sandbox Provisioner** is a FastAPI service that dynamically manages sandbo ### How It Works -1. **Backend Request**: When the backend needs to execute code, it sends a `POST /api/sandboxes` request with a `sandbox_id` and `thread_id`. +1. **Backend Request**: When the backend needs to execute code, it sends a `POST /api/sandboxes` request with a `sandbox_id`, `thread_id`, and optional `user_id`. 2. **Pod Creation**: The provisioner creates a dedicated Pod in the `deer-flow` namespace with: - The sandbox container image (all-in-one-sandbox) @@ -70,10 +70,13 @@ Create a new sandbox Pod + Service. ```json { "sandbox_id": "abc-123", - "thread_id": "thread-456" + "thread_id": "thread-456", + "user_id": "user-789" } ``` +`user_id` is optional for backwards compatibility and defaults to `default`. When `USERDATA_PVC_NAME` is set, the provisioner uses it to isolate PVC-backed user-data directories. + **Response**: ```json { @@ -138,11 +141,25 @@ The provisioner is configured via environment variables (set in [docker-compose- | `SKILLS_HOST_PATH` | - | **Host machine** path to skills directory (must be absolute) | | `THREADS_HOST_PATH` | - | **Host machine** path to threads data directory (must be absolute) | | `SKILLS_PVC_NAME` | empty (use hostPath) | PVC name for skills volume; when set, sandbox Pods use PVC instead of hostPath | -| `USERDATA_PVC_NAME` | empty (use hostPath) | PVC name for user-data volume; when set, uses PVC with `subPath: threads/{thread_id}/user-data` | +| `USERDATA_PVC_NAME` | empty (use hostPath) | PVC name for user-data volume; when set, uses PVC with `subPath: deer-flow/users/{user_id}/threads/{thread_id}/user-data` | | `KUBECONFIG_PATH` | `/root/.kube/config` | Path to kubeconfig **inside** the provisioner container | | `NODE_HOST` | `host.docker.internal` | Hostname that backend containers use to reach host NodePorts | | `K8S_API_SERVER` | (from kubeconfig) | Override K8s API server URL (e.g., `https://host.docker.internal:26443`) | +### PVC User-Data Upgrade Note + +Older provisioner versions mounted PVC user-data from `threads/{thread_id}/user-data`. The user-scoped layout mounts from `deer-flow/users/{user_id}/threads/{thread_id}/user-data`. + +If an existing deployment already has PVC-backed user-data under the legacy layout, migrate the DeerFlow data directory before relying on the new PVC subPath. Mount the same PVC path that the gateway uses as its DeerFlow base directory, then run the existing user-isolation migration script: + +```bash +cd backend +PYTHONPATH=. python scripts/migrate_user_isolation.py --dry-run +PYTHONPATH=. python scripts/migrate_user_isolation.py --user-id +``` + +This moves legacy `threads/{thread_id}/user-data` data under `users//threads/{thread_id}/user-data`, which matches the new provisioner PVC subPath when the gateway base directory is mounted at `deer-flow/` on the PVC. Use `default` as the target user only when the legacy data should remain in the default no-auth user namespace. Run the migration while no gateway or sandbox Pods are writing to those paths. + ### Important: K8S_API_SERVER Override If your kubeconfig uses `localhost`, `127.0.0.1`, or `0.0.0.0` as the API server address (common with OrbStack, minikube, kind), the provisioner **cannot** reach it from inside the Docker container. @@ -213,7 +230,7 @@ curl http://localhost:8002/health # Create a sandbox (via provisioner container for internal DNS) docker exec deer-flow-provisioner curl -X POST http://localhost:8002/api/sandboxes \ -H "Content-Type: application/json" \ - -d '{"sandbox_id":"test-001","thread_id":"thread-001"}' + -d '{"sandbox_id":"test-001","thread_id":"thread-001","user_id":"user-001"}' # Check sandbox status docker exec deer-flow-provisioner curl http://localhost:8002/api/sandboxes/test-001 diff --git a/docker/provisioner/app.py b/docker/provisioner/app.py index 11e1e424f..91c09f9ee 100644 --- a/docker/provisioner/app.py +++ b/docker/provisioner/app.py @@ -63,6 +63,8 @@ THREADS_HOST_PATH = os.environ.get("THREADS_HOST_PATH", "/.deer-flow/threads") SKILLS_PVC_NAME = os.environ.get("SKILLS_PVC_NAME", "") USERDATA_PVC_NAME = os.environ.get("USERDATA_PVC_NAME", "") SAFE_THREAD_ID_PATTERN = r"^[A-Za-z0-9_\-]+$" +SAFE_USER_ID_PATTERN = r"^[A-Za-z0-9_\-]+$" +DEFAULT_USER_ID = "default" # Path to the kubeconfig *inside* the provisioner container. # Typically the host's ~/.kube/config is mounted here. @@ -95,14 +97,6 @@ def join_host_path(base: str, *parts: str) -> str: return str(result) -def _validate_thread_id(thread_id: str) -> str: - if not re.match(SAFE_THREAD_ID_PATTERN, thread_id): - raise ValueError( - "Invalid thread_id: only alphanumeric characters, hyphens, and underscores are allowed." - ) - return thread_id - - # ── K8s client setup ──────────────────────────────────────────────────── core_v1: k8s_client.CoreV1Api | None = None @@ -221,6 +215,7 @@ app = FastAPI(title="DeerFlow Sandbox Provisioner", lifespan=lifespan) class CreateSandboxRequest(BaseModel): sandbox_id: str thread_id: str = Field(pattern=SAFE_THREAD_ID_PATTERN) + user_id: str = Field(default=DEFAULT_USER_ID, pattern=SAFE_USER_ID_PATTERN) class SandboxResponse(BaseModel): @@ -283,7 +278,7 @@ def _build_volumes(thread_id: str) -> list[k8s_client.V1Volume]: return [skills_vol, userdata_vol] -def _build_volume_mounts(thread_id: str) -> list[k8s_client.V1VolumeMount]: +def _build_volume_mounts(thread_id: str, user_id: str = DEFAULT_USER_ID) -> list[k8s_client.V1VolumeMount]: """Build volume mount list, using subPath for PVC user-data.""" userdata_mount = k8s_client.V1VolumeMount( name="user-data", @@ -291,7 +286,7 @@ def _build_volume_mounts(thread_id: str) -> list[k8s_client.V1VolumeMount]: read_only=False, ) if USERDATA_PVC_NAME: - userdata_mount.sub_path = f"threads/{thread_id}/user-data" + userdata_mount.sub_path = f"deer-flow/users/{user_id}/threads/{thread_id}/user-data" return [ k8s_client.V1VolumeMount( @@ -303,9 +298,8 @@ def _build_volume_mounts(thread_id: str) -> list[k8s_client.V1VolumeMount]: ] -def _build_pod(sandbox_id: str, thread_id: str) -> k8s_client.V1Pod: +def _build_pod(sandbox_id: str, thread_id: str, user_id: str = DEFAULT_USER_ID) -> k8s_client.V1Pod: """Construct a Pod manifest for a single sandbox.""" - thread_id = _validate_thread_id(thread_id) return k8s_client.V1Pod( metadata=k8s_client.V1ObjectMeta( name=_pod_name(sandbox_id), @@ -362,7 +356,7 @@ def _build_pod(sandbox_id: str, thread_id: str) -> k8s_client.V1Pod: "ephemeral-storage": "500Mi", }, ), - volume_mounts=_build_volume_mounts(thread_id), + volume_mounts=_build_volume_mounts(thread_id, user_id=user_id), security_context=k8s_client.V1SecurityContext( privileged=False, allow_privilege_escalation=True, @@ -445,9 +439,13 @@ async def create_sandbox(req: CreateSandboxRequest): """ sandbox_id = req.sandbox_id thread_id = req.thread_id + user_id = req.user_id logger.info( - f"Received request to create sandbox '{sandbox_id}' for thread '{thread_id}'" + "Received request to create sandbox '%s' for thread '%s' user '%s'", + sandbox_id, + thread_id, + user_id, ) # ── Fast path: sandbox already exists ──────────────────────────── @@ -461,7 +459,7 @@ async def create_sandbox(req: CreateSandboxRequest): # ── Create Pod ─────────────────────────────────────────────────── try: - core_v1.create_namespaced_pod(K8S_NAMESPACE, _build_pod(sandbox_id, thread_id)) + core_v1.create_namespaced_pod(K8S_NAMESPACE, _build_pod(sandbox_id, thread_id, user_id=user_id)) logger.info(f"Created Pod {_pod_name(sandbox_id)}") except ApiException as exc: if exc.status != 409: # 409 = AlreadyExists From 39f901d3a5e1aa0b0e5ceafea2a1e092566ac6de Mon Sep 17 00:00:00 2001 From: Willem Jiang Date: Sun, 17 May 2026 20:03:21 +0800 Subject: [PATCH 37/86] fix(runs): restore historical runs from persistent store after gateway restart (#2989) * fix(runs): restore historical runs from persistent store after gateway restart RunManager.list_by_thread() and get() only queried the in-memory _runs dict, returning empty results after a restart even when PostgreSQL had the records. Add store fallback to both read paths and a new async aget() for the API endpoint, keeping sync get() for internal callers that need live task/abort_event state. Fixes #2984 * Apply suggestions from code review Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix(runs): scope run store fallback reads by user id Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/e73daada-1215-4bc1-ab7d-7117826c5013 Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com> * test(runs): clarify ordering expectation and mock store filters Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/e73daada-1215-4bc1-ab7d-7117826c5013 Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com> * test(runs): make user filter fallback assertions explicit Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/e73daada-1215-4bc1-ab7d-7117826c5013 Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com> * test(runs): verify user-isolated fallback behavior with memory store Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/e73daada-1215-4bc1-ab7d-7117826c5013 Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com> * update the code with feedback from issue-2984 --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> --- backend/app/gateway/routers/thread_runs.py | 6 +- .../harness/deerflow/runtime/runs/manager.py | 57 ++++++- .../deerflow/runtime/runs/store/base.py | 7 +- .../deerflow/runtime/runs/store/memory.py | 9 +- backend/tests/test_run_manager.py | 158 ++++++++++++++++++ 5 files changed, 226 insertions(+), 11 deletions(-) diff --git a/backend/app/gateway/routers/thread_runs.py b/backend/app/gateway/routers/thread_runs.py index 30365fb7d..3d429fc03 100644 --- a/backend/app/gateway/routers/thread_runs.py +++ b/backend/app/gateway/routers/thread_runs.py @@ -180,7 +180,8 @@ async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> async def list_runs(thread_id: str, request: Request) -> list[RunResponse]: """List all runs for a thread.""" run_mgr = get_run_manager(request) - records = await run_mgr.list_by_thread(thread_id) + user_id = await get_current_user(request) + records = await run_mgr.list_by_thread(thread_id, user_id=user_id) return [_record_to_response(r) for r in records] @@ -189,7 +190,8 @@ async def list_runs(thread_id: str, request: Request) -> list[RunResponse]: async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse: """Get details of a specific run.""" run_mgr = get_run_manager(request) - record = run_mgr.get(run_id) + user_id = await get_current_user(request) + record = await run_mgr.aget(run_id, user_id=user_id) if record is None or record.thread_id != thread_id: raise HTTPException(status_code=404, detail=f"Run {run_id} not found") return _record_to_response(record) diff --git a/backend/packages/harness/deerflow/runtime/runs/manager.py b/backend/packages/harness/deerflow/runtime/runs/manager.py index 50dc594ab..11d6b478e 100644 --- a/backend/packages/harness/deerflow/runtime/runs/manager.py +++ b/backend/packages/harness/deerflow/runtime/runs/manager.py @@ -111,15 +111,60 @@ class RunManager: return record def get(self, run_id: str) -> RunRecord | None: - """Return a run record by ID, or ``None``.""" + """Return an in-memory run record by ID, or ``None``.""" return self._runs.get(run_id) - async def list_by_thread(self, thread_id: str) -> list[RunRecord]: - """Return all runs for a given thread, newest first.""" + async def aget(self, run_id: str, *, user_id: str | None = None) -> RunRecord | None: + """Return a run record by ID, checking the persistent store as fallback.""" + record = self._runs.get(run_id) + if record is not None: + return record + if self._store is not None: + try: + d = await self._store.get(run_id, user_id=user_id) + if d is not None: + return self._store_dict_to_record(d) + except Exception: + logger.warning("Failed to query store for run %s", run_id, exc_info=True) + return None + + def _store_dict_to_record(self, d: dict) -> RunRecord: + """Convert a store dict back to a RunRecord for read-only use.""" + return RunRecord( + run_id=d["run_id"], + thread_id=d["thread_id"], + assistant_id=d.get("assistant_id"), + status=RunStatus(d.get("status", RunStatus.error.value)), + on_disconnect=DisconnectMode.cancel, + multitask_strategy=d.get("multitask_strategy", "reject"), + metadata=d.get("metadata", {}), + kwargs=d.get("kwargs", {}), + created_at=d.get("created_at", ""), + updated_at=d.get("updated_at", ""), + model_name=d.get("model_name"), + error=d.get("error"), + ) + + async def list_by_thread(self, thread_id: str, *, user_id: str | None = None) -> list[RunRecord]: + """Return all runs for a given thread, oldest first.""" async with self._lock: - # Dict insertion order matches creation order, so reversing it gives - # us deterministic newest-first results even when timestamps tie. - return [r for r in self._runs.values() if r.thread_id == thread_id] + in_memory = [r for r in self._runs.values() if r.thread_id == thread_id] + in_memory_ids = {r.run_id for r in in_memory} + + store_records: list[RunRecord] = [] + if self._store is not None: + try: + store_dicts = await self._store.list_by_thread(thread_id, user_id=user_id) + for d in store_dicts: + if d["run_id"] not in in_memory_ids: + store_records.append(self._store_dict_to_record(d)) + except Exception: + logger.warning("Failed to query store for thread %s runs", thread_id, exc_info=True) + + return sorted( + in_memory + store_records, + key=lambda record: record.created_at or "", + ) async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None: """Transition a run to a new status.""" diff --git a/backend/packages/harness/deerflow/runtime/runs/store/base.py b/backend/packages/harness/deerflow/runtime/runs/store/base.py index d3c10eba6..a742d89ca 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/base.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/base.py @@ -34,7 +34,12 @@ class RunStore(abc.ABC): pass @abc.abstractmethod - async def get(self, run_id: str) -> dict[str, Any] | None: + async def get( + self, + run_id: str, + *, + user_id: str | None = None, + ) -> dict[str, Any] | None: pass @abc.abstractmethod diff --git a/backend/packages/harness/deerflow/runtime/runs/store/memory.py b/backend/packages/harness/deerflow/runtime/runs/store/memory.py index e41147e3e..9db27cacc 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/memory.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/memory.py @@ -46,8 +46,13 @@ class MemoryRunStore(RunStore): "updated_at": now, } - async def get(self, run_id): - return self._runs.get(run_id) + async def get(self, run_id, *, user_id=None): + run = self._runs.get(run_id) + if run is None: + return None + if user_id is not None and run.get("user_id") != user_id: + return None + return run async def list_by_thread(self, thread_id, *, user_id=None, limit=100): results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (user_id is None or r.get("user_id") == user_id)] diff --git a/backend/tests/test_run_manager.py b/backend/tests/test_run_manager.py index 98cd58264..de8f66319 100644 --- a/backend/tests/test_run_manager.py +++ b/backend/tests/test_run_manager.py @@ -83,6 +83,7 @@ async def test_list_by_thread(manager: RunManager): runs = await manager.list_by_thread("thread-1") assert len(runs) == 2 + # list_by_thread returns oldest-first (ascending created_at). assert runs[0].run_id == r1.run_id assert runs[1].run_id == r2.run_id @@ -192,3 +193,160 @@ async def test_model_name_default_is_none(): stored = await store.get(record.run_id) assert stored["model_name"] is None + + +# --------------------------------------------------------------------------- +# Store fallback tests (simulates gateway restart scenario) +# --------------------------------------------------------------------------- + + +@pytest.fixture +def manager_with_store() -> RunManager: + """RunManager backed by a MemoryRunStore.""" + return RunManager(store=MemoryRunStore()) + + +@pytest.mark.anyio +async def test_list_by_thread_returns_store_records_after_restart(manager_with_store: RunManager): + """After in-memory state is cleared (simulating restart), list_by_thread + should still return runs from the persistent store.""" + mgr = manager_with_store + r1 = await mgr.create("thread-1", "agent-1") + await mgr.set_status(r1.run_id, RunStatus.success) + r2 = await mgr.create("thread-1", "agent-2") + await mgr.set_status(r2.run_id, RunStatus.error, error="boom") + + # Clear in-memory dict to simulate a restart + mgr._runs.clear() + + runs = await mgr.list_by_thread("thread-1") + assert len(runs) == 2 + statuses = {r.run_id: r.status for r in runs} + assert statuses[r1.run_id] == RunStatus.success + assert statuses[r2.run_id] == RunStatus.error + # Verify other fields survive the round-trip + for r in runs: + assert r.thread_id == "thread-1" + assert ISO_RE.match(r.created_at) + + +@pytest.mark.anyio +async def test_list_by_thread_merges_in_memory_and_store(manager_with_store: RunManager): + """In-memory runs should be included alongside store-only records.""" + mgr = manager_with_store + + # Create a run and let it complete (will be in both memory and store) + r1 = await mgr.create("thread-1") + await mgr.set_status(r1.run_id, RunStatus.success) + + # Simulate restart: clear memory, then create a new in-memory run + mgr._runs.clear() + r2 = await mgr.create("thread-1") + + runs = await mgr.list_by_thread("thread-1") + assert len(runs) == 2 + run_ids = {r.run_id for r in runs} + assert r1.run_id in run_ids + assert r2.run_id in run_ids + + # r2 should be the in-memory record (has live state) + r2_record = next(r for r in runs if r.run_id == r2.run_id) + assert r2_record is r2 # same object reference + + +@pytest.mark.anyio +async def test_list_by_thread_no_store(): + """Without a store, list_by_thread should only return in-memory runs.""" + mgr = RunManager() + await mgr.create("thread-1") + + mgr._runs.clear() + runs = await mgr.list_by_thread("thread-1") + assert runs == [] + + +@pytest.mark.anyio +async def test_aget_returns_in_memory_record(manager_with_store: RunManager): + """aget should return the in-memory record when available.""" + mgr = manager_with_store + r1 = await mgr.create("thread-1", "agent-1") + + result = await mgr.aget(r1.run_id) + assert result is r1 # same object + + +@pytest.mark.anyio +async def test_aget_falls_back_to_store(manager_with_store: RunManager): + """aget should return a record from the store when not in memory.""" + mgr = manager_with_store + r1 = await mgr.create("thread-1", "agent-1") + await mgr.set_status(r1.run_id, RunStatus.success) + + mgr._runs.clear() + + result = await mgr.aget(r1.run_id) + assert result is not None + assert result.run_id == r1.run_id + assert result.status == RunStatus.success + assert result.thread_id == "thread-1" + assert result.assistant_id == "agent-1" + + +@pytest.mark.anyio +async def test_aget_falls_back_to_store_with_user_filter(): + """aget should honor user_id when reading store-only records.""" + store = MemoryRunStore() + await store.put("run-1", thread_id="thread-1", user_id="user-1", status="success") + mgr = RunManager(store=store) + + allowed = await mgr.aget("run-1", user_id="user-1") + denied = await mgr.aget("run-1", user_id="user-2") + assert allowed is not None + assert denied is None + + +@pytest.mark.anyio +async def test_aget_returns_none_for_unknown(manager_with_store: RunManager): + """aget should return None for a run ID that doesn't exist anywhere.""" + result = await manager_with_store.aget("nonexistent-run-id") + assert result is None + + +@pytest.mark.anyio +async def test_aget_store_failure_is_graceful(): + """If the store raises, aget should return None instead of propagating.""" + from unittest.mock import AsyncMock + + store = MemoryRunStore() + store.get = AsyncMock(side_effect=RuntimeError("db down")) + mgr = RunManager(store=store) + + result = await mgr.aget("some-id") + assert result is None + + +@pytest.mark.anyio +async def test_list_by_thread_store_failure_is_graceful(): + """If the store raises, list_by_thread should return only in-memory runs.""" + from unittest.mock import AsyncMock + + store = MemoryRunStore() + store.list_by_thread = AsyncMock(side_effect=RuntimeError("db down")) + mgr = RunManager(store=store) + + r1 = await mgr.create("thread-1") + runs = await mgr.list_by_thread("thread-1") + assert len(runs) == 1 + assert runs[0].run_id == r1.run_id + + +@pytest.mark.anyio +async def test_list_by_thread_falls_back_to_store_with_user_filter(): + """list_by_thread should return only the requesting user's store records.""" + store = MemoryRunStore() + await store.put("run-1", thread_id="thread-1", user_id="user-1", status="success") + await store.put("run-2", thread_id="thread-1", user_id="user-2", status="success") + mgr = RunManager(store=store) + + runs = await mgr.list_by_thread("thread-1", user_id="user-1") + assert [r.run_id for r in runs] == ["run-1"] From b5108e35206d989abca05c83a7831e0903bc975c Mon Sep 17 00:00:00 2001 From: Willem Jiang Date: Mon, 18 May 2026 22:07:01 +0800 Subject: [PATCH 38/86] fix(auth): replace setup-status 429 rate limit with cached response (#2915) * fix(auth): replace setup-status 429 rate limit with cached response The /api/v1/auth/setup-status endpoint had a 60-second cooldown that returned HTTP 429 for all but the first request per IP. When the service restarted with multiple browser tabs open, all tabs hit this endpoint simultaneously from the same source IP, causing a storm of 429 errors that blocked the login flow. Replace the cooldown-with-429 model with a per-IP response cache that returns the previously computed result within the TTL. The database query (count_admin_users) still only runs once per IP per 60 seconds, preserving the original performance goal while eliminating spurious 429 errors on multi-tab reconnection. Fixes #2902 * fix(auth): address setup-status cache review issues Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/439a0e8c-8b64-41d4-a3cd-fe9a00eec534 Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com> * test(auth): improve readability of setup-status concurrency assertion Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/439a0e8c-8b64-41d4-a3cd-fe9a00eec534 Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> * fix the unit test error --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> --- backend/app/gateway/routers/auth.py | 84 +++++++++++++++++-------- backend/tests/test_initialize_admin.py | 85 ++++++++++++++++++++++---- 2 files changed, 133 insertions(+), 36 deletions(-) diff --git a/backend/app/gateway/routers/auth.py b/backend/app/gateway/routers/auth.py index 6192456fb..e57182c26 100644 --- a/backend/app/gateway/routers/auth.py +++ b/backend/app/gateway/routers/auth.py @@ -1,5 +1,6 @@ """Authentication endpoints.""" +import asyncio import logging import os import time @@ -382,9 +383,15 @@ async def get_me(request: Request): return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup) -_SETUP_STATUS_COOLDOWN: dict[str, float] = {} -_SETUP_STATUS_COOLDOWN_SECONDS = 60 +# Per-IP cache: ip → (timestamp, result_dict). +# Returns the cached result within the TTL instead of 429, because +# the answer (whether an admin exists) rarely changes and returning +# 429 breaks multi-tab / post-restart reconnection storms. +_SETUP_STATUS_CACHE: dict[str, tuple[float, dict]] = {} +_SETUP_STATUS_CACHE_TTL_SECONDS = 60 _MAX_TRACKED_SETUP_STATUS_IPS = 10000 +_SETUP_STATUS_INFLIGHT: dict[str, asyncio.Task[dict]] = {} +_SETUP_STATUS_INFLIGHT_GUARD = asyncio.Lock() @router.get("/setup-status") @@ -392,29 +399,56 @@ async def setup_status(request: Request): """Check if an admin account exists. Returns needs_setup=True when no admin exists.""" client_ip = _get_client_ip(request) now = time.time() - last_check = _SETUP_STATUS_COOLDOWN.get(client_ip, 0) - elapsed = now - last_check - if elapsed < _SETUP_STATUS_COOLDOWN_SECONDS: - retry_after = max(1, int(_SETUP_STATUS_COOLDOWN_SECONDS - elapsed)) - raise HTTPException( - status_code=status.HTTP_429_TOO_MANY_REQUESTS, - detail="Setup status check is rate limited", - headers={"Retry-After": str(retry_after)}, - ) - # Evict stale entries when dict grows too large to bound memory usage. - if len(_SETUP_STATUS_COOLDOWN) >= _MAX_TRACKED_SETUP_STATUS_IPS: - cutoff = now - _SETUP_STATUS_COOLDOWN_SECONDS - stale = [k for k, t in _SETUP_STATUS_COOLDOWN.items() if t < cutoff] - for k in stale: - del _SETUP_STATUS_COOLDOWN[k] - # If still too large after evicting expired entries, remove oldest half. - if len(_SETUP_STATUS_COOLDOWN) >= _MAX_TRACKED_SETUP_STATUS_IPS: - by_time = sorted(_SETUP_STATUS_COOLDOWN.items(), key=lambda kv: kv[1]) - for k, _ in by_time[: len(by_time) // 2]: - del _SETUP_STATUS_COOLDOWN[k] - _SETUP_STATUS_COOLDOWN[client_ip] = now - admin_count = await get_local_provider().count_admin_users() - return {"needs_setup": admin_count == 0} + + # Return cached result when within TTL — avoids 429 on multi-tab reconnection. + cached = _SETUP_STATUS_CACHE.get(client_ip) + if cached is not None: + cached_time, cached_result = cached + if now - cached_time < _SETUP_STATUS_CACHE_TTL_SECONDS: + return cached_result + + async with _SETUP_STATUS_INFLIGHT_GUARD: + # Recheck cache after waiting for the inflight guard. + now = time.time() + cached = _SETUP_STATUS_CACHE.get(client_ip) + if cached is not None: + cached_time, cached_result = cached + if now - cached_time < _SETUP_STATUS_CACHE_TTL_SECONDS: + return cached_result + + task = _SETUP_STATUS_INFLIGHT.get(client_ip) + if task is None: + # Evict stale entries when dict grows too large to bound memory usage. + if len(_SETUP_STATUS_CACHE) >= _MAX_TRACKED_SETUP_STATUS_IPS: + cutoff = now - _SETUP_STATUS_CACHE_TTL_SECONDS + stale = [k for k, (t, _) in _SETUP_STATUS_CACHE.items() if t < cutoff] + for k in stale: + del _SETUP_STATUS_CACHE[k] + if len(_SETUP_STATUS_CACHE) >= _MAX_TRACKED_SETUP_STATUS_IPS: + by_time = sorted(_SETUP_STATUS_CACHE.items(), key=lambda entry: entry[1][0]) + for k, _ in by_time[: len(by_time) // 2]: + del _SETUP_STATUS_CACHE[k] + + async def _compute_setup_status() -> dict: + admin_count = await get_local_provider().count_admin_users() + return {"needs_setup": admin_count == 0} + + task = asyncio.create_task(_compute_setup_status()) + _SETUP_STATUS_INFLIGHT[client_ip] = task + + try: + result = await task + finally: + async with _SETUP_STATUS_INFLIGHT_GUARD: + if _SETUP_STATUS_INFLIGHT.get(client_ip) is task: + del _SETUP_STATUS_INFLIGHT[client_ip] + + # Cache only the stable "initialized" result to avoid stale setup redirects. + if result["needs_setup"] is False: + _SETUP_STATUS_CACHE[client_ip] = (time.time(), result) + else: + _SETUP_STATUS_CACHE.pop(client_ip, None) + return result class InitializeAdminRequest(BaseModel): diff --git a/backend/tests/test_initialize_admin.py b/backend/tests/test_initialize_admin.py index 26b2ec6b2..514ee6df3 100644 --- a/backend/tests/test_initialize_admin.py +++ b/backend/tests/test_initialize_admin.py @@ -22,7 +22,7 @@ _TEST_SECRET = "test-secret-key-initialize-admin-min-32" def _setup_auth(tmp_path): """Fresh SQLite engine + auth config per test.""" from app.gateway import deps - from app.gateway.routers.auth import _SETUP_STATUS_COOLDOWN + from app.gateway.routers.auth import _SETUP_STATUS_CACHE, _SETUP_STATUS_INFLIGHT from deerflow.persistence.engine import close_engine, init_engine set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET)) @@ -30,13 +30,15 @@ def _setup_auth(tmp_path): asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))) deps._cached_local_provider = None deps._cached_repo = None - _SETUP_STATUS_COOLDOWN.clear() + _SETUP_STATUS_CACHE.clear() + _SETUP_STATUS_INFLIGHT.clear() try: yield finally: deps._cached_local_provider = None deps._cached_repo = None - _SETUP_STATUS_COOLDOWN.clear() + _SETUP_STATUS_CACHE.clear() + _SETUP_STATUS_INFLIGHT.clear() asyncio.run(close_engine()) @@ -168,15 +170,76 @@ def test_setup_status_false_when_only_regular_user_exists(client): assert resp.json()["needs_setup"] is True -def test_setup_status_rate_limited_on_second_call(client): - """Second /setup-status call within the cooldown window returns 429 with Retry-After.""" - # First call succeeds. +def test_setup_status_returns_cached_result_on_rapid_calls(client): + """Rapid /setup-status calls return the cached result (200) instead of 429.""" + client.post("/api/v1/auth/initialize", json=_init_payload()) + + # First call succeeds and computes the result. resp1 = client.get("/api/v1/auth/setup-status") assert resp1.status_code == 200 - # Immediate second call is rate-limited. + # Immediate second call returns cached result, not 429. resp2 = client.get("/api/v1/auth/setup-status") - assert resp2.status_code == 429 - assert "Retry-After" in resp2.headers - retry_after = int(resp2.headers["Retry-After"]) - assert 1 <= retry_after <= 60 + assert resp2.status_code == 200 + assert resp2.json() == resp1.json() + assert resp2.json()["needs_setup"] is False + + +def test_setup_status_does_not_return_stale_true_after_initialize(client): + """A pre-initialize setup-status response should not stay cached as True.""" + before = client.get("/api/v1/auth/setup-status") + assert before.status_code == 200 + assert before.json()["needs_setup"] is True + + init = client.post("/api/v1/auth/initialize", json=_init_payload()) + assert init.status_code == 201 + + after = client.get("/api/v1/auth/setup-status") + assert after.status_code == 200 + assert after.json()["needs_setup"] is False + + +@pytest.mark.asyncio +async def test_setup_status_single_flight_per_ip(monkeypatch): + """Concurrent requests from same IP share one in-flight DB query.""" + from starlette.requests import Request + + from app.gateway.routers.auth import ( + _SETUP_STATUS_CACHE, + _SETUP_STATUS_INFLIGHT, + setup_status, + ) + + class _Provider: + def __init__(self): + self.calls = 0 + + async def count_admin_users(self): + self.calls += 1 + await asyncio.sleep(0.05) + return 0 + + provider = _Provider() + monkeypatch.setattr("app.gateway.routers.auth.get_local_provider", lambda: provider) + _SETUP_STATUS_CACHE.clear() + _SETUP_STATUS_INFLIGHT.clear() + + def _request() -> Request: + return Request( + { + "type": "http", + "method": "GET", + "path": "/api/v1/auth/setup-status", + "headers": [], + "client": ("127.0.0.1", 12345), + } + ) + + results = await asyncio.gather( + setup_status(_request()), + setup_status(_request()), + setup_status(_request()), + ) + + assert all(result["needs_setup"] is True for result in results) + assert provider.calls == 1 From 3acca1261475e2831b22da23c6985ddf68a00598 Mon Sep 17 00:00:00 2001 From: KiteEater <145987840+Kiteeater@users.noreply.github.com> Date: Mon, 18 May 2026 22:19:32 +0800 Subject: [PATCH 39/86] fix(subagents): make subagent timeout terminal state atomic (#2583) * Guard subagent terminal state transitions * fix: publish subagent terminal status last * Fix subagent timeout test to avoid blocking event loop * Fix subagent timeout test tracking * Refine subagent terminal state handling --------- Co-authored-by: Willem Jiang --- .../harness/deerflow/subagents/executor.py | 143 +++++++++++------- backend/tests/test_subagent_executor.py | 134 ++++++++++++++++ 2 files changed, 222 insertions(+), 55 deletions(-) diff --git a/backend/packages/harness/deerflow/subagents/executor.py b/backend/packages/harness/deerflow/subagents/executor.py index d6d2e4fc5..8fcbd5e1d 100644 --- a/backend/packages/harness/deerflow/subagents/executor.py +++ b/backend/packages/harness/deerflow/subagents/executor.py @@ -47,6 +47,15 @@ class SubagentStatus(Enum): CANCELLED = "cancelled" TIMED_OUT = "timed_out" + @property + def is_terminal(self) -> bool: + return self in { + type(self).COMPLETED, + type(self).FAILED, + type(self).CANCELLED, + type(self).TIMED_OUT, + } + @dataclass class SubagentResult: @@ -74,12 +83,48 @@ class SubagentResult: token_usage_records: list[dict[str, int | str]] = field(default_factory=list) usage_reported: bool = False cancel_event: threading.Event = field(default_factory=threading.Event, repr=False) + _state_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False) def __post_init__(self): """Initialize mutable defaults.""" if self.ai_messages is None: self.ai_messages = [] + def try_set_terminal( + self, + status: SubagentStatus, + *, + result: str | None = None, + error: str | None = None, + completed_at: datetime | None = None, + ai_messages: list[dict[str, Any]] | None = None, + token_usage_records: list[dict[str, int | str]] | None = None, + ) -> bool: + """Set a terminal status exactly once. + + Background timeout/cancellation and the execution worker can race on the + same result holder. The first terminal transition wins; late terminal + writes must not change status or payload fields. + """ + if not status.is_terminal: + raise ValueError(f"Status {status} is not terminal") + + with self._state_lock: + if self.status.is_terminal: + return False + + if result is not None: + self.result = result + if error is not None: + self.error = error + if ai_messages is not None: + self.ai_messages = ai_messages + if token_usage_records is not None: + self.token_usage_records = token_usage_records + self.completed_at = completed_at or datetime.now() + self.status = status + return True + # Global storage for background task results _background_tasks: dict[str, SubagentResult] = {} @@ -459,13 +504,11 @@ class SubagentExecutor: # Pre-check: bail out immediately if already cancelled before streaming starts if result.cancel_event.is_set(): logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} cancelled before streaming") - with _background_tasks_lock: - if result.status == SubagentStatus.RUNNING: - result.status = SubagentStatus.CANCELLED - result.error = "Cancelled by user" - result.completed_at = datetime.now() - if collector is not None: - result.token_usage_records = collector.snapshot_records() + result.try_set_terminal( + SubagentStatus.CANCELLED, + error="Cancelled by user", + token_usage_records=collector.snapshot_records(), + ) return result async for chunk in agent.astream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type] @@ -475,12 +518,11 @@ class SubagentExecutor: # interrupted until the next chunk is yielded. if result.cancel_event.is_set(): logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} cancelled by parent") - with _background_tasks_lock: - if result.status == SubagentStatus.RUNNING: - result.status = SubagentStatus.CANCELLED - result.error = "Cancelled by user" - result.completed_at = datetime.now() - result.token_usage_records = collector.snapshot_records() + result.try_set_terminal( + SubagentStatus.CANCELLED, + error="Cancelled by user", + token_usage_records=collector.snapshot_records(), + ) return result final_state = chunk @@ -507,11 +549,12 @@ class SubagentExecutor: logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(ai_messages)}") logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed async execution") - result.token_usage_records = collector.snapshot_records() + token_usage_records = collector.snapshot_records() + final_result: str | None = None if final_state is None: logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no final state") - result.result = "No response generated" + final_result = "No response generated" else: # Extract the final message - find the last AIMessage messages = final_state.get("messages", []) @@ -528,7 +571,7 @@ class SubagentExecutor: content = last_ai_message.content # Handle both str and list content types for the final result if isinstance(content, str): - result.result = content + final_result = content elif isinstance(content, list): # Extract text from list of content blocks for final result only. # Concatenate raw string chunks directly, but preserve separation @@ -547,16 +590,16 @@ class SubagentExecutor: text_parts.append(text_val) if pending_str_parts: text_parts.append("".join(pending_str_parts)) - result.result = "\n".join(text_parts) if text_parts else "No text content in response" + final_result = "\n".join(text_parts) if text_parts else "No text content in response" else: - result.result = str(content) + final_result = str(content) elif messages: # Fallback: use the last message if no AIMessage found last_message = messages[-1] logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no AIMessage found, using last message: {type(last_message)}") raw_content = last_message.content if hasattr(last_message, "content") else str(last_message) if isinstance(raw_content, str): - result.result = raw_content + final_result = raw_content elif isinstance(raw_content, list): parts = [] pending_str_parts = [] @@ -572,23 +615,29 @@ class SubagentExecutor: parts.append(text_val) if pending_str_parts: parts.append("".join(pending_str_parts)) - result.result = "\n".join(parts) if parts else "No text content in response" + final_result = "\n".join(parts) if parts else "No text content in response" else: - result.result = str(raw_content) + final_result = str(raw_content) else: logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no messages in final state") - result.result = "No response generated" + final_result = "No response generated" - result.status = SubagentStatus.COMPLETED - result.completed_at = datetime.now() + if final_result is None: + final_result = "No response generated" + + result.try_set_terminal( + SubagentStatus.COMPLETED, + result=final_result, + token_usage_records=token_usage_records, + ) except Exception as e: logger.exception(f"[trace={self.trace_id}] Subagent {self.config.name} async execution failed") - result.status = SubagentStatus.FAILED - result.error = str(e) - result.completed_at = datetime.now() - if collector is not None: - result.token_usage_records = collector.snapshot_records() + result.try_set_terminal( + SubagentStatus.FAILED, + error=str(e), + token_usage_records=collector.snapshot_records() if collector is not None else None, + ) return result @@ -667,11 +716,9 @@ class SubagentExecutor: result = SubagentResult( task_id=str(uuid.uuid4())[:8], trace_id=self.trace_id, - status=SubagentStatus.FAILED, + status=SubagentStatus.RUNNING, ) - result.status = SubagentStatus.FAILED - result.error = str(e) - result.completed_at = datetime.now() + result.try_set_terminal(SubagentStatus.FAILED, error=str(e)) return result def execute_async(self, task: str, task_id: str | None = None) -> str: @@ -718,29 +765,21 @@ class SubagentExecutor: ) try: # Wait for execution with timeout - exec_result = execution_future.result(timeout=self.config.timeout_seconds) - with _background_tasks_lock: - _background_tasks[task_id].status = exec_result.status - _background_tasks[task_id].result = exec_result.result - _background_tasks[task_id].error = exec_result.error - _background_tasks[task_id].completed_at = datetime.now() - _background_tasks[task_id].ai_messages = exec_result.ai_messages + execution_future.result(timeout=self.config.timeout_seconds) except FuturesTimeoutError: logger.error(f"[trace={self.trace_id}] Subagent {self.config.name} execution timed out after {self.config.timeout_seconds}s") - with _background_tasks_lock: - if _background_tasks[task_id].status == SubagentStatus.RUNNING: - _background_tasks[task_id].status = SubagentStatus.TIMED_OUT - _background_tasks[task_id].error = f"Execution timed out after {self.config.timeout_seconds} seconds" - _background_tasks[task_id].completed_at = datetime.now() # Signal cooperative cancellation and cancel the future result_holder.cancel_event.set() + result_holder.try_set_terminal( + SubagentStatus.TIMED_OUT, + error=f"Execution timed out after {self.config.timeout_seconds} seconds", + ) execution_future.cancel() except Exception as e: logger.exception(f"[trace={self.trace_id}] Subagent {self.config.name} async execution failed") with _background_tasks_lock: - _background_tasks[task_id].status = SubagentStatus.FAILED - _background_tasks[task_id].error = str(e) - _background_tasks[task_id].completed_at = datetime.now() + task_result = _background_tasks[task_id] + task_result.try_set_terminal(SubagentStatus.FAILED, error=str(e)) _scheduler_pool.submit(run_task) return task_id @@ -811,13 +850,7 @@ def cleanup_background_task(task_id: str) -> None: # Only clean up tasks that are in a terminal state to avoid races with # the background executor still updating the task entry. - is_terminal_status = result.status in { - SubagentStatus.COMPLETED, - SubagentStatus.FAILED, - SubagentStatus.CANCELLED, - SubagentStatus.TIMED_OUT, - } - if is_terminal_status or result.completed_at is not None: + if result.status.is_terminal or result.completed_at is not None: del _background_tasks[task_id] logger.debug("Cleaned up background task: %s", task_id) else: diff --git a/backend/tests/test_subagent_executor.py b/backend/tests/test_subagent_executor.py index 87c82ff96..8987958a8 100644 --- a/backend/tests/test_subagent_executor.py +++ b/backend/tests/test_subagent_executor.py @@ -1125,6 +1125,15 @@ class TestAsyncToolSupport: class TestThreadSafety: """Test thread safety of executor operations.""" + @pytest.fixture + def executor_module(self, _setup_executor_classes): + """Import the executor module with real classes.""" + import importlib + + from deerflow.subagents import executor + + return importlib.reload(executor) + def test_multiple_executors_in_parallel(self, classes, base_config, msg): """Test multiple executors running in parallel via thread pool.""" from concurrent.futures import ThreadPoolExecutor, as_completed @@ -1170,6 +1179,68 @@ class TestThreadSafety: assert result.status == SubagentStatus.COMPLETED assert "Result" in result.result + def test_terminal_status_is_published_after_payload_fields(self, executor_module, monkeypatch): + """Readers must not observe terminal status before terminal payload is complete.""" + SubagentResult = executor_module.SubagentResult + SubagentStatus = executor_module.SubagentStatus + + now_entered = threading.Event() + release_now = threading.Event() + completed_at = datetime(2026, 5, 1, 12, 0, 0) + writer_errors: list[BaseException] = [] + + class BlockingDateTime: + @staticmethod + def now(): + now_entered.set() + release_now.wait(timeout=5) + return completed_at + + monkeypatch.setattr(executor_module, "datetime", BlockingDateTime) + + result = SubagentResult( + task_id="test-terminal-publication-order", + trace_id="test-trace", + status=SubagentStatus.RUNNING, + ) + token_usage_records = [ + { + "source_run_id": "run-1", + "caller": "subagent:test-agent", + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + } + ] + + def set_terminal(): + try: + assert result.try_set_terminal( + SubagentStatus.COMPLETED, + result="done", + token_usage_records=token_usage_records, + ) + except BaseException as exc: + writer_errors.append(exc) + + writer = threading.Thread(target=set_terminal) + writer.start() + + assert now_entered.wait(timeout=3), "try_set_terminal did not reach completed_at assignment" + assert result.completed_at is None + assert result.status == SubagentStatus.RUNNING + assert result.token_usage_records == token_usage_records + + release_now.set() + writer.join(timeout=3) + + assert not writer.is_alive(), "try_set_terminal did not finish" + assert writer_errors == [] + assert result.completed_at == completed_at + assert result.status == SubagentStatus.COMPLETED + assert result.result == "done" + assert result.token_usage_records == token_usage_records + # ----------------------------------------------------------------------------- # Cleanup Background Task Tests @@ -1604,6 +1675,69 @@ class TestCooperativeCancellation: assert result.error == "Cancelled by user" assert result.completed_at is not None + def test_late_completion_after_timeout_does_not_overwrite_timed_out(self, executor_module, classes, msg): + """Late completion from the execution worker must not overwrite TIMED_OUT.""" + SubagentExecutor = classes["SubagentExecutor"] + SubagentStatus = classes["SubagentStatus"] + + short_config = classes["SubagentConfig"]( + name="test-agent", + description="Test agent", + system_prompt="You are a test agent.", + max_turns=10, + timeout_seconds=0.05, + ) + + first_chunk_seen = threading.Event() + finish_stream = threading.Event() + execution_done = threading.Event() + + async def mock_astream(*args, **kwargs): + yield {"messages": [msg.human("Task"), msg.ai("late completion", "msg-late")]} + first_chunk_seen.set() + deadline = asyncio.get_running_loop().time() + 5 + while not finish_stream.is_set(): + if asyncio.get_running_loop().time() >= deadline: + break + await asyncio.sleep(0.001) + + mock_agent = MagicMock() + mock_agent.astream = mock_astream + + executor = SubagentExecutor( + config=short_config, + tools=[], + thread_id="test-thread", + trace_id="test-trace", + ) + original_aexecute = executor._aexecute + + async def tracked_aexecute(task, result_holder=None): + try: + return await original_aexecute(task, result_holder) + finally: + execution_done.set() + + with patch.object(executor, "_create_agent", return_value=mock_agent), patch.object(executor, "_aexecute", tracked_aexecute): + task_id = executor.execute_async("Task") + assert first_chunk_seen.wait(timeout=3), "stream did not yield initial chunk" + + result = executor_module._background_tasks[task_id] + assert result.cancel_event.wait(timeout=3), "timeout handler did not request cancellation" + assert result.status.value == SubagentStatus.TIMED_OUT.value + timed_out_error = result.error + timed_out_completed_at = result.completed_at + + finish_stream.set() + assert execution_done.wait(timeout=3), "execution worker did not finish" + + result = executor_module._background_tasks.get(task_id) + assert result is not None + assert result.status.value == SubagentStatus.TIMED_OUT.value + assert result.result is None + assert result.error == timed_out_error + assert result.completed_at == timed_out_completed_at + def test_cleanup_removes_cancelled_task(self, executor_module, classes): """Test that cleanup removes a CANCELLED task (terminal state).""" SubagentResult = classes["SubagentResult"] From c810e9f809222ad8449348667b5c438803c7885e Mon Sep 17 00:00:00 2001 From: He Wang Date: Mon, 18 May 2026 22:25:02 +0800 Subject: [PATCH 40/86] fix(harness)!: hydrate runs from RunStore and persist interrupted status (#2932) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(harness): hydrate run history from RunStore and persist cancellation status fix: - Make RunManager.get() async and hydrate from RunStore when in-memory record is missing - Merge store rows into list_by_thread() with in-memory precedence for active runs - Persist interrupted status to RunStore in cancel() and create_or_reject(interrupt|rollback) - Extract _persist_status() to reuse the best-effort store update pattern - Await run_mgr.get() in all gateway endpoints - Return 409 with distinct message for store-only runs not active on current worker Closes #2812, Closes #2813 Co-Authored-By: Claude Opus 4.7 * fix(harness): consistent sort and guarded hydration in RunManager fix: - list_by_thread() now sorts by created_at desc (newest first) even when no RunStore is configured, matching the store-backed code path - guard _record_from_store() call sites in get() and list_by_thread() with best-effort error handling so a single malformed store row cannot turn read paths into 500s test: - update test_list_by_thread assertion to expect newest-first order - seed MemoryRunStore via public put() API instead of writing to _runs * fix(harness): guard store-only runs from streaming and fix get() TOCTOU Add RunRecord.store_only flag set by _record_from_store so callers can distinguish hydrated history from live in-memory runs. join_run and stream_existing_run (action=None) now return 409 instead of hanging forever on an empty MemoryStreamBridge channel. Re-check _runs under lock after the store await in RunManager.get() so a concurrent create() that lands between the two checks returns the authoritative in-memory record rather than a stale store-hydrated copy. Co-Authored-By: Claude Sonnet 4 * fix(harness): reorder bridge fetch in join_run and make list_by_thread limit explicit Move get_stream_bridge() after the store_only guard in join_run so a missing bridge cannot produce 503 for historical runs before the 409 guard fires. Add limit parameter to RunManager.list_by_thread (default 100, matching the store's page size) and pass it explicitly to the store call. Update docstring to document the limit instead of claiming all runs are returned. Co-Authored-By: Claude Sonnet 4 * fix(harness): cap list_by_thread result to limit after merge Apply [:limit] to all return paths in list_by_thread so the method consistently returns at most limit records regardless of how many in-memory runs exist, making the limit parameter a true upper bound on the response size rather than just a store-query hint. Co-Authored-By: Claude Sonnet 4 * fix `list_by_thread` docstring Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix(runtime): add update_model_name to RunStore to prevent SQL integrity errors RunManager.update_model_name() was calling _persist_to_store() which uses RunStore.put(), but RunRepository.put() is insert-only. This caused integrity errors when updating model_name for existing runs in SQL-backed stores. fix: - Add abstract update_model_name method to RunStore base class - Implement update_model_name in MemoryRunStore - Implement update_model_name in RunRepository with proper normalization - Add _persist_model_name helper in RunManager - Update RunManager.update_model_name to use the new method test: - Add tests for update_model_name functionality - Add integration tests for RunManager with SQL-backed store Co-Authored-By: Claude Opus 4.7 * fix(runtime): handle NULL status/on_disconnect in _record_from_store `dict.get(key, default)` only uses the default when the key is absent, so a SQL row with an explicit NULL status would pass `None` to `RunStatus(None)` and raise, breaking hydration for otherwise valid rows. Switch to `row.get(...) or fallback` so both missing and NULL values get a safe default. Add tests for get() and list_by_thread() with a NULL status row to prevent regression. Co-Authored-By: Claude Sonnet 4 * fix(runs): address PR review feedback on store consistency changes - Fix list_by_thread limit semantics: pass store_limit = max(0, limit - len(memory_records)) to store so newer store records are not crowded out by in-memory records - Remove dead code: cancelled guard after raise is always True, simplify to if wait and record.task - Document _record_from_store NULL fallback policy (status→pending, on_disconnect→cancel) in docstring Co-Authored-By: Claude Sonnet 4 --------- Co-authored-by: Claude Opus 4.7 Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- backend/CLAUDE.md | 6 + backend/app/gateway/routers/thread_runs.py | 31 ++-- .../harness/deerflow/persistence/run/sql.py | 5 + .../harness/deerflow/runtime/runs/manager.py | 171 +++++++++++------ .../deerflow/runtime/runs/store/base.py | 9 + .../deerflow/runtime/runs/store/memory.py | 5 + backend/tests/test_run_manager.py | 175 +++++++++++++++++- backend/tests/test_run_repository.py | 103 +++++++++++ backend/tests/test_run_worker_rollback.py | 4 +- .../test_thread_run_messages_pagination.py | 67 ++++++- 10 files changed, 499 insertions(+), 77 deletions(-) diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 35607c6fd..b951f919c 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -225,6 +225,12 @@ CORS is same-origin by default when requests enter through nginx on port 2026. S | **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific | | **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id | +**RunManager / RunStore contract**: +- `RunManager.get()` is async; direct callers must `await` it. +- When a persistent `RunStore` is configured, `get()` and `list_by_thread()` hydrate historical runs from the store. In-memory records win for the same `run_id` so task, abort, and stream-control state stays attached to active local runs. +- `cancel()` and `create_or_reject(..., multitask_strategy="interrupt"|"rollback")` persist interrupted status through `RunStore.update_status()`, matching normal `set_status()` transitions. +- Store-only hydrated runs are readable history. If the current worker has no in-memory task/control state for that run, cancellation APIs can return 409 because this worker cannot stop the task. + Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runtime, all other `/api/*` → Gateway REST APIs. ### Sandbox System (`packages/harness/deerflow/sandbox/`) diff --git a/backend/app/gateway/routers/thread_runs.py b/backend/app/gateway/routers/thread_runs.py index 3d429fc03..294fa9799 100644 --- a/backend/app/gateway/routers/thread_runs.py +++ b/backend/app/gateway/routers/thread_runs.py @@ -22,7 +22,7 @@ from pydantic import BaseModel, Field from app.gateway.authz import require_permission from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge from app.gateway.services import sse_consumer, start_run -from deerflow.runtime import RunRecord, serialize_channel_values +from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/threads", tags=["runs"]) @@ -94,6 +94,12 @@ class ThreadTokenUsageResponse(BaseModel): # --------------------------------------------------------------------------- +def _cancel_conflict_detail(run_id: str, record: RunRecord) -> str: + if record.status in (RunStatus.pending, RunStatus.running): + return f"Run {run_id} is not active on this worker and cannot be cancelled" + return f"Run {run_id} is not cancellable (status: {record.status.value})" + + def _record_to_response(record: RunRecord) -> RunResponse: return RunResponse( run_id=record.run_id, @@ -191,7 +197,7 @@ async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse: """Get details of a specific run.""" run_mgr = get_run_manager(request) user_id = await get_current_user(request) - record = await run_mgr.aget(run_id, user_id=user_id) + record = await run_mgr.get(run_id, user_id=user_id) if record is None or record.thread_id != thread_id: raise HTTPException(status_code=404, detail=f"Run {run_id} not found") return _record_to_response(record) @@ -214,16 +220,13 @@ async def cancel_run( - wait=false: Return immediately with 202 """ run_mgr = get_run_manager(request) - record = run_mgr.get(run_id) + record = await run_mgr.get(run_id) if record is None or record.thread_id != thread_id: raise HTTPException(status_code=404, detail=f"Run {run_id} not found") cancelled = await run_mgr.cancel(run_id, action=action) if not cancelled: - raise HTTPException( - status_code=409, - detail=f"Run {run_id} is not cancellable (status: {record.status.value})", - ) + raise HTTPException(status_code=409, detail=_cancel_conflict_detail(run_id, record)) if wait and record.task is not None: try: @@ -239,12 +242,14 @@ async def cancel_run( @require_permission("runs", "read", owner_check=True) async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse: """Join an existing run's SSE stream.""" - bridge = get_stream_bridge(request) run_mgr = get_run_manager(request) - record = run_mgr.get(run_id) + record = await run_mgr.get(run_id) if record is None or record.thread_id != thread_id: raise HTTPException(status_code=404, detail=f"Run {run_id} not found") + if record.store_only: + raise HTTPException(status_code=409, detail=f"Run {run_id} is not active on this worker and cannot be streamed") + bridge = get_stream_bridge(request) return StreamingResponse( sse_consumer(bridge, record, request, run_mgr), media_type="text/event-stream", @@ -273,14 +278,18 @@ async def stream_existing_run( remaining buffered events so the client observes a clean shutdown. """ run_mgr = get_run_manager(request) - record = run_mgr.get(run_id) + record = await run_mgr.get(run_id) if record is None or record.thread_id != thread_id: raise HTTPException(status_code=404, detail=f"Run {run_id} not found") + if record.store_only and action is None: + raise HTTPException(status_code=409, detail=f"Run {run_id} is not active on this worker and cannot be streamed") # Cancel if an action was requested (stop-button / interrupt flow) if action is not None: cancelled = await run_mgr.cancel(run_id, action=action) - if cancelled and wait and record.task is not None: + if not cancelled: + raise HTTPException(status_code=409, detail=_cancel_conflict_detail(run_id, record)) + if wait and record.task is not None: try: await record.task except (asyncio.CancelledError, Exception): diff --git a/backend/packages/harness/deerflow/persistence/run/sql.py b/backend/packages/harness/deerflow/persistence/run/sql.py index 5331451e3..d586a2b13 100644 --- a/backend/packages/harness/deerflow/persistence/run/sql.py +++ b/backend/packages/harness/deerflow/persistence/run/sql.py @@ -151,6 +151,11 @@ class RunRepository(RunStore): await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) await session.commit() + async def update_model_name(self, run_id, model_name): + async with self._sf() as session: + await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(model_name=self._normalize_model_name(model_name), updated_at=datetime.now(UTC))) + await session.commit() + async def delete( self, run_id, diff --git a/backend/packages/harness/deerflow/runtime/runs/manager.py b/backend/packages/harness/deerflow/runtime/runs/manager.py index 11d6b478e..06731eb91 100644 --- a/backend/packages/harness/deerflow/runtime/runs/manager.py +++ b/backend/packages/harness/deerflow/runtime/runs/manager.py @@ -6,7 +6,7 @@ import asyncio import logging import uuid from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from deerflow.utils.time import now_iso as _now_iso @@ -37,6 +37,7 @@ class RunRecord: abort_action: str = "interrupt" error: str | None = None model_name: str | None = None + store_only: bool = False class RunManager: @@ -71,6 +72,38 @@ class RunManager: except Exception: logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True) + async def _persist_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None: + """Best-effort persist a status transition to the backing store.""" + if self._store is None: + return + try: + await self._store.update_status(run_id, status.value, error=error) + except Exception: + logger.warning("Failed to persist status update for run %s", run_id, exc_info=True) + + @staticmethod + def _record_from_store(row: dict[str, Any]) -> RunRecord: + """Build a read-only runtime record from a serialized store row. + + NULL status/on_disconnect columns (e.g. from rows written before those + columns were added) default to ``pending`` and ``cancel`` respectively. + """ + return RunRecord( + run_id=row["run_id"], + thread_id=row["thread_id"], + assistant_id=row.get("assistant_id"), + status=RunStatus(row.get("status") or RunStatus.pending.value), + on_disconnect=DisconnectMode(row.get("on_disconnect") or DisconnectMode.cancel.value), + multitask_strategy=row.get("multitask_strategy") or "reject", + metadata=row.get("metadata") or {}, + kwargs=row.get("kwargs") or {}, + created_at=row.get("created_at") or "", + updated_at=row.get("updated_at") or "", + error=row.get("error"), + model_name=row.get("model_name"), + store_only=True, + ) + async def update_run_completion(self, run_id: str, **kwargs) -> None: """Persist token usage and completion data to the backing store.""" if self._store is not None: @@ -110,61 +143,77 @@ class RunManager: logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record - def get(self, run_id: str) -> RunRecord | None: - """Return an in-memory run record by ID, or ``None``.""" - return self._runs.get(run_id) + async def get(self, run_id: str, *, user_id: str | None = None) -> RunRecord | None: + """Return a run record by ID, or ``None``. - async def aget(self, run_id: str, *, user_id: str | None = None) -> RunRecord | None: - """Return a run record by ID, checking the persistent store as fallback.""" - record = self._runs.get(run_id) + Args: + run_id: The run ID to look up. + user_id: Optional user ID for permission filtering when hydrating from store. + """ + async with self._lock: + record = self._runs.get(run_id) if record is not None: return record - if self._store is not None: - try: - d = await self._store.get(run_id, user_id=user_id) - if d is not None: - return self._store_dict_to_record(d) - except Exception: - logger.warning("Failed to query store for run %s", run_id, exc_info=True) - return None - - def _store_dict_to_record(self, d: dict) -> RunRecord: - """Convert a store dict back to a RunRecord for read-only use.""" - return RunRecord( - run_id=d["run_id"], - thread_id=d["thread_id"], - assistant_id=d.get("assistant_id"), - status=RunStatus(d.get("status", RunStatus.error.value)), - on_disconnect=DisconnectMode.cancel, - multitask_strategy=d.get("multitask_strategy", "reject"), - metadata=d.get("metadata", {}), - kwargs=d.get("kwargs", {}), - created_at=d.get("created_at", ""), - updated_at=d.get("updated_at", ""), - model_name=d.get("model_name"), - error=d.get("error"), - ) - - async def list_by_thread(self, thread_id: str, *, user_id: str | None = None) -> list[RunRecord]: - """Return all runs for a given thread, oldest first.""" + if self._store is None: + return None + try: + row = await self._store.get(run_id, user_id=user_id) + except Exception: + logger.warning("Failed to hydrate run %s from store", run_id, exc_info=True) + return None + # Re-check after store await: a concurrent create() may have inserted the + # in-memory record while the store call was in flight. async with self._lock: - in_memory = [r for r in self._runs.values() if r.thread_id == thread_id] - in_memory_ids = {r.run_id for r in in_memory} + record = self._runs.get(run_id) + if record is not None: + return record + if row is None: + return None + try: + return self._record_from_store(row) + except Exception: + logger.warning("Failed to map store row for run %s", run_id, exc_info=True) + return None - store_records: list[RunRecord] = [] - if self._store is not None: - try: - store_dicts = await self._store.list_by_thread(thread_id, user_id=user_id) - for d in store_dicts: - if d["run_id"] not in in_memory_ids: - store_records.append(self._store_dict_to_record(d)) - except Exception: - logger.warning("Failed to query store for thread %s runs", thread_id, exc_info=True) + async def aget(self, run_id: str, *, user_id: str | None = None) -> RunRecord | None: + """Return a run record by ID, checking the persistent store as fallback. - return sorted( - in_memory + store_records, - key=lambda record: record.created_at or "", - ) + Alias for :meth:`get` for backward compatibility. + """ + return await self.get(run_id, user_id=user_id) + + async def list_by_thread(self, thread_id: str, *, user_id: str | None = None, limit: int = 100) -> list[RunRecord]: + """Return runs for a given thread, newest first, at most ``limit`` records. + + In-memory runs take precedence only when the same ``run_id`` exists in both + memory and the backing store. The merged result is then sorted newest-first + by ``created_at`` and trimmed to ``limit`` (default 100). + + Args: + thread_id: The thread ID to filter by. + user_id: Optional user ID for permission filtering when hydrating from store. + limit: Maximum number of runs to return. + """ + async with self._lock: + # Dict insertion order gives deterministic results when timestamps tie. + memory_records = [r for r in self._runs.values() if r.thread_id == thread_id] + if self._store is None: + return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit] + records_by_id = {record.run_id: record for record in memory_records} + store_limit = max(0, limit - len(memory_records)) + try: + rows = await self._store.list_by_thread(thread_id, user_id=user_id, limit=store_limit) + except Exception: + logger.warning("Failed to hydrate runs for thread %s from store", thread_id, exc_info=True) + return sorted(memory_records, key=lambda r: r.created_at, reverse=True)[:limit] + for row in rows: + run_id = row.get("run_id") + if run_id and run_id not in records_by_id: + try: + records_by_id[run_id] = self._record_from_store(row) + except Exception: + logger.warning("Failed to map store row for run %s", run_id, exc_info=True) + return sorted(records_by_id.values(), key=lambda record: record.created_at, reverse=True)[:limit] async def set_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None: """Transition a run to a new status.""" @@ -177,13 +226,18 @@ class RunManager: record.updated_at = _now_iso() if error is not None: record.error = error - if self._store is not None: - try: - await self._store.update_status(run_id, status.value, error=error) - except Exception: - logger.warning("Failed to persist status update for run %s", run_id, exc_info=True) + await self._persist_status(run_id, status, error=error) logger.info("Run %s -> %s", run_id, status.value) + async def _persist_model_name(self, run_id: str, model_name: str | None) -> None: + """Best-effort persist model_name update to the backing store.""" + if self._store is None: + return + try: + await self._store.update_model_name(run_id, model_name) + except Exception: + logger.warning("Failed to persist model_name update for run %s", run_id, exc_info=True) + async def update_model_name(self, run_id: str, model_name: str | None) -> None: """Update the model name for a run.""" async with self._lock: @@ -193,7 +247,7 @@ class RunManager: return record.model_name = model_name record.updated_at = _now_iso() - await self._persist_to_store(record) + await self._persist_model_name(run_id, model_name) logger.info("Run %s model_name=%s", run_id, model_name) async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool: @@ -218,6 +272,7 @@ class RunManager: record.task.cancel() record.status = RunStatus.interrupted record.updated_at = _now_iso() + await self._persist_status(run_id, RunStatus.interrupted) logger.info("Run %s cancelled (action=%s)", run_id, action) return True @@ -245,6 +300,7 @@ class RunManager: now = _now_iso() _supported_strategies = ("reject", "interrupt", "rollback") + interrupted_run_ids: list[str] = [] async with self._lock: if multitask_strategy not in _supported_strategies: @@ -263,6 +319,7 @@ class RunManager: r.task.cancel() r.status = RunStatus.interrupted r.updated_at = now + interrupted_run_ids.append(r.run_id) logger.info( "Cancelled %d inflight run(s) on thread %s (strategy=%s)", len(inflight), @@ -285,6 +342,8 @@ class RunManager: ) self._runs[run_id] = record + for interrupted_run_id in interrupted_run_ids: + await self._persist_status(interrupted_run_id, RunStatus.interrupted) await self._persist_to_store(record) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record diff --git a/backend/packages/harness/deerflow/runtime/runs/store/base.py b/backend/packages/harness/deerflow/runtime/runs/store/base.py index a742d89ca..10c90d7ea 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/base.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/base.py @@ -66,6 +66,15 @@ class RunStore(abc.ABC): async def delete(self, run_id: str) -> None: pass + @abc.abstractmethod + async def update_model_name( + self, + run_id: str, + model_name: str | None, + ) -> None: + """Update the model_name field for an existing run.""" + pass + @abc.abstractmethod async def update_run_completion( self, diff --git a/backend/packages/harness/deerflow/runtime/runs/store/memory.py b/backend/packages/harness/deerflow/runtime/runs/store/memory.py index 9db27cacc..56ef02b5b 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/memory.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/memory.py @@ -66,6 +66,11 @@ class MemoryRunStore(RunStore): self._runs[run_id]["error"] = error self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat() + async def update_model_name(self, run_id, model_name): + if run_id in self._runs: + self._runs[run_id]["model_name"] = model_name + self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat() + async def delete(self, run_id): self._runs.pop(run_id, None) diff --git a/backend/tests/test_run_manager.py b/backend/tests/test_run_manager.py index de8f66319..e7b5f06f5 100644 --- a/backend/tests/test_run_manager.py +++ b/backend/tests/test_run_manager.py @@ -4,7 +4,7 @@ import re import pytest -from deerflow.runtime import RunManager, RunStatus +from deerflow.runtime import DisconnectMode, RunManager, RunStatus from deerflow.runtime.runs.store.memory import MemoryRunStore ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}") @@ -34,7 +34,7 @@ async def test_create_and_get(manager: RunManager): assert ISO_RE.match(record.created_at) assert ISO_RE.match(record.updated_at) - fetched = manager.get(record.run_id) + fetched = await manager.get(record.run_id) assert fetched is record @@ -64,6 +64,22 @@ async def test_cancel(manager: RunManager): assert record.status == RunStatus.interrupted +@pytest.mark.anyio +async def test_cancel_persists_interrupted_status_to_store(): + """Cancel should persist interrupted status to the backing store.""" + store = MemoryRunStore() + manager = RunManager(store=store) + record = await manager.create("thread-1") + await manager.set_status(record.run_id, RunStatus.running) + + cancelled = await manager.cancel(record.run_id) + + stored = await store.get(record.run_id) + assert cancelled is True + assert stored is not None + assert stored["status"] == "interrupted" + + @pytest.mark.anyio async def test_cancel_not_inflight(manager: RunManager): """Cancelling a completed run should return False.""" @@ -83,9 +99,9 @@ async def test_list_by_thread(manager: RunManager): runs = await manager.list_by_thread("thread-1") assert len(runs) == 2 - # list_by_thread returns oldest-first (ascending created_at). - assert runs[0].run_id == r1.run_id - assert runs[1].run_id == r2.run_id + # Newest first: r2 was created after r1. + assert runs[0].run_id == r2.run_id + assert runs[1].run_id == r1.run_id @pytest.mark.anyio @@ -117,7 +133,7 @@ async def test_cleanup(manager: RunManager): run_id = record.run_id await manager.cleanup(run_id, delay=0) - assert manager.get(run_id) is None + assert await manager.get(run_id) is None @pytest.mark.anyio @@ -132,7 +148,116 @@ async def test_set_status_with_error(manager: RunManager): @pytest.mark.anyio async def test_get_nonexistent(manager: RunManager): """Getting a nonexistent run should return None.""" - assert manager.get("does-not-exist") is None + assert await manager.get("does-not-exist") is None + + +@pytest.mark.anyio +async def test_get_hydrates_store_only_run(): + """Store-only runs should be readable after process restart.""" + store = MemoryRunStore() + await store.put( + "run-store-only", + thread_id="thread-1", + assistant_id="lead_agent", + status="success", + multitask_strategy="reject", + metadata={"source": "store"}, + kwargs={"input": "value"}, + created_at="2026-01-01T00:00:00+00:00", + model_name="model-a", + ) + manager = RunManager(store=store) + + record = await manager.get("run-store-only") + + assert record is not None + assert record.run_id == "run-store-only" + assert record.thread_id == "thread-1" + assert record.assistant_id == "lead_agent" + assert record.status == RunStatus.success + assert record.on_disconnect == DisconnectMode.cancel + assert record.metadata == {"source": "store"} + assert record.kwargs == {"input": "value"} + assert record.model_name == "model-a" + assert record.task is None + assert record.store_only is True + + +@pytest.mark.anyio +async def test_get_hydrates_run_with_null_enum_fields(): + """Rows with NULL status/on_disconnect must hydrate with safe defaults, not raise.""" + store = MemoryRunStore() + # Simulate a SQL row where the nullable status column is NULL + await store.put( + "run-null-status", + thread_id="thread-1", + status=None, + created_at="2026-01-01T00:00:00+00:00", + ) + manager = RunManager(store=store) + + record = await manager.get("run-null-status") + + assert record is not None + assert record.status == RunStatus.pending + assert record.on_disconnect == DisconnectMode.cancel + assert record.store_only is True + + +@pytest.mark.anyio +async def test_list_by_thread_hydrates_run_with_null_enum_fields(): + """list_by_thread must not skip rows with NULL status; applies safe defaults.""" + store = MemoryRunStore() + await store.put( + "run-null-status-list", + thread_id="thread-null", + status=None, + created_at="2026-01-01T00:00:00+00:00", + ) + manager = RunManager(store=store) + + runs = await manager.list_by_thread("thread-null") + + assert len(runs) == 1 + assert runs[0].run_id == "run-null-status-list" + assert runs[0].status == RunStatus.pending + assert runs[0].on_disconnect == DisconnectMode.cancel + + +@pytest.mark.anyio +async def test_create_record_is_not_store_only(manager: RunManager): + """In-memory records created via create() must have store_only=False.""" + record = await manager.create("thread-1") + assert record.store_only is False + + +@pytest.mark.anyio +async def test_get_prefers_in_memory_record_over_store(): + """In-memory records retain task/control state when store has same run.""" + store = MemoryRunStore() + manager = RunManager(store=store) + record = await manager.create("thread-1") + await store.update_status(record.run_id, "success") + + fetched = await manager.get(record.run_id) + + assert fetched is record + assert fetched.status == RunStatus.pending + + +@pytest.mark.anyio +async def test_list_by_thread_merges_store_runs_newest_first(): + """list_by_thread should merge memory and store rows with memory precedence.""" + store = MemoryRunStore() + await store.put("old-store", thread_id="thread-1", status="success", created_at="2026-01-01T00:00:00+00:00") + await store.put("other-thread", thread_id="thread-2", status="success", created_at="2026-01-03T00:00:00+00:00") + manager = RunManager(store=store) + memory_record = await manager.create("thread-1") + + runs = await manager.list_by_thread("thread-1") + + assert [run.run_id for run in runs] == [memory_record.run_id, "old-store"] + assert runs[0] is memory_record @pytest.mark.anyio @@ -171,11 +296,45 @@ async def test_model_name_create_or_reject(): assert stored["model_name"] == "anthropic.claude-sonnet-4-20250514-v1:0" # Verify retrieval returns the model_name via in-memory record - fetched = mgr.get(record.run_id) + fetched = await mgr.get(record.run_id) assert fetched is not None assert fetched.model_name == "anthropic.claude-sonnet-4-20250514-v1:0" +@pytest.mark.anyio +async def test_create_or_reject_interrupt_persists_interrupted_status_to_store(): + """interrupt strategy should persist interrupted status for old runs.""" + store = MemoryRunStore() + manager = RunManager(store=store) + old = await manager.create("thread-1") + await manager.set_status(old.run_id, RunStatus.running) + + new = await manager.create_or_reject("thread-1", multitask_strategy="interrupt") + + stored_old = await store.get(old.run_id) + assert new.run_id != old.run_id + assert old.status == RunStatus.interrupted + assert stored_old is not None + assert stored_old["status"] == "interrupted" + + +@pytest.mark.anyio +async def test_create_or_reject_rollback_persists_interrupted_status_to_store(): + """rollback strategy should persist interrupted status for old runs.""" + store = MemoryRunStore() + manager = RunManager(store=store) + old = await manager.create("thread-1") + await manager.set_status(old.run_id, RunStatus.running) + + new = await manager.create_or_reject("thread-1", multitask_strategy="rollback") + + stored_old = await store.get(old.run_id) + assert new.run_id != old.run_id + assert old.status == RunStatus.interrupted + assert stored_old is not None + assert stored_old["status"] == "interrupted" + + @pytest.mark.anyio async def test_model_name_default_is_none(): """create_or_reject without model_name should default to None.""" diff --git a/backend/tests/test_run_repository.py b/backend/tests/test_run_repository.py index 5e230e790..5809db517 100644 --- a/backend/tests/test_run_repository.py +++ b/backend/tests/test_run_repository.py @@ -9,6 +9,7 @@ import pytest from sqlalchemy.dialects import postgresql from deerflow.persistence.run import RunRepository +from deerflow.runtime import RunManager, RunStatus async def _make_repo(tmp_path): @@ -326,3 +327,105 @@ class TestRunRepository: assert select_match is not None assert group_by_match is not None assert select_match.group(1) == group_by_match.group(1) + + @pytest.mark.anyio + async def test_run_manager_hydrates_store_only_run_from_sql(self, tmp_path): + """RunManager should hydrate historical runs from SQL-backed store.""" + repo = await _make_repo(tmp_path) + await repo.put( + "sql-store-only", + thread_id="thread-1", + assistant_id="lead_agent", + status="success", + metadata={"source": "sql"}, + kwargs={"input": "value"}, + model_name="model-a", + ) + manager = RunManager(store=repo) + + record = await manager.get("sql-store-only") + rows = await manager.list_by_thread("thread-1") + + assert record is not None + assert record.run_id == "sql-store-only" + assert record.status == RunStatus.success + assert record.metadata == {"source": "sql"} + assert record.kwargs == {"input": "value"} + assert record.model_name == "model-a" + assert [run.run_id for run in rows] == ["sql-store-only"] + await _cleanup() + + @pytest.mark.anyio + async def test_run_manager_cancel_persists_interrupted_status_to_sql(self, tmp_path): + """RunManager.cancel should write interrupted status to SQL-backed store.""" + repo = await _make_repo(tmp_path) + manager = RunManager(store=repo) + record = await manager.create("thread-1") + await manager.set_status(record.run_id, RunStatus.running) + + cancelled = await manager.cancel(record.run_id) + row = await repo.get(record.run_id) + + assert cancelled is True + assert row is not None + assert row["status"] == "interrupted" + await _cleanup() + + @pytest.mark.anyio + async def test_update_model_name(self, tmp_path): + """RunRepository.update_model_name should update model_name for existing run.""" + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", model_name="initial-model") + await repo.update_model_name("r1", "updated-model") + row = await repo.get("r1") + assert row["model_name"] == "updated-model" + await _cleanup() + + @pytest.mark.anyio + async def test_update_model_name_normalizes_value(self, tmp_path): + """RunRepository.update_model_name should normalize and truncate model_name.""" + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1") + long_name = "a" * 200 + await repo.update_model_name("r1", long_name) + row = await repo.get("r1") + assert row["model_name"] == "a" * 128 + await _cleanup() + + @pytest.mark.anyio + async def test_update_model_name_to_none(self, tmp_path): + """RunRepository.update_model_name should allow setting model_name to None.""" + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", model_name="initial-model") + await repo.update_model_name("r1", None) + row = await repo.get("r1") + assert row["model_name"] is None + await _cleanup() + + @pytest.mark.anyio + async def test_run_manager_update_model_name_persists_to_sql(self, tmp_path): + """RunManager.update_model_name should persist to SQL-backed store without integrity error.""" + repo = await _make_repo(tmp_path) + manager = RunManager(store=repo) + record = await manager.create("thread-1") + + await manager.update_model_name(record.run_id, "gpt-4o") + + row = await repo.get(record.run_id) + assert row is not None + assert row["model_name"] == "gpt-4o" + await _cleanup() + + @pytest.mark.anyio + async def test_run_manager_update_model_name_twice(self, tmp_path): + """RunManager.update_model_name should support multiple updates.""" + repo = await _make_repo(tmp_path) + manager = RunManager(store=repo) + record = await manager.create("thread-1") + + await manager.update_model_name(record.run_id, "model-1") + await manager.update_model_name(record.run_id, "model-2") + + row = await repo.get(record.run_id) + assert row["model_name"] == "model-2" + await _cleanup() diff --git a/backend/tests/test_run_worker_rollback.py b/backend/tests/test_run_worker_rollback.py index 0a4421e2f..72e3ac98e 100644 --- a/backend/tests/test_run_worker_rollback.py +++ b/backend/tests/test_run_worker_rollback.py @@ -88,7 +88,9 @@ async def test_run_agent_threads_explicit_app_config_into_config_only_factory(): assert captured["factory_context"]["app_config"] is app_config assert captured["astream_context"]["app_config"] is app_config - assert run_manager.get(record.run_id).status == RunStatus.success + fetched = await run_manager.get(record.run_id) + assert fetched is not None + assert fetched.status == RunStatus.success bridge.publish_end.assert_awaited_once_with(record.run_id) bridge.cleanup.assert_awaited_once_with(record.run_id, delay=60) diff --git a/backend/tests/test_thread_run_messages_pagination.py b/backend/tests/test_thread_run_messages_pagination.py index 00e354a34..9098e2b73 100644 --- a/backend/tests/test_thread_run_messages_pagination.py +++ b/backend/tests/test_thread_run_messages_pagination.py @@ -2,25 +2,30 @@ from __future__ import annotations +import asyncio from unittest.mock import AsyncMock, MagicMock from _router_auth_helpers import make_authed_test_app from fastapi.testclient import TestClient from app.gateway.routers import thread_runs +from deerflow.runtime import RunManager +from deerflow.runtime.runs.store.memory import MemoryRunStore # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- -def _make_app(event_store=None): +def _make_app(event_store=None, run_manager=None): """Build a test FastAPI app with stub auth and mocked state.""" app = make_authed_test_app() app.include_router(thread_runs.router) if event_store is not None: app.state.run_event_store = event_store + if run_manager is not None: + app.state.run_manager = run_manager return app @@ -36,6 +41,23 @@ def _make_message(seq: int) -> dict: return {"seq": seq, "event_type": "ai_message", "category": "message", "content": f"msg-{seq}"} +def _make_store_only_run_manager() -> RunManager: + store = MemoryRunStore() + asyncio.run( + store.put( + "store-only-run", + thread_id="thread-store", + assistant_id="lead_agent", + status="running", + multitask_strategy="reject", + metadata={}, + kwargs={}, + created_at="2026-01-01T00:00:00+00:00", + ) + ) + return RunManager(store=store) + + # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- @@ -128,3 +150,46 @@ def test_empty_data_when_no_messages(): body = response.json() assert body["data"] == [] assert body["has_more"] is False + + +def test_get_run_hydrates_store_only_run(): + """GET /api/threads/{tid}/runs/{rid} should read historical store rows.""" + app = _make_app(run_manager=_make_store_only_run_manager()) + with TestClient(app) as client: + response = client.get("/api/threads/thread-store/runs/store-only-run") + + assert response.status_code == 200 + body = response.json() + assert body["run_id"] == "store-only-run" + assert body["thread_id"] == "thread-store" + assert body["status"] == "running" + + +def test_cancel_store_only_run_returns_409(): + """Store-only runs are readable but not cancellable by this worker.""" + app = _make_app(run_manager=_make_store_only_run_manager()) + with TestClient(app) as client: + response = client.post("/api/threads/thread-store/runs/store-only-run/cancel") + + assert response.status_code == 409 + assert "not active on this worker" in response.json()["detail"] + + +def test_join_store_only_run_returns_409(): + """join endpoint should return 409 for store-only runs (no local stream state).""" + app = _make_app(run_manager=_make_store_only_run_manager()) + with TestClient(app) as client: + response = client.get("/api/threads/thread-store/runs/store-only-run/join") + + assert response.status_code == 409 + assert "not active on this worker" in response.json()["detail"] + + +def test_stream_store_only_run_returns_409(): + """stream endpoint (action=None) should return 409 for store-only runs.""" + app = _make_app(run_manager=_make_store_only_run_manager()) + with TestClient(app) as client: + response = client.get("/api/threads/thread-store/runs/store-only-run/stream") + + assert response.status_code == 409 + assert "not active on this worker" in response.json()["detail"] From 3599b570a92a34fa27f4ca4979f0b0a04eb6ad1c Mon Sep 17 00:00:00 2001 From: AochenShen99 <142667174+ShenAC-SAC@users.noreply.github.com> Date: Tue, 19 May 2026 22:11:46 +0800 Subject: [PATCH 41/86] fix(harness): wrap all async-only tools for sync clients (#2935) --- .../packages/harness/deerflow/tools/sync.py | 66 ++++++++++++-- .../packages/harness/deerflow/tools/tools.py | 2 +- backend/tests/test_invoke_acp_agent_tool.py | 86 +++++++++++++++++++ backend/tests/test_mcp_sync_wrapper.py | 54 ++++++++++++ backend/tests/test_tool_deduplication.py | 58 +++++++++++++ 5 files changed, 260 insertions(+), 6 deletions(-) diff --git a/backend/packages/harness/deerflow/tools/sync.py b/backend/packages/harness/deerflow/tools/sync.py index c2b80781a..7521dd7b3 100644 --- a/backend/packages/harness/deerflow/tools/sync.py +++ b/backend/packages/harness/deerflow/tools/sync.py @@ -3,9 +3,13 @@ import asyncio import atexit import concurrent.futures +import contextvars +import functools import logging from collections.abc import Callable -from typing import Any +from typing import Any, get_type_hints + +from langchain_core.runnables import RunnableConfig logger = logging.getLogger(__name__) @@ -15,10 +19,49 @@ _SYNC_TOOL_EXECUTOR = concurrent.futures.ThreadPoolExecutor(max_workers=10, thre atexit.register(lambda: _SYNC_TOOL_EXECUTOR.shutdown(wait=False)) -def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]: - """Build a synchronous wrapper for an asynchronous tool coroutine.""" +def _get_runnable_config_param(func: Callable[..., Any]) -> str | None: + """Return the coroutine parameter that expects LangChain RunnableConfig.""" + if isinstance(func, functools.partial): + func = func.func - def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + try: + type_hints = get_type_hints(func) + except Exception: + return None + + for name, type_ in type_hints.items(): + if type_ is RunnableConfig: + return name + return None + + +def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable[..., Any]: + """Build a synchronous wrapper for an asynchronous tool coroutine. + + Args: + coro: Async callable backing a LangChain tool. + tool_name: Tool name used in error logs. + + Returns: + A sync callable suitable for ``BaseTool.func``. + + Notes: + If ``coro`` declares a ``RunnableConfig`` parameter, this wrapper + exposes ``config: RunnableConfig`` so LangChain can inject runtime + config and then forwards it to the coroutine's detected config + parameter. This covers DeerFlow's current config-sensitive tools, such + as ``invoke_acp_agent``. + + This wrapper intentionally does not synthesize a dynamic function + signature. A future async tool with a normal user-facing argument named + ``config`` and a separate ``RunnableConfig`` parameter named something + else, such as ``run_config``, may collide with LangChain's injected + ``config`` argument. Rename that user-facing field or extend this + helper before using that signature. + """ + config_param = _get_runnable_config_param(coro) + + def run_coroutine(*args: Any, **kwargs: Any) -> Any: try: loop = asyncio.get_running_loop() except RuntimeError: @@ -26,11 +69,24 @@ def make_sync_tool_wrapper(coro: Callable[..., Any], tool_name: str) -> Callable try: if loop is not None and loop.is_running(): - future = _SYNC_TOOL_EXECUTOR.submit(asyncio.run, coro(*args, **kwargs)) + context = contextvars.copy_context() + future = _SYNC_TOOL_EXECUTOR.submit(context.run, lambda: asyncio.run(coro(*args, **kwargs))) return future.result() return asyncio.run(coro(*args, **kwargs)) except Exception as e: logger.error("Error invoking tool %r via sync wrapper: %s", tool_name, e, exc_info=True) raise + if config_param: + + def sync_wrapper(*args: Any, config: RunnableConfig = None, **kwargs: Any) -> Any: + if config is not None or config_param not in kwargs: + kwargs[config_param] = config + return run_coroutine(*args, **kwargs) + + return sync_wrapper + + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + return run_coroutine(*args, **kwargs) + return sync_wrapper diff --git a/backend/packages/harness/deerflow/tools/tools.py b/backend/packages/harness/deerflow/tools/tools.py index 5c97962fc..bc2caed43 100644 --- a/backend/packages/harness/deerflow/tools/tools.py +++ b/backend/packages/harness/deerflow/tools/tools.py @@ -205,7 +205,7 @@ def get_available_tools( # Deduplicate by tool name — config-loaded tools take priority, followed by # built-ins, MCP tools, and ACP tools. Duplicate names cause the LLM to # receive ambiguous or concatenated function schemas (issue #1803). - all_tools = loaded_tools + builtin_tools + mcp_tools + acp_tools + all_tools = [_ensure_sync_invocable_tool(t) for t in loaded_tools + builtin_tools + mcp_tools + acp_tools] seen_names: set[str] = set() unique_tools: list[BaseTool] = [] for t in all_tools: diff --git a/backend/tests/test_invoke_acp_agent_tool.py b/backend/tests/test_invoke_acp_agent_tool.py index 8c44403b8..deace5b4e 100644 --- a/backend/tests/test_invoke_acp_agent_tool.py +++ b/backend/tests/test_invoke_acp_agent_tool.py @@ -699,6 +699,92 @@ def test_get_available_tools_includes_invoke_acp_agent_when_agents_configured(mo load_acp_config_from_dict({}) +def test_get_available_tools_sync_invoke_acp_agent_preserves_thread_workspace(monkeypatch, tmp_path): + from deerflow.config import paths as paths_module + from deerflow.runtime import user_context as uc_module + + monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path)) + monkeypatch.setattr(uc_module, "get_effective_user_id", lambda: None) + monkeypatch.setattr( + "deerflow.config.extensions_config.ExtensionsConfig.from_file", + classmethod(lambda cls: ExtensionsConfig(mcp_servers={}, skills={})), + ) + monkeypatch.setattr("deerflow.tools.tools.is_host_bash_allowed", lambda config=None: True) + + captured: dict[str, object] = {} + + class DummyClient: + @property + def collected_text(self) -> str: + return "ok" + + async def session_update(self, session_id, update, **kwargs): + pass + + async def request_permission(self, options, session_id, tool_call, **kwargs): + raise AssertionError("should not be called") + + class DummyConn: + async def initialize(self, **kwargs): + pass + + async def new_session(self, **kwargs): + return SimpleNamespace(session_id="s1") + + async def prompt(self, **kwargs): + pass + + class DummyProcessContext: + def __init__(self, client, cmd, *args, env=None, cwd): + captured["cwd"] = cwd + + async def __aenter__(self): + return DummyConn(), object() + + async def __aexit__(self, exc_type, exc, tb): + return False + + monkeypatch.setitem( + sys.modules, + "acp", + SimpleNamespace( + PROTOCOL_VERSION="2026-03-24", + Client=DummyClient, + spawn_agent_process=lambda client, cmd, *args, env=None, cwd: DummyProcessContext(client, cmd, *args, env=env, cwd=cwd), + text_block=lambda text: {"type": "text", "text": text}, + ), + ) + monkeypatch.setitem( + sys.modules, + "acp.schema", + SimpleNamespace( + ClientCapabilities=lambda: {}, + Implementation=lambda **kwargs: kwargs, + TextContentBlock=type("TextContentBlock", (), {"__init__": lambda self, text: setattr(self, "text", text)}), + ), + ) + + explicit_config = SimpleNamespace( + tools=[], + models=[], + tool_search=SimpleNamespace(enabled=False), + skill_evolution=SimpleNamespace(enabled=False), + sandbox=SimpleNamespace(), + get_model_config=lambda name: None, + acp_agents={"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")}, + ) + tools = get_available_tools(include_mcp=False, subagent_enabled=False, app_config=explicit_config) + tool = next(tool for tool in tools if tool.name == "invoke_acp_agent") + + thread_id = "thread-sync-123" + tool.invoke( + {"agent": "codex", "prompt": "Do something"}, + config={"configurable": {"thread_id": thread_id}}, + ) + + assert captured["cwd"] == str(tmp_path / "threads" / thread_id / "acp-workspace") + + def test_get_available_tools_uses_explicit_app_config_for_acp_agents(monkeypatch): explicit_agents = {"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")} explicit_config = SimpleNamespace( diff --git a/backend/tests/test_mcp_sync_wrapper.py b/backend/tests/test_mcp_sync_wrapper.py index 285200781..c66662bb5 100644 --- a/backend/tests/test_mcp_sync_wrapper.py +++ b/backend/tests/test_mcp_sync_wrapper.py @@ -1,7 +1,9 @@ import asyncio +import contextvars from unittest.mock import AsyncMock, MagicMock, patch import pytest +from langchain_core.runnables import RunnableConfig from langchain_core.tools import StructuredTool from pydantic import BaseModel, Field @@ -69,6 +71,58 @@ def test_mcp_tool_sync_wrapper_in_running_loop(): assert result == "async_result: 100" +def test_sync_wrapper_preserves_contextvars_in_running_loop(): + """The executor branch preserves LangGraph-style contextvars.""" + current_value: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_value", default=None) + + async def mock_coro() -> str | None: + return current_value.get() + + sync_func = make_sync_tool_wrapper(mock_coro, "test_tool") + + async def run_in_loop() -> str | None: + token = current_value.set("from-parent-context") + try: + return sync_func() + finally: + current_value.reset(token) + + assert asyncio.run(run_in_loop()) == "from-parent-context" + + +def test_sync_wrapper_preserves_runnable_config_injection(): + """LangChain can still inject RunnableConfig after an async tool is wrapped.""" + captured: dict[str, object] = {} + + async def mock_coro(x: int, config: RunnableConfig = None): + captured["thread_id"] = ((config or {}).get("configurable") or {}).get("thread_id") + return f"result: {x}" + + mock_tool = StructuredTool( + name="test_tool", + description="test description", + args_schema=MockArgs, + func=make_sync_tool_wrapper(mock_coro, "test_tool"), + coroutine=mock_coro, + ) + + result = mock_tool.invoke({"x": 42}, config={"configurable": {"thread_id": "thread-123"}}) + + assert result == "result: 42" + assert captured["thread_id"] == "thread-123" + + +def test_sync_wrapper_preserves_regular_config_argument(): + """Only RunnableConfig-annotated coroutine params get special config injection.""" + + async def mock_coro(config: str): + return config + + sync_func = make_sync_tool_wrapper(mock_coro, "test_tool") + + assert sync_func(config="user-config") == "user-config" + + def test_mcp_tool_sync_wrapper_exception_logging(): """Test the shared sync wrapper's error logging.""" diff --git a/backend/tests/test_tool_deduplication.py b/backend/tests/test_tool_deduplication.py index f018fc57d..b8a7a3127 100644 --- a/backend/tests/test_tool_deduplication.py +++ b/backend/tests/test_tool_deduplication.py @@ -95,6 +95,64 @@ def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg): assert async_tool.invoke({"x": 42}) == "result: 42" +@patch("deerflow.tools.tools.get_app_config") +@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) +def test_subagent_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg): + """Async-only tools added through the subagent path can be invoked by sync clients.""" + + async def async_tool_impl(x: int) -> str: + return f"subagent: {x}" + + async_tool = StructuredTool( + name="async_subagent_tool", + description="Async-only subagent test tool.", + args_schema=AsyncToolArgs, + func=None, + coroutine=async_tool_impl, + ) + mock_cfg.return_value = _make_minimal_config([]) + + with ( + patch("deerflow.tools.tools.BUILTIN_TOOLS", []), + patch("deerflow.tools.tools.SUBAGENT_TOOLS", [async_tool]), + ): + result = get_available_tools(include_mcp=False, subagent_enabled=True, app_config=mock_cfg.return_value) + + assert async_tool in result + assert async_tool.func is not None + assert async_tool.invoke({"x": 7}) == "subagent: 7" + + +@patch("deerflow.tools.tools.get_app_config") +@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) +def test_acp_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg): + """Async-only ACP tools can be invoked by sync clients.""" + + async def async_tool_impl(x: int) -> str: + return f"acp: {x}" + + async_tool = StructuredTool( + name="invoke_acp_agent", + description="Async-only ACP test tool.", + args_schema=AsyncToolArgs, + func=None, + coroutine=async_tool_impl, + ) + config = _make_minimal_config([]) + config.acp_agents = {"codex": object()} + mock_cfg.return_value = config + + with ( + patch("deerflow.tools.tools.BUILTIN_TOOLS", []), + patch("deerflow.tools.builtins.invoke_acp_agent_tool.build_invoke_acp_agent_tool", return_value=async_tool), + ): + result = get_available_tools(include_mcp=False, app_config=config) + + assert async_tool in result + assert async_tool.func is not None + assert async_tool.invoke({"x": 9}) == "acp: 9" + + @patch("deerflow.tools.tools.get_app_config") @patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True) def test_no_duplicates_returned(mock_bash, mock_cfg): From b69ca7ad9768ee26225fb134bd90706402440e2e Mon Sep 17 00:00:00 2001 From: Lawrance_YXLiao <32213920+kibabsquirrel@users.noreply.github.com> Date: Tue, 19 May 2026 22:34:51 +0800 Subject: [PATCH 42/86] test(middleware): lock tool-call transcript boundary invariants (#3049) --- .../test_dangling_tool_call_middleware.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/backend/tests/test_dangling_tool_call_middleware.py b/backend/tests/test_dangling_tool_call_middleware.py index f9f47369d..5ecded924 100644 --- a/backend/tests/test_dangling_tool_call_middleware.py +++ b/backend/tests/test_dangling_tool_call_middleware.py @@ -190,6 +190,24 @@ class TestBuildPatchedMessagesPatching: assert [patched[1].tool_call_id, patched[2].tool_call_id] == ["call_1", "call_2"] assert isinstance(patched[3], HumanMessage) + def test_non_tool_message_inserted_between_partial_tool_results_is_regrouped(self): + mw = DanglingToolCallMiddleware() + msgs = [ + _ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]), + _tool_msg("call_1", "bash"), + HumanMessage(content="interruption"), + _tool_msg("call_2", "read"), + ] + + patched = mw._build_patched_messages(msgs) + + assert patched is not None + assert isinstance(patched[0], AIMessage) + assert isinstance(patched[1], ToolMessage) + assert isinstance(patched[2], ToolMessage) + assert [patched[1].tool_call_id, patched[2].tool_call_id] == ["call_1", "call_2"] + assert isinstance(patched[3], HumanMessage) + def test_valid_adjacent_tool_results_are_unchanged(self): mw = DanglingToolCallMiddleware() msgs = [ @@ -237,7 +255,8 @@ class TestBuildPatchedMessagesPatching: assert isinstance(patched[0], AIMessage) assert isinstance(patched[1], ToolMessage) assert patched[1].tool_call_id == "call_1" - assert orphan in patched + assert patched[2] is orphan + assert isinstance(patched[3], HumanMessage) assert patched.count(orphan) == 1 def test_invalid_tool_call_is_patched(self): From b1ec7e81114b424e2a3b3c21ea13e50d0b8a5a2c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 20 May 2026 07:27:45 +0800 Subject: [PATCH 43/86] chore(deps): bump idna from 3.13 to 3.15 in /backend (#3077) Bumps [idna](https://github.com/kjd/idna) from 3.13 to 3.15. - [Release notes](https://github.com/kjd/idna/releases) - [Changelog](https://github.com/kjd/idna/blob/master/HISTORY.md) - [Commits](https://github.com/kjd/idna/compare/v3.13...v3.15) --- updated-dependencies: - dependency-name: idna dependency-version: '3.15' dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- backend/uv.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/backend/uv.lock b/backend/uv.lock index 9cc2030fa..5501fb81f 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'win32'", @@ -1504,11 +1504,11 @@ wheels = [ [[package]] name = "idna" -version = "3.13" +version = "3.15" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ce/cc/762dfb036166873f0059f3b7de4565e1b5bc3d6f28a414c13da27e442f99/idna-3.13.tar.gz", hash = "sha256:585ea8fe5d69b9181ec1afba340451fba6ba764af97026f92a91d4eef164a242", size = 194210, upload-time = "2026-04-22T16:42:42.314Z" } +sdist = { url = "https://files.pythonhosted.org/packages/82/77/7b3966d0b9d1d31a36ddf1746926a11dface89a83409bf1483f0237aa758/idna-3.15.tar.gz", hash = "sha256:ca962446ea538f7092a95e057da437618e886f4d349216d2b1e294abfdb65fdc", size = 199245, upload-time = "2026-05-12T22:45:57.011Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5d/13/ad7d7ca3808a898b4612b6fe93cde56b53f3034dcde235acb1f0e1df24c6/idna-3.13-py3-none-any.whl", hash = "sha256:892ea0cde124a99ce773decba204c5552b69c3c67ffd5f232eb7696135bc8bb3", size = 68629, upload-time = "2026-04-22T16:42:40.909Z" }, + { url = "https://files.pythonhosted.org/packages/d2/23/408243171aa9aaba178d3e2559159c24c1171a641aa83b67bdd3394ead8e/idna-3.15-py3-none-any.whl", hash = "sha256:048adeaf8c2d788c40fee287673ccaa74c24ffd8dcf09ffa555a2fbb59f10ac8", size = 72340, upload-time = "2026-05-12T22:45:55.733Z" }, ] [[package]] From 006948232ccb31179a74bd7433f66a410fa2c360 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 20 May 2026 07:42:03 +0800 Subject: [PATCH 44/86] chore(deps): bump brace-expansion from 1.1.12 to 5.0.5 in /frontend (#3078) Bumps [brace-expansion](https://github.com/juliangruber/brace-expansion) from 1.1.12 to 5.0.5. - [Release notes](https://github.com/juliangruber/brace-expansion/releases) - [Commits](https://github.com/juliangruber/brace-expansion/compare/v1.1.12...v5.0.5) --- updated-dependencies: - dependency-name: brace-expansion dependency-version: 5.0.5 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- frontend/pnpm-lock.yaml | 226 ++++++++++++++++++++-------------------- 1 file changed, 113 insertions(+), 113 deletions(-) diff --git a/frontend/pnpm-lock.yaml b/frontend/pnpm-lock.yaml index 8c80061c9..426b607e8 100644 --- a/frontend/pnpm-lock.yaml +++ b/frontend/pnpm-lock.yaml @@ -1731,128 +1731,128 @@ packages: resolution: {integrity: sha512-FqALmHI8D4o6lk/LRWDnhw95z5eO+eAa6ORjVg09YRR7BkcM6oPHU9uyC0gtQG5vpFLvgpeU4+zEAz2H8APHNw==} engines: {node: '>= 10'} - '@rollup/rollup-android-arm-eabi@4.60.3': - resolution: {integrity: sha512-x35CNW/ANXG3hE/EZpRU8MXX1JDN86hBb2wMGAtltkz7pc6cxgjpy1OMMfDosOQ+2hWqIkag/fGok1Yady9nGw==} + '@rollup/rollup-android-arm-eabi@4.60.4': + resolution: {integrity: sha512-F5QXMSiFebS9hKZj02XhWLLnRpJ3B3AROP0tWbFBSj+6kCbg5m9j5JoHKd4mmSVy5mS/IMQloYgYxCuJC0fxEQ==} cpu: [arm] os: [android] - '@rollup/rollup-android-arm64@4.60.3': - resolution: {integrity: sha512-xw3xtkDApIOGayehp2+Rz4zimfkaX65r4t47iy+ymQB2G4iJCBBfj0ogVg5jpvjpn8UWn/+q9tprxleYeNp3Hw==} + '@rollup/rollup-android-arm64@4.60.4': + resolution: {integrity: sha512-GxxTKApUpzRhof7poWvCJHRF51C67u1R7D6DiluBE8wKU1u5GWE8t+v81JvJYtbawoBFX1hLv5Ei4eVjkWokaw==} cpu: [arm64] os: [android] - '@rollup/rollup-darwin-arm64@4.60.3': - resolution: {integrity: sha512-vo6Y5Qfpx7/5EaamIwi0WqW2+zfiusVihKatLvtN1VFVy3D13uERk/6gZLU1UiHRL6fDXqj/ELIeVRGnvcTE1g==} + '@rollup/rollup-darwin-arm64@4.60.4': + resolution: {integrity: sha512-tua0TaJxMOB1R0V0RS1jFZ/RpURFDJIOR2A6jWwQeawuFyS4gBW+rntLRaQd0EQ4bd6Vp44Z2rXW+YYDBsj6IA==} cpu: [arm64] os: [darwin] - '@rollup/rollup-darwin-x64@4.60.3': - resolution: {integrity: sha512-D+0QGcZhBzTN82weOnsSlY7V7+RMmPuF1CkbxyMAGE8+ZHeUjyb76ZiWmBlCu//AQQONvxcqRbwZTajZKqjuOw==} + '@rollup/rollup-darwin-x64@4.60.4': + resolution: {integrity: sha512-CSKq7MsP+5PFIcydhAiR1K0UhEI1A2jWXVKHPCBZ151yOutENwvnPocgVHkivu2kviURtCEB6zUQw0vs8RrhMg==} cpu: [x64] os: [darwin] - '@rollup/rollup-freebsd-arm64@4.60.3': - resolution: {integrity: sha512-6HnvHCT7fDyj6R0Ph7A6x8dQS/S38MClRWeDLqc0MdfWkxjiu1HSDYrdPhqSILzjTIC/pnXbbJbo+ft+gy/9hQ==} + '@rollup/rollup-freebsd-arm64@4.60.4': + resolution: {integrity: sha512-+O8OkVdyvXMtJEciu2wS/pzm1IxntEEQx3z5TAVy4l32G0etZn+RsA48ARRrFm6Ri8fvqPQfgrvNxSjKAbnd3g==} cpu: [arm64] os: [freebsd] - '@rollup/rollup-freebsd-x64@4.60.3': - resolution: {integrity: sha512-KHLgC3WKlUYW3ShFKnnosZDOJ0xjg9zp7au3sIm2bs/tGBeC2ipmvRh/N7JKi0t9Ue20C0dpEshi8WUubg+cnA==} + '@rollup/rollup-freebsd-x64@4.60.4': + resolution: {integrity: sha512-Iw3oMskH3AfNuhU0MSN7vNbdi4me/NiYo2azqPz/Le16zHSa+3RRmliCMWWQmh4lcndccU40xcJuTYJZxNo/lw==} cpu: [x64] os: [freebsd] - '@rollup/rollup-linux-arm-gnueabihf@4.60.3': - resolution: {integrity: sha512-DV6fJoxEYWJOvaZIsok7KrYl0tPvga5OZ2yvKHNNYyk/2roMLqQAbGhr78EQ5YhHpnhLKJD3S1WFusAkmUuV5g==} + '@rollup/rollup-linux-arm-gnueabihf@4.60.4': + resolution: {integrity: sha512-EIPRXTVQpHyF8WOo219AD2yEltPehLTcTMz2fn6JsatLYSzQf00hj3rulF+yauOlF9/FtM2WpkT/hJh/KJFGhA==} cpu: [arm] os: [linux] - '@rollup/rollup-linux-arm-musleabihf@4.60.3': - resolution: {integrity: sha512-mQKoJAzvuOs6F+TZybQO4GOTSMUu7v0WdxEk24krQ/uUxXoPTtHjuaUuPmFhtBcM4K0ons8nrE3JyhTuCFtT/w==} + '@rollup/rollup-linux-arm-musleabihf@4.60.4': + resolution: {integrity: sha512-J3Yh9PzzF1Ovah2At+lHiGQdsYgArxBbXv/zHfSyaiFQEqvNv7DcW98pCrmdjCZBrqBiKrKKe2V+aaSGWuBe/w==} cpu: [arm] os: [linux] - '@rollup/rollup-linux-arm64-gnu@4.60.3': - resolution: {integrity: sha512-Whjj2qoiJ6+OOJMGptTYazaJvjOJm+iKHpXQM1P3LzGjt7Ff++Tp7nH4N8J/BUA7R9IHfDyx4DJIflifwnbmIA==} + '@rollup/rollup-linux-arm64-gnu@4.60.4': + resolution: {integrity: sha512-BFDEZMYfUvLn37ONE1yMBojPxnMlTFsdyNoqncT0qFq1mAfllL+ATMMJd8TeuVMiX84s1KbcxcZbXInmcO2mRg==} cpu: [arm64] os: [linux] - '@rollup/rollup-linux-arm64-musl@4.60.3': - resolution: {integrity: sha512-4YTNHKqGng5+yiZt3mg77nmyuCfmNfX4fPmyUapBcIk+BdwSwmCWGXOUxhXbBEkFHtoN5boLj/5NON+u5QC9tg==} + '@rollup/rollup-linux-arm64-musl@4.60.4': + resolution: {integrity: sha512-pc9EYOSlOgdQ2uPl1o9PF6/kLSgaUosia7gOuS8mB69IxJvlclko1MECXysjs5ryez1/5zjYqx3+xYU0TU6R1A==} cpu: [arm64] os: [linux] - '@rollup/rollup-linux-loong64-gnu@4.60.3': - resolution: {integrity: sha512-SU3kNlhkpI4UqlUc2VXPGK9o886ZsSeGfMAX2ba2b8DKmMXq4AL7KUrkSWVbb7koVqx41Yczx6dx5PNargIrEA==} + '@rollup/rollup-linux-loong64-gnu@4.60.4': + resolution: {integrity: sha512-NxnomyxYerDh5n4iLrNa+sH+Z+U4BMEE46V2PgQ/hoB909i8gV1M5wPojWg9fk1jWpO3IQnOs20K4wyZuFLEFQ==} cpu: [loong64] os: [linux] - '@rollup/rollup-linux-loong64-musl@4.60.3': - resolution: {integrity: sha512-6lDLl5h4TXpB1mTf2rQWnAk/LcXrx9vBfu/DT5TIPhvMhRWaZ5MxkIc8u4lJAmBo6klTe1ywXIUHFjylW505sg==} + '@rollup/rollup-linux-loong64-musl@4.60.4': + resolution: {integrity: sha512-nbJnQ8a3z1mtmrwImCYhc6BGpThAyYVRQxw9uKSKG4wR6aAYno9sVjJ0zaZcW9BPJX1GbrDPf+SvdWjgTuDmnw==} cpu: [loong64] os: [linux] - '@rollup/rollup-linux-ppc64-gnu@4.60.3': - resolution: {integrity: sha512-BMo8bOw8evlup/8G+cj5xWtPyp93xPdyoSN16Zy90Q2QZ0ZYRhCt6ZJSwbrRzG9HApFabjwj2p25TUPDWrhzqQ==} + '@rollup/rollup-linux-ppc64-gnu@4.60.4': + resolution: {integrity: sha512-2EU6acNrQLd8tYvo/LXW535wupT3m6fo7HKo6lr7ktQoItxTyOL1ZCR/GfGCuXl2vR+zmfI6eRXkSemafv+iVg==} cpu: [ppc64] os: [linux] - '@rollup/rollup-linux-ppc64-musl@4.60.3': - resolution: {integrity: sha512-E0L8X1dZN1/Rph+5VPF6Xj2G7JJvMACVXtamTJIDrVI44Y3K+G8gQaMEAavbqCGTa16InptiVrX6eM6pmJ+7qA==} + '@rollup/rollup-linux-ppc64-musl@4.60.4': + resolution: {integrity: sha512-WeBtoMuaMxiiIrO2IYP3xs6GMWkJP2C0EoT8beTLkUPmzV1i/UcOSVw1d5r9KBODtHKilG5yFxsGRnBbK3wJ4A==} cpu: [ppc64] os: [linux] - '@rollup/rollup-linux-riscv64-gnu@4.60.3': - resolution: {integrity: sha512-oZJ/WHaVfHUiRAtmTAeo3DcevNsVvH8mbvodjZy7D5QKvCefO371SiKRpxoDcCxB3PTRTLayWBkvmDQKTcX/sw==} + '@rollup/rollup-linux-riscv64-gnu@4.60.4': + resolution: {integrity: sha512-FJHFfqpKUI3A10WrWKiFbBZ7yVbGT4q4B5o1qKFFojqpaYoh9LrQgqWCmmcxQzVSXYtyB5bzkXrYzlHTs21MYA==} cpu: [riscv64] os: [linux] - '@rollup/rollup-linux-riscv64-musl@4.60.3': - resolution: {integrity: sha512-Dhbyh7j9FybM3YaTgaHmVALwA8AkUwTPccyCQ79TG9AJUsMQqgN1DDEZNr4+QUfwiWvLDumW5vdwzoeUF+TNxQ==} + '@rollup/rollup-linux-riscv64-musl@4.60.4': + resolution: {integrity: sha512-mcEl6CUT5IAUmQf1m9FYSmVqCJlpQ8r8eyftFUHG8i9OhY7BkBXSUdnLH5DOf0wCOjcP9v/QO93zpmF1SptCCw==} cpu: [riscv64] os: [linux] - '@rollup/rollup-linux-s390x-gnu@4.60.3': - resolution: {integrity: sha512-cJd1X5XhHHlltkaypz1UcWLA8AcoIi1aWhsvaWDskD1oz2eKCypnqvTQ8ykMNI0RSmm7NkTdSqSSD7zM0xa6Ig==} + '@rollup/rollup-linux-s390x-gnu@4.60.4': + resolution: {integrity: sha512-ynt3JxVd2w2buzoKDWIyiV1pJW93xlQic1THVLXilz429oijRpSHivZAgp65KBu+cMcgf1eVVjdnTLvPxgCuoQ==} cpu: [s390x] os: [linux] - '@rollup/rollup-linux-x64-gnu@4.60.3': - resolution: {integrity: sha512-DAZDBHQfG2oQuhY7mc6I3/qB4LU2fQCjRvxbDwd/Jdvb9fypP4IJ4qmtu6lNjes6B531AI8cg1aKC2di97bUxA==} + '@rollup/rollup-linux-x64-gnu@4.60.4': + resolution: {integrity: sha512-Boiz5+MsaROEWDf+GGEwF8VMHGhlUoQMtIPjOgA5fv4osupqTVnJteQNKJwUcnUog2G55jYXH7KZFFiJe0TEzQ==} cpu: [x64] os: [linux] - '@rollup/rollup-linux-x64-musl@4.60.3': - resolution: {integrity: sha512-cRxsE8c13mZOh3vP+wLDxpQBRrOHDIGOWyDL93Sy0Ga8y515fBcC2pjUfFwUe5T7tqvTvWbCpg1URM/AXdWIXA==} + '@rollup/rollup-linux-x64-musl@4.60.4': + resolution: {integrity: sha512-+qfSY27qIrFfI/Hom04KYFw3GKZSGU4lXus51wsb5EuySfFlWRwjkKWoE9emgRw/ukoT4Udsj4W/+xxG8VbPKg==} cpu: [x64] os: [linux] - '@rollup/rollup-openbsd-x64@4.60.3': - resolution: {integrity: sha512-QaWcIgRxqEdQdhJqW4DJctsH6HCmo5vHxY0krHSX4jMtOqfzC+dqDGuHM87bu4H8JBeibWx7jFz+h6/4C8wA5Q==} + '@rollup/rollup-openbsd-x64@4.60.4': + resolution: {integrity: sha512-VpTfOPHgVXEBeeR8hZ2O0F3aSso+JDWqTWmTmzcQKted54IAdUVbxE+j/MVxUsKa8L20HJhv3vUezVPoquqWjA==} cpu: [x64] os: [openbsd] - '@rollup/rollup-openharmony-arm64@4.60.3': - resolution: {integrity: sha512-AaXwSvUi3QIPtroAUw1t5yHGIyqKEXwH54WUocFolZhpGDruJcs8c+xPNDRn4XiQsS7MEwnYsHW2l0MBLDMkWg==} + '@rollup/rollup-openharmony-arm64@4.60.4': + resolution: {integrity: sha512-IPOsh5aRYuLv/nkU51X10Bf75Bsf6+gZdx1X+QP5QM6lIJFHHqbHLG0uJn/hWthzo13UAc2umiUorqZy3axoZg==} cpu: [arm64] os: [openharmony] - '@rollup/rollup-win32-arm64-msvc@4.60.3': - resolution: {integrity: sha512-65LAKM/bAWDqKNEelHlcHvm2V+Vfb8C6INFxQXRHCvaVN1rJfwr4NvdP4FyzUaLqWfaCGaadf6UbTm8xJeYfEg==} + '@rollup/rollup-win32-arm64-msvc@4.60.4': + resolution: {integrity: sha512-4QzE9E81OohJ/HKzHhsqU+zcYYojVOXlFMs1DdyMT6qXl/niOH7AVElmmEdUNHHS/oRkc++d5k6Vy85zFs0DEw==} cpu: [arm64] os: [win32] - '@rollup/rollup-win32-ia32-msvc@4.60.3': - resolution: {integrity: sha512-EEM2gyhBF5MFnI6vMKdX1LAosE627RGBzIoGMdLloPZkXrUN0Ckqgr2Qi8+J3zip/8NVVro3/FjB+tjhZUgUHA==} + '@rollup/rollup-win32-ia32-msvc@4.60.4': + resolution: {integrity: sha512-zTPgT1YuHHcd+Tmx7h8aml0FWFVelV5N54oHow9SLj+GfoDy/huQ+UV396N/C7KpMDMiPspRktzM1/0r1usYEA==} cpu: [ia32] os: [win32] - '@rollup/rollup-win32-x64-gnu@4.60.3': - resolution: {integrity: sha512-E5Eb5H/DpxaoXH++Qkv28RcUJboMopmdDUALBczvHMf7hNIxaDZqwY5lK12UK1BHacSmvupoEWGu+n993Z0y1A==} + '@rollup/rollup-win32-x64-gnu@4.60.4': + resolution: {integrity: sha512-DRS4G7mi9lJxqEDezIkKCaUIKCrLUUDCUaCsTPCi/rtqaC6D/jjwslMQyiDU50Ka0JKpeXeRBFBAXwArY52vBw==} cpu: [x64] os: [win32] - '@rollup/rollup-win32-x64-msvc@4.60.3': - resolution: {integrity: sha512-hPt/bgL5cE+Qp+/TPHBqptcAgPzgj46mPcg/16zNUmbQk0j+mOEQV/+Lqu8QRtDV3Ek95Q6FeFITpuhl6OTsAA==} + '@rollup/rollup-win32-x64-msvc@4.60.4': + resolution: {integrity: sha512-QVTUovf40zgTqlFVrKA1uXMVvU2QWEFWfAH8Wdc48IxLvrJMQVMBRjuQyUpzZCDkakImib9eVazbWlC6ksWtJw==} cpu: [x64] os: [win32] @@ -4079,8 +4079,8 @@ packages: resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==} hasBin: true - lru-cache@11.3.6: - resolution: {integrity: sha512-Gf/KoL3C/MlI7Bt0PGI9I+TeTC/I6r/csU58N4BSNc4lppLBeKsOdFYkK+dX0ABDUMJNfCHTyPpzwwO21Awd3A==} + lru-cache@11.5.0: + resolution: {integrity: sha512-5YgH9UJd7wVb9hIouI2adWpgqrrICkt070Dnj8EUY1+B4B2P9eRLPAkAAo6NICA7CEhOIeBHl46u9zSNpNu7zA==} engines: {node: 20 || >=22} lucide-react@0.542.0: @@ -4671,8 +4671,8 @@ packages: resolution: {integrity: sha512-PS08Iboia9mts/2ygV3eLpY5ghnUcfLV/EXTOW1E2qYxJKGGBUtNjN76FYHnMs36RmARn41bC0AZmn+rR0OVpQ==} engines: {node: ^10 || ^12 || >=14} - postcss@8.5.14: - resolution: {integrity: sha512-SoSL4+OSEtR99LHFZQiJLkT59C5B1amGO1NzTwj7TT1qCUgUO6hxOvzkOYxD+vMrXBM3XJIKzokoERdqQq/Zmg==} + postcss@8.5.15: + resolution: {integrity: sha512-FfR8sjd4em2T6fb3I2MwAJU7HWVMr9zba+enmQeeWFfCbm+UOC/0X4DS8XtpUTMwWMGbjKYP7xjfNekzyGmB3A==} engines: {node: ^10 || ^12 || >=14} postcss@8.5.6: @@ -4962,8 +4962,8 @@ packages: robust-predicates@3.0.2: resolution: {integrity: sha512-IXgzBWvWQwE6PrDI05OvmXUIruQTcoMDzRsOd5CDvHCVLcLHMTSYvOK5Cm46kWqlV3yAbuSpBZdJ5oP5OUoStg==} - rollup@4.60.3: - resolution: {integrity: sha512-pAQK9HalE84QSm4Po3EmWIZPd3FnjkShVkiMlz1iligWYkWQ7wHYd1PF/T7QZ5TVSD6uSTon5gBVMSM4JfBV+A==} + rollup@4.60.4: + resolution: {integrity: sha512-WHeFSbZYsPu3+bLoNRUuAO+wavNlocOPf3wSHTP7hcFKVnJeWsYlCDbr3mTS14FCizf9ccIxXA8sGL8zKeQN3g==} engines: {node: '>=18.0.0', npm: '>=8.0.0'} hasBin: true @@ -7297,79 +7297,79 @@ snapshots: '@resvg/resvg-wasm@2.6.2': {} - '@rollup/rollup-android-arm-eabi@4.60.3': + '@rollup/rollup-android-arm-eabi@4.60.4': optional: true - '@rollup/rollup-android-arm64@4.60.3': + '@rollup/rollup-android-arm64@4.60.4': optional: true - '@rollup/rollup-darwin-arm64@4.60.3': + '@rollup/rollup-darwin-arm64@4.60.4': optional: true - '@rollup/rollup-darwin-x64@4.60.3': + '@rollup/rollup-darwin-x64@4.60.4': optional: true - '@rollup/rollup-freebsd-arm64@4.60.3': + '@rollup/rollup-freebsd-arm64@4.60.4': optional: true - '@rollup/rollup-freebsd-x64@4.60.3': + '@rollup/rollup-freebsd-x64@4.60.4': optional: true - '@rollup/rollup-linux-arm-gnueabihf@4.60.3': + '@rollup/rollup-linux-arm-gnueabihf@4.60.4': optional: true - '@rollup/rollup-linux-arm-musleabihf@4.60.3': + '@rollup/rollup-linux-arm-musleabihf@4.60.4': optional: true - '@rollup/rollup-linux-arm64-gnu@4.60.3': + '@rollup/rollup-linux-arm64-gnu@4.60.4': optional: true - '@rollup/rollup-linux-arm64-musl@4.60.3': + '@rollup/rollup-linux-arm64-musl@4.60.4': optional: true - '@rollup/rollup-linux-loong64-gnu@4.60.3': + '@rollup/rollup-linux-loong64-gnu@4.60.4': optional: true - '@rollup/rollup-linux-loong64-musl@4.60.3': + '@rollup/rollup-linux-loong64-musl@4.60.4': optional: true - '@rollup/rollup-linux-ppc64-gnu@4.60.3': + '@rollup/rollup-linux-ppc64-gnu@4.60.4': optional: true - '@rollup/rollup-linux-ppc64-musl@4.60.3': + '@rollup/rollup-linux-ppc64-musl@4.60.4': optional: true - '@rollup/rollup-linux-riscv64-gnu@4.60.3': + '@rollup/rollup-linux-riscv64-gnu@4.60.4': optional: true - '@rollup/rollup-linux-riscv64-musl@4.60.3': + '@rollup/rollup-linux-riscv64-musl@4.60.4': optional: true - '@rollup/rollup-linux-s390x-gnu@4.60.3': + '@rollup/rollup-linux-s390x-gnu@4.60.4': optional: true - '@rollup/rollup-linux-x64-gnu@4.60.3': + '@rollup/rollup-linux-x64-gnu@4.60.4': optional: true - '@rollup/rollup-linux-x64-musl@4.60.3': + '@rollup/rollup-linux-x64-musl@4.60.4': optional: true - '@rollup/rollup-openbsd-x64@4.60.3': + '@rollup/rollup-openbsd-x64@4.60.4': optional: true - '@rollup/rollup-openharmony-arm64@4.60.3': + '@rollup/rollup-openharmony-arm64@4.60.4': optional: true - '@rollup/rollup-win32-arm64-msvc@4.60.3': + '@rollup/rollup-win32-arm64-msvc@4.60.4': optional: true - '@rollup/rollup-win32-ia32-msvc@4.60.3': + '@rollup/rollup-win32-ia32-msvc@4.60.4': optional: true - '@rollup/rollup-win32-x64-gnu@4.60.3': + '@rollup/rollup-win32-x64-gnu@4.60.4': optional: true - '@rollup/rollup-win32-x64-msvc@4.60.3': + '@rollup/rollup-win32-x64-msvc@4.60.4': optional: true '@rtsao/scc@1.1.0': {} @@ -8067,7 +8067,7 @@ snapshots: '@vue/shared': 3.5.28 estree-walker: 2.0.2 magic-string: 0.30.21 - postcss: 8.5.14 + postcss: 8.5.15 source-map-js: 1.2.1 '@vue/compiler-ssr@3.5.28': @@ -9947,7 +9947,7 @@ snapshots: dependencies: js-tokens: 4.0.0 - lru-cache@11.3.6: {} + lru-cache@11.5.0: {} lucide-react@0.542.0(react@19.2.4): dependencies: @@ -10941,7 +10941,7 @@ snapshots: picocolors: 1.1.1 source-map-js: 1.2.1 - postcss@8.5.14: + postcss@8.5.15: dependencies: nanoid: 3.3.12 picocolors: 1.1.1 @@ -11282,35 +11282,35 @@ snapshots: robust-predicates@3.0.2: {} - rollup@4.60.3: + rollup@4.60.4: dependencies: '@types/estree': 1.0.8 optionalDependencies: - '@rollup/rollup-android-arm-eabi': 4.60.3 - '@rollup/rollup-android-arm64': 4.60.3 - '@rollup/rollup-darwin-arm64': 4.60.3 - '@rollup/rollup-darwin-x64': 4.60.3 - '@rollup/rollup-freebsd-arm64': 4.60.3 - '@rollup/rollup-freebsd-x64': 4.60.3 - '@rollup/rollup-linux-arm-gnueabihf': 4.60.3 - '@rollup/rollup-linux-arm-musleabihf': 4.60.3 - '@rollup/rollup-linux-arm64-gnu': 4.60.3 - '@rollup/rollup-linux-arm64-musl': 4.60.3 - '@rollup/rollup-linux-loong64-gnu': 4.60.3 - '@rollup/rollup-linux-loong64-musl': 4.60.3 - '@rollup/rollup-linux-ppc64-gnu': 4.60.3 - '@rollup/rollup-linux-ppc64-musl': 4.60.3 - '@rollup/rollup-linux-riscv64-gnu': 4.60.3 - '@rollup/rollup-linux-riscv64-musl': 4.60.3 - '@rollup/rollup-linux-s390x-gnu': 4.60.3 - '@rollup/rollup-linux-x64-gnu': 4.60.3 - '@rollup/rollup-linux-x64-musl': 4.60.3 - '@rollup/rollup-openbsd-x64': 4.60.3 - '@rollup/rollup-openharmony-arm64': 4.60.3 - '@rollup/rollup-win32-arm64-msvc': 4.60.3 - '@rollup/rollup-win32-ia32-msvc': 4.60.3 - '@rollup/rollup-win32-x64-gnu': 4.60.3 - '@rollup/rollup-win32-x64-msvc': 4.60.3 + '@rollup/rollup-android-arm-eabi': 4.60.4 + '@rollup/rollup-android-arm64': 4.60.4 + '@rollup/rollup-darwin-arm64': 4.60.4 + '@rollup/rollup-darwin-x64': 4.60.4 + '@rollup/rollup-freebsd-arm64': 4.60.4 + '@rollup/rollup-freebsd-x64': 4.60.4 + '@rollup/rollup-linux-arm-gnueabihf': 4.60.4 + '@rollup/rollup-linux-arm-musleabihf': 4.60.4 + '@rollup/rollup-linux-arm64-gnu': 4.60.4 + '@rollup/rollup-linux-arm64-musl': 4.60.4 + '@rollup/rollup-linux-loong64-gnu': 4.60.4 + '@rollup/rollup-linux-loong64-musl': 4.60.4 + '@rollup/rollup-linux-ppc64-gnu': 4.60.4 + '@rollup/rollup-linux-ppc64-musl': 4.60.4 + '@rollup/rollup-linux-riscv64-gnu': 4.60.4 + '@rollup/rollup-linux-riscv64-musl': 4.60.4 + '@rollup/rollup-linux-s390x-gnu': 4.60.4 + '@rollup/rollup-linux-x64-gnu': 4.60.4 + '@rollup/rollup-linux-x64-musl': 4.60.4 + '@rollup/rollup-openbsd-x64': 4.60.4 + '@rollup/rollup-openharmony-arm64': 4.60.4 + '@rollup/rollup-win32-arm64-msvc': 4.60.4 + '@rollup/rollup-win32-ia32-msvc': 4.60.4 + '@rollup/rollup-win32-x64-gnu': 4.60.4 + '@rollup/rollup-win32-x64-msvc': 4.60.4 fsevents: 2.3.3 roughjs@4.6.6: @@ -11908,7 +11908,7 @@ snapshots: chokidar: 5.0.0 destr: 2.0.5 h3: 1.15.11 - lru-cache: 11.3.6 + lru-cache: 11.5.0 node-fetch-native: 1.6.7 ofetch: 1.5.1 ufo: 1.6.4 @@ -11985,8 +11985,8 @@ snapshots: esbuild: 0.27.7 fdir: 6.5.0(picomatch@4.0.4) picomatch: 4.0.4 - postcss: 8.5.14 - rollup: 4.60.3 + postcss: 8.5.15 + rollup: 4.60.4 tinyglobby: 0.2.16 optionalDependencies: '@types/node': 20.19.33 From 0c22349029067d5efde3ca5624a2d2f2120a48bd Mon Sep 17 00:00:00 2001 From: AochenShen99 <142667174+ShenAC-SAC@users.noreply.github.com> Date: Wed, 20 May 2026 10:00:17 +0800 Subject: [PATCH 45/86] chore(dev): add async/thread boundary detector (#2936) * chore(dev): add thread boundary detector * chore(dev): reduce thread boundary detector false positives --- Makefile | 6 +- .../support/detectors/thread_boundaries.py | 507 ++++++++++++++++++ .../tests/test_detect_thread_boundaries.py | 182 +++++++ scripts/detect_thread_boundaries.py | 23 + 4 files changed, 717 insertions(+), 1 deletion(-) create mode 100644 backend/tests/support/detectors/thread_boundaries.py create mode 100644 backend/tests/test_detect_thread_boundaries.py create mode 100644 scripts/detect_thread_boundaries.py diff --git a/Makefile b/Makefile index c60d9b9b2..fb83cd556 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ # DeerFlow - Unified Development Environment -.PHONY: help config config-upgrade check install setup doctor dev dev-daemon start start-daemon stop up down clean docker-init docker-start docker-stop docker-logs docker-logs-frontend docker-logs-gateway +.PHONY: help config config-upgrade check install setup doctor detect-thread-boundaries dev dev-daemon start start-daemon stop up down clean docker-init docker-start docker-stop docker-logs docker-logs-frontend docker-logs-gateway BASH ?= bash BACKEND_UV_RUN = cd backend && uv run @@ -23,6 +23,7 @@ help: @echo " make config - Generate local config files (aborts if config already exists)" @echo " make config-upgrade - Merge new fields from config.example.yaml into config.yaml" @echo " make check - Check if all required tools are installed" + @echo " make detect-thread-boundaries - Inventory async/thread boundary points" @echo " make install - Install all dependencies (frontend + backend + pre-commit hooks)" @echo " make setup-sandbox - Pre-pull sandbox container image (recommended)" @echo " make dev - Start all services in development mode (with hot-reloading)" @@ -51,6 +52,9 @@ setup: doctor: @$(BACKEND_UV_RUN) python ../scripts/doctor.py +detect-thread-boundaries: + @$(PYTHON) ./scripts/detect_thread_boundaries.py + config: @$(PYTHON) ./scripts/configure.py diff --git a/backend/tests/support/detectors/thread_boundaries.py b/backend/tests/support/detectors/thread_boundaries.py new file mode 100644 index 000000000..b1d043d47 --- /dev/null +++ b/backend/tests/support/detectors/thread_boundaries.py @@ -0,0 +1,507 @@ +#!/usr/bin/env python3 +"""Inventory async/thread boundary points for developer review. + +This detector is intentionally non-invasive: it parses Python source with AST +and reports places where code crosses sync/async/thread boundaries. Findings +are review evidence, not automatic bug decisions. +""" + +from __future__ import annotations + +import argparse +import ast +import json +import os +import sys +from collections.abc import Iterable, Sequence +from dataclasses import asdict, dataclass +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[4] +DEFAULT_SCAN_PATHS = ( + REPO_ROOT / "backend" / "app", + REPO_ROOT / "backend" / "packages" / "harness" / "deerflow", +) +IGNORED_DIR_NAMES = { + ".git", + ".mypy_cache", + ".pytest_cache", + ".ruff_cache", + ".venv", + "__pycache__", + "node_modules", +} +SEVERITY_ORDER = {"INFO": 0, "WARN": 1, "FAIL": 2} + + +@dataclass(frozen=True) +class BoundaryFinding: + severity: str + category: str + path: str + line: int + column: int + function: str + async_context: bool + symbol: str + message: str + code: str + + def to_dict(self) -> dict[str, object]: + return asdict(self) + + +@dataclass(frozen=True) +class _FunctionContext: + name: str + is_async: bool + + +@dataclass(frozen=True) +class _CallRule: + severity: str + category: str + message: str + + +EXACT_CALL_RULES: dict[str, _CallRule] = { + "asyncio.run": _CallRule( + "WARN", + "SYNC_ASYNC_BRIDGE", + "Runs a coroutine from synchronous code by creating an event loop boundary.", + ), + "asyncio.to_thread": _CallRule( + "INFO", + "ASYNC_THREAD_OFFLOAD", + "Offloads synchronous work from an async context into a worker thread.", + ), + "asyncio.new_event_loop": _CallRule( + "WARN", + "NEW_EVENT_LOOP", + "Creates a separate event loop; review resource ownership across loops.", + ), + "asyncio.run_coroutine_threadsafe": _CallRule( + "WARN", + "CROSS_THREAD_COROUTINE", + "Submits a coroutine to an event loop from another thread.", + ), + "concurrent.futures.ThreadPoolExecutor": _CallRule( + "INFO", + "THREAD_POOL", + "Creates a thread pool boundary.", + ), + "threading.Thread": _CallRule( + "INFO", + "RAW_THREAD", + "Creates a raw thread; ContextVar values do not propagate automatically.", + ), + "threading.Timer": _CallRule( + "INFO", + "RAW_TIMER_THREAD", + "Creates a timer-backed raw thread; ContextVar values do not propagate automatically.", + ), + "make_sync_tool_wrapper": _CallRule( + "INFO", + "SYNC_TOOL_WRAPPER", + "Adapts an async tool coroutine for synchronous tool invocation.", + ), +} +THREAD_POOL_CONSTRUCTORS = {"concurrent.futures.ThreadPoolExecutor"} +ASYNC_TOOL_FACTORY_CALLS = { + "StructuredTool.from_function", + "langchain.tools.StructuredTool.from_function", + "langchain_core.tools.StructuredTool.from_function", +} +LANGCHAIN_INVOKE_RECEIVER_NAMES = { + "agent", + "chain", + "chat_model", + "graph", + "llm", + "model", + "runnable", +} +LANGCHAIN_INVOKE_RECEIVER_SUFFIXES = ( + "_agent", + "_chain", + "_graph", + "_llm", + "_model", + "_runnable", +) + +ASYNC_BLOCKING_CALL_RULES: dict[str, _CallRule] = { + "time.sleep": _CallRule( + "WARN", + "BLOCKING_CALL_IN_ASYNC", + "Blocks the event loop when called directly inside async code.", + ), + "subprocess.run": _CallRule( + "WARN", + "BLOCKING_SUBPROCESS_IN_ASYNC", + "Runs a blocking subprocess from async code.", + ), + "subprocess.check_call": _CallRule( + "WARN", + "BLOCKING_SUBPROCESS_IN_ASYNC", + "Runs a blocking subprocess from async code.", + ), + "subprocess.check_output": _CallRule( + "WARN", + "BLOCKING_SUBPROCESS_IN_ASYNC", + "Runs a blocking subprocess from async code.", + ), + "subprocess.Popen": _CallRule( + "WARN", + "BLOCKING_SUBPROCESS_IN_ASYNC", + "Starts a subprocess from async code; review whether it blocks later.", + ), +} + + +def dotted_name(node: ast.AST | None) -> str | None: + if isinstance(node, ast.Name): + return node.id + if isinstance(node, ast.Attribute): + parent = dotted_name(node.value) + if parent: + return f"{parent}.{node.attr}" + return node.attr + return None + + +def call_receiver_name(node: ast.Call) -> str | None: + if not isinstance(node.func, ast.Attribute): + return None + return dotted_name(node.func.value) + + +def is_none_node(node: ast.AST | None) -> bool: + return isinstance(node, ast.Constant) and node.value is None + + +class BoundaryVisitor(ast.NodeVisitor): + def __init__(self, path: Path, relative_path: str, source_lines: Sequence[str]) -> None: + self.path = path + self.relative_path = relative_path + self.source_lines = source_lines + self.findings: list[BoundaryFinding] = [] + self.function_stack: list[_FunctionContext] = [] + self.import_aliases: dict[str, str] = {} + self.executor_names: set[str] = set() + + @property + def current_function(self) -> str: + if not self.function_stack: + return "" + return ".".join(context.name for context in self.function_stack) + + @property + def in_async_context(self) -> bool: + return bool(self.function_stack and self.function_stack[-1].is_async) + + def visit_Import(self, node: ast.Import) -> None: + for alias in node.names: + local_name = alias.asname or alias.name.split(".", 1)[0] + canonical_name = alias.name if alias.asname else local_name + self.import_aliases[local_name] = canonical_name + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + if node.module is None: + return + for alias in node.names: + local_name = alias.asname or alias.name + self.import_aliases[local_name] = f"{node.module}.{alias.name}" + + def visit_Assign(self, node: ast.Assign) -> None: + self._record_executor_targets(node.value, node.targets) + self.generic_visit(node) + + def visit_AnnAssign(self, node: ast.AnnAssign) -> None: + if node.value is not None: + self._record_executor_targets(node.value, [node.target]) + self.generic_visit(node) + + def visit_With(self, node: ast.With) -> None: + for item in node.items: + if item.optional_vars is not None: + self._record_executor_targets(item.context_expr, [item.optional_vars]) + self.generic_visit(node) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + self.function_stack.append(_FunctionContext(node.name, is_async=False)) + self.generic_visit(node) + self.function_stack.pop() + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None: + self.function_stack.append(_FunctionContext(node.name, is_async=True)) + try: + self._check_async_tool_definition(node) + self.generic_visit(node) + finally: + self.function_stack.pop() + + def visit_Call(self, node: ast.Call) -> None: + call_name = self._canonical_name(dotted_name(node.func)) + if call_name: + self._check_call(node, call_name) + self.generic_visit(node) + + def _check_async_tool_definition(self, node: ast.AsyncFunctionDef) -> None: + for decorator in node.decorator_list: + decorator_call = decorator.func if isinstance(decorator, ast.Call) else decorator + decorator_name = self._canonical_name(dotted_name(decorator_call)) + if decorator_name in {"langchain.tools.tool", "langchain_core.tools.tool"}: + self._emit( + node, + severity="INFO", + category="ASYNC_TOOL_DEFINITION", + symbol=decorator_name, + message="Defines an async LangChain tool; sync clients need a wrapper before invoke().", + ) + return + + def _check_call(self, node: ast.Call, call_name: str) -> None: + rule = EXACT_CALL_RULES.get(call_name) + if rule: + self._emit_rule(node, call_name, rule) + + if call_name.endswith(".run_until_complete"): + self._emit( + node, + severity="WARN", + category="RUN_UNTIL_COMPLETE", + symbol=call_name, + message="Drives an event loop from synchronous code; review nested-loop behavior.", + ) + + if self._is_executor_submit(node, call_name): + self._emit( + node, + severity="INFO", + category="EXECUTOR_SUBMIT", + symbol=call_name, + message="Submits work to an executor; review context propagation and cancellation.", + ) + + if call_name in ASYNC_TOOL_FACTORY_CALLS: + if any(keyword.arg == "coroutine" and not is_none_node(keyword.value) for keyword in node.keywords): + self._emit( + node, + severity="INFO", + category="ASYNC_ONLY_TOOL_FACTORY", + symbol=call_name, + message="Creates a StructuredTool from a coroutine; sync clients need a wrapper.", + ) + + if self.in_async_context and call_name in ASYNC_BLOCKING_CALL_RULES: + self._emit_rule(node, call_name, ASYNC_BLOCKING_CALL_RULES[call_name]) + + if self.in_async_context and self._is_langchain_invoke(node, call_name, method_name="invoke"): + self._emit( + node, + severity="WARN", + category="SYNC_INVOKE_IN_ASYNC", + symbol=call_name, + message="Calls a synchronous invoke() from async code; review event-loop blocking.", + ) + + if not self.in_async_context and self._is_langchain_invoke(node, call_name, method_name="ainvoke"): + self._emit( + node, + severity="WARN", + category="ASYNC_INVOKE_IN_SYNC", + symbol=call_name, + message="Calls async ainvoke() from sync code; review how the coroutine is awaited.", + ) + + def _canonical_name(self, name: str | None) -> str | None: + if name is None: + return None + parts = name.split(".") + if parts and parts[0] in self.import_aliases: + return ".".join((self.import_aliases[parts[0]], *parts[1:])) + return name + + def _record_executor_targets(self, value: ast.AST, targets: Sequence[ast.AST]) -> None: + if not isinstance(value, ast.Call): + return + call_name = self._canonical_name(dotted_name(value.func)) + if call_name not in THREAD_POOL_CONSTRUCTORS: + return + for target in targets: + for name in self._target_names(target): + self.executor_names.add(name) + + def _target_names(self, target: ast.AST) -> Iterable[str]: + if isinstance(target, ast.Name): + yield target.id + elif isinstance(target, (ast.Tuple, ast.List)): + for element in target.elts: + yield from self._target_names(element) + + def _is_executor_submit(self, node: ast.Call, call_name: str) -> bool: + if not call_name.endswith(".submit"): + return False + receiver_name = call_receiver_name(node) + return receiver_name in self.executor_names + + def _is_langchain_invoke(self, node: ast.Call, call_name: str, *, method_name: str) -> bool: + if not call_name.endswith(f".{method_name}"): + return False + receiver_name = call_receiver_name(node) + if receiver_name is None: + return False + receiver_leaf = receiver_name.rsplit(".", 1)[-1] + return receiver_leaf in LANGCHAIN_INVOKE_RECEIVER_NAMES or receiver_leaf.endswith(LANGCHAIN_INVOKE_RECEIVER_SUFFIXES) + + def _emit_rule(self, node: ast.AST, symbol: str, rule: _CallRule) -> None: + self._emit( + node, + severity=rule.severity, + category=rule.category, + symbol=symbol, + message=rule.message, + ) + + def _emit(self, node: ast.AST, *, severity: str, category: str, symbol: str, message: str) -> None: + line = getattr(node, "lineno", 0) + column = getattr(node, "col_offset", 0) + code = "" + if line > 0 and line <= len(self.source_lines): + code = self.source_lines[line - 1].strip() + self.findings.append( + BoundaryFinding( + severity=severity, + category=category, + path=self.relative_path, + line=line, + column=column, + function=self.current_function, + async_context=self.in_async_context, + symbol=symbol, + message=message, + code=code, + ) + ) + + +def relative_to_repo(path: Path, repo_root: Path = REPO_ROOT) -> str: + try: + return path.resolve().relative_to(repo_root.resolve()).as_posix() + except ValueError: + return path.as_posix() + + +def scan_file(path: Path, *, repo_root: Path = REPO_ROOT) -> list[BoundaryFinding]: + source = path.read_text(encoding="utf-8") + source_lines = source.splitlines() + relative_path = relative_to_repo(path, repo_root) + try: + tree = ast.parse(source, filename=str(path)) + except SyntaxError as exc: + line = exc.lineno or 0 + code = source_lines[line - 1].strip() if line > 0 and line <= len(source_lines) else "" + return [ + BoundaryFinding( + severity="WARN", + category="PARSE_ERROR", + path=relative_path, + line=line, + column=max((exc.offset or 1) - 1, 0), + function="", + async_context=False, + symbol="SyntaxError", + message=str(exc), + code=code, + ) + ] + + visitor = BoundaryVisitor(path, relative_path, source_lines) + visitor.visit(tree) + return visitor.findings + + +def is_ignored_path(path: Path) -> bool: + return any(part in IGNORED_DIR_NAMES for part in path.parts) + + +def iter_python_files(paths: Iterable[Path]) -> Iterable[Path]: + for path in paths: + if not path.exists() or is_ignored_path(path): + continue + if path.is_file(): + if path.suffix == ".py" and not is_ignored_path(path): + yield path + continue + for dirpath, dirnames, filenames in os.walk(path): + dirnames[:] = [dirname for dirname in dirnames if dirname not in IGNORED_DIR_NAMES] + for filename in filenames: + if filename.endswith(".py"): + yield Path(dirpath) / filename + + +def scan_paths(paths: Iterable[Path], *, repo_root: Path = REPO_ROOT) -> list[BoundaryFinding]: + findings: list[BoundaryFinding] = [] + for path in sorted(iter_python_files(paths)): + findings.extend(scan_file(path, repo_root=repo_root)) + return sorted(findings, key=lambda finding: (finding.path, finding.line, finding.column, finding.category)) + + +def filter_findings(findings: Iterable[BoundaryFinding], min_severity: str) -> list[BoundaryFinding]: + threshold = SEVERITY_ORDER[min_severity] + return [finding for finding in findings if SEVERITY_ORDER[finding.severity] >= threshold] + + +def format_text(findings: Sequence[BoundaryFinding]) -> str: + if not findings: + return "No async/thread boundary findings." + + lines: list[str] = [] + for finding in findings: + lines.append(f"{finding.severity} {finding.category} {finding.path}:{finding.line}:{finding.column + 1} in {finding.function} async={str(finding.async_context).lower()}") + lines.append(f" symbol: {finding.symbol}") + lines.append(f" note: {finding.message}") + if finding.code: + lines.append(f" code: {finding.code}") + return "\n".join(lines) + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description=("Detect async/thread boundary points for developer review. Findings are an inventory, not automatic bug decisions.")) + parser.add_argument( + "paths", + nargs="*", + type=Path, + help="Files or directories to scan. Defaults to backend app and harness sources.", + ) + parser.add_argument( + "--format", + choices=("text", "json"), + default="text", + help="Output format.", + ) + parser.add_argument( + "--min-severity", + choices=tuple(SEVERITY_ORDER), + default="INFO", + help="Only show findings at or above this severity.", + ) + return parser + + +def main(argv: Sequence[str] | None = None) -> int: + parser = build_parser() + args = parser.parse_args(argv) + paths = args.paths or list(DEFAULT_SCAN_PATHS) + findings = filter_findings(scan_paths(paths), args.min_severity) + + if args.format == "json": + print(json.dumps([finding.to_dict() for finding in findings], indent=2, sort_keys=True)) + else: + print(format_text(findings)) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/backend/tests/test_detect_thread_boundaries.py b/backend/tests/test_detect_thread_boundaries.py new file mode 100644 index 000000000..102613e39 --- /dev/null +++ b/backend/tests/test_detect_thread_boundaries.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +import json +import textwrap +from pathlib import Path + +from support.detectors import thread_boundaries as detector + + +def _write_python(path: Path, source: str) -> Path: + path.write_text(textwrap.dedent(source).strip() + "\n", encoding="utf-8") + return path + + +def test_scan_file_detects_async_thread_and_tool_boundaries(tmp_path): + source_file = _write_python( + tmp_path / "sample.py", + """ + import asyncio + import threading + import time + from concurrent.futures import ThreadPoolExecutor + from langchain.tools import tool + from langchain_core.tools import StructuredTool + + @tool + async def async_tool(value: int) -> str: + return str(value) + + async def handler(model): + await asyncio.to_thread(str, "x") + model.invoke("blocking") + time.sleep(1) + + def sync_entry(): + asyncio.run(handler(None)) + pool = ThreadPoolExecutor(max_workers=1) + pool.submit(str, "x") + threading.Thread(target=sync_entry).start() + return StructuredTool.from_function( + name="factory_tool", + description="factory", + coroutine=async_tool, + ) + """, + ) + + findings = detector.scan_file(source_file, repo_root=tmp_path) + categories = {finding.category for finding in findings} + async_tool_finding = next(finding for finding in findings if finding.category == "ASYNC_TOOL_DEFINITION") + + assert "ASYNC_TOOL_DEFINITION" in categories + assert async_tool_finding.function == "async_tool" + assert async_tool_finding.async_context is True + assert "ASYNC_THREAD_OFFLOAD" in categories + assert "SYNC_INVOKE_IN_ASYNC" in categories + assert "BLOCKING_CALL_IN_ASYNC" in categories + assert "SYNC_ASYNC_BRIDGE" in categories + assert "THREAD_POOL" in categories + assert "EXECUTOR_SUBMIT" in categories + assert "RAW_THREAD" in categories + assert "ASYNC_ONLY_TOOL_FACTORY" in categories + + +def test_scan_file_ignores_unqualified_threads_and_generic_method_names(tmp_path): + source_file = _write_python( + tmp_path / "sample.py", + """ + class Thread: + pass + + class Timer: + pass + + async def handler(form, runner): + form.submit() + runner.invoke("not a langchain model") + + def sync_entry(runner): + Thread() + Timer() + runner.ainvoke("not a langchain model") + """, + ) + + findings = detector.scan_file(source_file, repo_root=tmp_path) + categories = {finding.category for finding in findings} + + assert "RAW_THREAD" not in categories + assert "RAW_TIMER_THREAD" not in categories + assert "EXECUTOR_SUBMIT" not in categories + assert "SYNC_INVOKE_IN_ASYNC" not in categories + assert "ASYNC_INVOKE_IN_SYNC" not in categories + + +def test_scan_file_uses_import_evidence_for_thread_and_executor_aliases(tmp_path): + source_file = _write_python( + tmp_path / "sample.py", + """ + from concurrent.futures import ThreadPoolExecutor as Pool + from threading import Thread as WorkerThread, Timer + + def sync_entry(): + pool = Pool(max_workers=1) + pool.submit(str, "x") + WorkerThread(target=sync_entry).start() + Timer(1, sync_entry).start() + """, + ) + + findings = detector.scan_file(source_file, repo_root=tmp_path) + categories = {finding.category for finding in findings} + + assert "THREAD_POOL" in categories + assert "EXECUTOR_SUBMIT" in categories + assert "RAW_THREAD" in categories + assert "RAW_TIMER_THREAD" in categories + + +def test_scan_paths_ignores_virtualenv_like_directories(tmp_path): + scanned_file = _write_python( + tmp_path / "app.py", + """ + import asyncio + + def main(): + return asyncio.run(asyncio.sleep(0)) + """, + ) + ignored_dir = tmp_path / ".venv" + ignored_dir.mkdir() + _write_python( + ignored_dir / "ignored.py", + """ + import threading + + thread = threading.Thread(target=lambda: None) + """, + ) + + findings = detector.scan_paths([tmp_path], repo_root=tmp_path) + + assert any(finding.path == scanned_file.name for finding in findings) + assert all(".venv" not in finding.path for finding in findings) + + +def test_json_output_and_min_severity_filter(tmp_path, capsys): + source_file = _write_python( + tmp_path / "sample.py", + """ + import asyncio + + async def handler(model): + await asyncio.to_thread(str, "x") + model.invoke("blocking") + """, + ) + + exit_code = detector.main(["--format", "json", "--min-severity", "WARN", str(source_file)]) + + assert exit_code == 0 + payload = json.loads(capsys.readouterr().out) + categories = {finding["category"] for finding in payload} + assert categories == {"SYNC_INVOKE_IN_ASYNC"} + + +def test_parse_errors_are_reported_as_findings(tmp_path): + source_file = _write_python( + tmp_path / "broken.py", + """ + def broken(: + pass + """, + ) + + findings = detector.scan_file(source_file, repo_root=tmp_path) + + assert len(findings) == 1 + assert findings[0].category == "PARSE_ERROR" + assert findings[0].severity == "WARN" + assert findings[0].column == 11 + assert f"{source_file.name}:1:12" in detector.format_text(findings) diff --git a/scripts/detect_thread_boundaries.py b/scripts/detect_thread_boundaries.py new file mode 100644 index 000000000..c4a59132e --- /dev/null +++ b/scripts/detect_thread_boundaries.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +"""CLI wrapper for the async/thread boundary detector.""" + +from __future__ import annotations + +import sys +from collections.abc import Sequence +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[1] +TEST_SUPPORT_PATH = REPO_ROOT / "backend" / "tests" +if str(TEST_SUPPORT_PATH) not in sys.path: + sys.path.insert(0, str(TEST_SUPPORT_PATH)) + + +def main(argv: Sequence[str] | None = None) -> int: + from support.detectors.thread_boundaries import main as detector_main + + return detector_main(argv) + + +if __name__ == "__main__": + sys.exit(main()) From e37912e2c85925fcdc712fa0b50fafe08602f884 Mon Sep 17 00:00:00 2001 From: Xun Date: Wed, 20 May 2026 10:16:31 +0800 Subject: [PATCH 46/86] feat(sandbox) Adds download file interface in Sandbox (#3038) * Add download interface in Sandbox * fix * fix * del invalidate test * fix * safe download * improve --- .../community/aio_sandbox/aio_sandbox.py | 47 ++++++++++ .../deerflow/sandbox/local/local_sandbox.py | 26 ++++++ .../harness/deerflow/sandbox/sandbox.py | 19 ++++ backend/tests/test_aio_sandbox.py | 85 ++++++++++++++++++ .../test_local_sandbox_provider_mounts.py | 88 +++++++++++++++++++ 5 files changed, 265 insertions(+) diff --git a/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox.py b/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox.py index 97da4144d..cdc8e1b77 100644 --- a/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox.py +++ b/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox.py @@ -1,4 +1,5 @@ import base64 +import errno import logging import shlex import threading @@ -6,11 +7,14 @@ import uuid from agent_sandbox import Sandbox as AioSandboxClient +from deerflow.config.paths import VIRTUAL_PATH_PREFIX from deerflow.sandbox.sandbox import Sandbox from deerflow.sandbox.search import GrepMatch, path_matches, should_ignore_path, truncate_line logger = logging.getLogger(__name__) +_MAX_DOWNLOAD_SIZE = 100 * 1024 * 1024 # 100 MB + _ERROR_OBSERVATION_SIGNATURE = "'ErrorObservation' object has no attribute 'exit_code'" @@ -102,6 +106,49 @@ class AioSandbox(Sandbox): logger.error(f"Failed to read file in sandbox: {e}") return f"Error: {e}" + def download_file(self, path: str) -> bytes: + """Download file bytes from the sandbox. + + Raises: + PermissionError: If the path contains '..' traversal segments or is + outside ``VIRTUAL_PATH_PREFIX``. + OSError: If the file cannot be retrieved from the sandbox. + """ + # Reject path traversal before sending to the container API. + # LocalSandbox gets this implicitly via _resolve_path; + # here the path is forwarded verbatim so we must check explicitly. + normalised = path.replace("\\", "/") + for segment in normalised.split("/"): + if segment == "..": + logger.error(f"Refused download due to path traversal: {path}") + raise PermissionError(f"Access denied: path traversal detected in '{path}'") + + stripped_path = normalised.lstrip("/") + allowed_prefix = VIRTUAL_PATH_PREFIX.lstrip("/") + if stripped_path != allowed_prefix and not stripped_path.startswith(f"{allowed_prefix}/"): + logger.error("Refused download outside allowed directory: path=%s, allowed_prefix=%s", path, VIRTUAL_PATH_PREFIX) + raise PermissionError(f"Access denied: path must be under '{VIRTUAL_PATH_PREFIX}': '{path}'") + + with self._lock: + try: + chunks: list[bytes] = [] + total = 0 + for chunk in self._client.file.download_file(path=path): + total += len(chunk) + if total > _MAX_DOWNLOAD_SIZE: + raise OSError( + errno.EFBIG, + f"File exceeds maximum download size of {_MAX_DOWNLOAD_SIZE} bytes", + path, + ) + chunks.append(chunk) + return b"".join(chunks) + except OSError: + raise + except Exception as e: + logger.error(f"Failed to download file in sandbox: {e}") + raise OSError(f"Failed to download file '{path}' from sandbox: {e}") from e + def list_dir(self, path: str, max_depth: int = 2) -> list[str]: """List the contents of a directory in the sandbox. diff --git a/backend/packages/harness/deerflow/sandbox/local/local_sandbox.py b/backend/packages/harness/deerflow/sandbox/local/local_sandbox.py index 62577abb9..0d7682733 100644 --- a/backend/packages/harness/deerflow/sandbox/local/local_sandbox.py +++ b/backend/packages/harness/deerflow/sandbox/local/local_sandbox.py @@ -1,4 +1,5 @@ import errno +import logging import ntpath import os import shutil @@ -7,10 +8,13 @@ from dataclasses import dataclass from pathlib import Path from typing import NamedTuple +from deerflow.config.paths import VIRTUAL_PATH_PREFIX from deerflow.sandbox.local.list_dir import list_dir from deerflow.sandbox.sandbox import Sandbox from deerflow.sandbox.search import GrepMatch, find_glob_matches, find_grep_matches +logger = logging.getLogger(__name__) + @dataclass(frozen=True) class PathMapping: @@ -379,6 +383,28 @@ class LocalSandbox(Sandbox): # Re-raise with the original path for clearer error messages, hiding internal resolved paths raise type(e)(e.errno, e.strerror, path) from None + def download_file(self, path: str) -> bytes: + normalised = path.replace("\\", "/") + stripped_path = normalised.lstrip("/") + allowed_prefix = VIRTUAL_PATH_PREFIX.lstrip("/") + if stripped_path != allowed_prefix and not stripped_path.startswith(f"{allowed_prefix}/"): + logger.error("Refused download outside allowed directory: path=%s, allowed_prefix=%s", path, VIRTUAL_PATH_PREFIX) + raise PermissionError(errno.EACCES, f"Access denied: path must be under '{VIRTUAL_PATH_PREFIX}'", path) + + resolved_path = self._resolve_path(path) + max_download_size = 100 * 1024 * 1024 + try: + file_size = os.path.getsize(resolved_path) + if file_size > max_download_size: + raise OSError(errno.EFBIG, f"File exceeds maximum download size of {max_download_size} bytes", path) + # TOCTOU note: the file could grow between getsize() and read(); accepted + # tradeoff since this is a controlled sandbox environment. + with open(resolved_path, "rb") as f: + return f.read() + except OSError as e: + # Re-raise with the original path for clearer error messages, hiding internal resolved paths + raise type(e)(e.errno, e.strerror, path) from None + def write_file(self, path: str, content: str, append: bool = False) -> None: resolved = self._resolve_path_with_mapping(path) resolved_path = resolved.path diff --git a/backend/packages/harness/deerflow/sandbox/sandbox.py b/backend/packages/harness/deerflow/sandbox/sandbox.py index dc567b503..50322f419 100644 --- a/backend/packages/harness/deerflow/sandbox/sandbox.py +++ b/backend/packages/harness/deerflow/sandbox/sandbox.py @@ -39,6 +39,25 @@ class Sandbox(ABC): """ pass + @abstractmethod + def download_file(self, path: str) -> bytes: + """Download the binary content of a file. + + Args: + path: The absolute path of the file to download. + + Returns: + Raw file bytes. + + Raises: + PermissionError: If path traversal is detected or the path is outside + the allowed virtual prefix. + OSError: If the file cannot be read or does not exist. Both local + and remote implementations must raise ``OSError`` so callers + have a single exception type to handle. + """ + pass + @abstractmethod def list_dir(self, path: str, max_depth=2) -> list[str]: """List the contents of a directory. diff --git a/backend/tests/test_aio_sandbox.py b/backend/tests/test_aio_sandbox.py index c6acb46eb..3b0a44f05 100644 --- a/backend/tests/test_aio_sandbox.py +++ b/backend/tests/test_aio_sandbox.py @@ -233,3 +233,88 @@ class TestConcurrentFileWrites: thread.join() assert storage["content"] in {"seed\nA\nB\n", "seed\nB\nA\n"} + + +class TestDownloadFile: + """Tests for AioSandbox.download_file.""" + + def test_returns_concatenated_bytes(self, sandbox): + """download_file should join chunks from the client iterator into bytes.""" + sandbox._client.file.download_file = MagicMock(return_value=[b"hel", b"lo"]) + + result = sandbox.download_file("/mnt/user-data/outputs/file.bin") + + assert result == b"hello" + sandbox._client.file.download_file.assert_called_once_with(path="/mnt/user-data/outputs/file.bin") + + def test_returns_empty_bytes_for_empty_file(self, sandbox): + """download_file should return b'' when the iterator yields nothing.""" + sandbox._client.file.download_file = MagicMock(return_value=iter([])) + + result = sandbox.download_file("/mnt/user-data/outputs/empty.bin") + + assert result == b"" + + def test_uses_lock_during_download(self, sandbox): + """download_file should hold the lock while calling the client.""" + lock_was_held = [] + + def tracking_download(path): + lock_was_held.append(sandbox._lock.locked()) + return iter([b"data"]) + + sandbox._client.file.download_file = tracking_download + + sandbox.download_file("/mnt/user-data/outputs/file.bin") + + assert lock_was_held == [True], "download_file must hold the lock during client call" + + def test_raises_oserror_on_client_error(self, sandbox): + """download_file should wrap client exceptions as OSError.""" + sandbox._client.file.download_file = MagicMock(side_effect=RuntimeError("network error")) + + with pytest.raises(OSError, match="network error"): + sandbox.download_file("/mnt/user-data/outputs/file.bin") + + def test_preserves_oserror_from_client(self, sandbox): + """OSError raised by the client should propagate without re-wrapping.""" + sandbox._client.file.download_file = MagicMock(side_effect=OSError("disk error")) + + with pytest.raises(OSError, match="disk error"): + sandbox.download_file("/mnt/user-data/outputs/file.bin") + + def test_rejects_path_outside_virtual_prefix_and_logs_error(self, sandbox, caplog): + """download_file must reject downloads outside /mnt/user-data and log the reason.""" + sandbox._client.file.download_file = MagicMock() + + with caplog.at_level("ERROR"): + with pytest.raises(PermissionError, match="must be under"): + sandbox.download_file("/etc/passwd") + + assert "outside allowed directory" in caplog.text + sandbox._client.file.download_file.assert_not_called() + + @pytest.mark.parametrize( + "path", + [ + "/mnt/workspace/../../etc/passwd", + "../secret", + "/a/b/../../../etc/shadow", + ], + ) + def test_rejects_path_traversal(self, sandbox, path): + """download_file must reject paths containing '..' before calling the client.""" + sandbox._client.file.download_file = MagicMock() + + with pytest.raises(PermissionError, match="path traversal"): + sandbox.download_file(path) + + sandbox._client.file.download_file.assert_not_called() + + def test_single_chunk(self, sandbox): + """download_file should work correctly with a single-chunk response.""" + sandbox._client.file.download_file = MagicMock(return_value=[b"single-chunk"]) + + result = sandbox.download_file("/mnt/user-data/outputs/single.bin") + + assert result == b"single-chunk" diff --git a/backend/tests/test_local_sandbox_provider_mounts.py b/backend/tests/test_local_sandbox_provider_mounts.py index 5e7a06b6d..add5c4ea6 100644 --- a/backend/tests/test_local_sandbox_provider_mounts.py +++ b/backend/tests/test_local_sandbox_provider_mounts.py @@ -204,6 +204,26 @@ class TestSymlinkEscapes: assert exc_info.value.errno == errno.EACCES + def test_download_file_blocks_symlink_escape_from_mount(self, tmp_path): + mount_dir = tmp_path / "mount" + mount_dir.mkdir() + outside_dir = tmp_path / "outside" + outside_dir.mkdir() + (outside_dir / "secret.bin").write_bytes(b"\x00secret") + _symlink_to(outside_dir, mount_dir / "escape", target_is_directory=True) + + sandbox = LocalSandbox( + "test", + [ + PathMapping(container_path="/mnt/user-data", local_path=str(mount_dir), read_only=False), + ], + ) + + with pytest.raises(PermissionError) as exc_info: + sandbox.download_file("/mnt/user-data/escape/secret.bin") + + assert exc_info.value.errno == errno.EACCES + def test_write_file_blocks_symlink_escape_from_mount(self, tmp_path): mount_dir = tmp_path / "mount" mount_dir.mkdir() @@ -334,6 +354,74 @@ class TestSymlinkEscapes: assert existing.read_bytes() == b"original" +class TestDownloadFileMappings: + """download_file must use _resolve_path_with_mapping so path resolution, symlink + containment, and read-only awareness are consistent with read_file.""" + + def test_resolves_container_path_via_mapping(self, tmp_path): + """download_file should resolve container paths through path mappings.""" + data_dir = tmp_path / "data" + data_dir.mkdir() + (data_dir / "asset.bin").write_bytes(b"\x01\x02\x03") + + sandbox = LocalSandbox( + "test", + [PathMapping(container_path="/mnt/user-data", local_path=str(data_dir))], + ) + + result = sandbox.download_file("/mnt/user-data/asset.bin") + + assert result == b"\x01\x02\x03" + + def test_raises_oserror_with_original_path_when_missing(self, tmp_path): + """OSError filename should show the container path, not the resolved host path.""" + data_dir = tmp_path / "data" + data_dir.mkdir() + + sandbox = LocalSandbox( + "test", + [PathMapping(container_path="/mnt/user-data", local_path=str(data_dir))], + ) + + with pytest.raises(OSError) as exc_info: + sandbox.download_file("/mnt/user-data/missing.bin") + + assert exc_info.value.filename == "/mnt/user-data/missing.bin" + + def test_rejects_path_outside_virtual_prefix_and_logs_error(self, tmp_path, caplog): + """download_file must reject paths outside /mnt/user-data and log the reason.""" + data_dir = tmp_path / "data" + data_dir.mkdir() + (data_dir / "model.bin").write_bytes(b"weights") + + sandbox = LocalSandbox( + "test", + [PathMapping(container_path="/mnt/user-data", local_path=str(data_dir), read_only=True)], + ) + + with caplog.at_level("ERROR"): + with pytest.raises(PermissionError) as exc_info: + sandbox.download_file("/mnt/skills/model.bin") + + assert exc_info.value.errno == errno.EACCES + assert "outside allowed directory" in caplog.text + + def test_readable_from_read_only_mount(self, tmp_path): + """Read-only mounts must not block download_file — read-only only restricts writes.""" + skills_dir = tmp_path / "skills" + skills_dir.mkdir() + (skills_dir / "model.bin").write_bytes(b"weights") + + sandbox = LocalSandbox( + "test", + [PathMapping(container_path="/mnt/user-data", local_path=str(skills_dir), read_only=True)], + ) + + result = sandbox.download_file("/mnt/user-data/model.bin") + + assert result == b"weights" + + class TestMultipleMounts: def test_multiple_read_write_mounts(self, tmp_path): skills_dir = tmp_path / "skills" From 8cd4710b169f6e01f9392a3e1e62fcdeb58fea3b Mon Sep 17 00:00:00 2001 From: john lee <64lamei@gmail.com> Date: Wed, 20 May 2026 10:43:18 +0800 Subject: [PATCH 47/86] fix(deploy): fall back to python/openssl when python3 is absent for secret generation (#3074) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(deploy): fall back to python/openssl when python3 is absent for secret generation Bare python3 call in deploy.sh exits 49 on systems without python3 in PATH (e.g. some Alpine/minimal containers, or Windows environments where only 'python' is on PATH). Add a fallback chain: python3 → python → openssl rand -hex 32. If all three are unavailable, emit a clear error message and exit with a non-zero status instead of a cryptic recipe failure. Closes #2922 Co-Authored-By: Claude Sonnet 4.6 * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Claude Sonnet 4.6 Co-authored-by: Willem Jiang Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- scripts/deploy.sh | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/scripts/deploy.sh b/scripts/deploy.sh index b4b030d4b..41c9dfa3f 100755 --- a/scripts/deploy.sh +++ b/scripts/deploy.sh @@ -120,7 +120,20 @@ if [ -z "$BETTER_AUTH_SECRET" ]; then echo -e "${GREEN}✓ BETTER_AUTH_SECRET loaded from $_secret_file${NC}" else export BETTER_AUTH_SECRET - BETTER_AUTH_SECRET="$(python3 -c 'import secrets; print(secrets.token_hex(32))')" + if command -v python3 > /dev/null 2>&1 && \ + BETTER_AUTH_SECRET="$(python3 -c 'import sys; sys.version_info >= (3, 6) or sys.exit(1); import secrets; print(secrets.token_hex(32))' 2>/dev/null)"; then + true + elif command -v python > /dev/null 2>&1 && \ + BETTER_AUTH_SECRET="$(python -c 'import sys; sys.version_info >= (3, 6) or sys.exit(1); import secrets; print(secrets.token_hex(32))' 2>/dev/null)"; then + true + elif command -v openssl > /dev/null 2>&1 && \ + BETTER_AUTH_SECRET="$(openssl rand -hex 32)"; then + true + else + echo -e "${RED}✗ Cannot generate BETTER_AUTH_SECRET: python3, python, and openssl are all unavailable.${NC}" >&2 + echo -e "${RED} Set BETTER_AUTH_SECRET manually before running make up.${NC}" >&2 + exit 1 + fi echo "$BETTER_AUTH_SECRET" > "$_secret_file" chmod 600 "$_secret_file" echo -e "${GREEN}✓ BETTER_AUTH_SECRET generated → $_secret_file${NC}" From 6b922e490837c8fd5fbedab58def88069ed933d3 Mon Sep 17 00:00:00 2001 From: AochenShen99 <142667174+ShenAC-SAC@users.noreply.github.com> Date: Wed, 20 May 2026 14:52:58 +0800 Subject: [PATCH 48/86] test(runtime): add lifecycle e2e coverage (#2946) * test(runtime): add lifecycle e2e coverage * test: isolate runtime lifecycle e2e config * test(runtime): document lifecycle e2e tradeoffs --- backend/tests/test_runtime_lifecycle_e2e.py | 686 ++++++++++++++++++++ 1 file changed, 686 insertions(+) create mode 100644 backend/tests/test_runtime_lifecycle_e2e.py diff --git a/backend/tests/test_runtime_lifecycle_e2e.py b/backend/tests/test_runtime_lifecycle_e2e.py new file mode 100644 index 000000000..1eda351ec --- /dev/null +++ b/backend/tests/test_runtime_lifecycle_e2e.py @@ -0,0 +1,686 @@ +"""HTTP/runtime lifecycle E2E tests for the Gateway-owned runs API. + +These tests keep the external model out of scope while exercising the real +FastAPI app, auth middleware, lifespan-created runtime dependencies, +``start_run()``, ``run_agent()``, StreamBridge, checkpointer, run store, and +thread metadata store. +""" + +from __future__ import annotations + +import asyncio +import inspect +import json +import queue +import threading +import time +import uuid +from contextlib import suppress +from pathlib import Path +from typing import Any +from unittest.mock import patch + +import pytest +from _agent_e2e_helpers import FakeToolCallingModel, build_single_tool_call_model +from langchain_core.messages import AIMessage, HumanMessage + +pytestmark = pytest.mark.no_auto_user + + +_MINIMAL_CONFIG_YAML = """\ +log_level: info +models: + - name: fake-test-model + display_name: Fake Test Model + use: langchain_openai:ChatOpenAI + model: gpt-4o-mini + api_key: $OPENAI_API_KEY + base_url: $OPENAI_API_BASE +sandbox: + use: deerflow.sandbox.local:LocalSandboxProvider +agents_api: + enabled: true +title: + enabled: false +memory: + enabled: false +database: + backend: sqlite +run_events: + backend: memory +""" + + +class _RunController: + """Cross-thread controls for the fake async agent.""" + + def __init__(self) -> None: + self.started = threading.Event() + self.checkpoint_written = threading.Event() + self.cancelled = threading.Event() + self.release = threading.Event() + self.instances: list[_ScriptedAgent] = [] + + +class _ScriptedAgent: + """Deterministic runtime double for lifecycle-only tests. + + This is intentionally not a full LangGraph graph. Tests that need + controllable blocking, cancellation, and rollback checkpoints use the small + ``run_agent`` surface they exercise: ``astream()``, checkpointer/store + attachment, metadata, and interrupt node attributes. The real lead-agent + graph/tool dispatch path is covered separately by + ``test_stream_run_executes_real_lead_agent_setup_agent_business_path``. + """ + + def __init__( + self, + controller: _RunController, + *, + title: str, + answer: str, + block_after_first_chunk: bool = False, + ) -> None: + self.controller = controller + self.title = title + self.answer = answer + self.block_after_first_chunk = block_after_first_chunk + self.checkpointer: Any | None = None + self.store: Any | None = None + self.metadata = {"model_name": "fake-test-model"} + self.interrupt_before_nodes = None + self.interrupt_after_nodes = None + self.model = FakeToolCallingModel(responses=[AIMessage(content=self.answer)]) + + async def astream(self, graph_input, config=None, stream_mode=None, subgraphs=False): + del subgraphs + self.controller.started.set() + + thread_id = _thread_id_from_config(config) + human_text = _last_human_text(graph_input) + human = HumanMessage(content=human_text) + ai = await self.model.ainvoke([human], config=config) + state = {"messages": [human.model_dump(), ai.model_dump()], "title": self.title} + + if self.checkpointer is not None: + await _write_checkpoint(self.checkpointer, thread_id=thread_id, state=state) + self.controller.checkpoint_written.set() + + yield _stream_item_for_mode(stream_mode, state) + + if self.block_after_first_chunk: + try: + while not self.controller.release.is_set(): + await asyncio.sleep(0.05) + except asyncio.CancelledError: + self.controller.cancelled.set() + raise + + +def _make_agent_factory(controller: _RunController, **agent_kwargs): + def factory(*, config): + del config + agent = _ScriptedAgent(controller, **agent_kwargs) + controller.instances.append(agent) + return agent + + return factory + + +def _build_fake_setup_agent_model(agent_name: str): + """Patch target for lead_agent.agent.create_chat_model. + + The graph, tool registry, ToolNode dispatch, and setup_agent implementation + remain production code; this fake only replaces the external LLM call. + """ + + def fake_create_chat_model(*args: Any, **kwargs: Any) -> FakeToolCallingModel: + del args, kwargs + return build_single_tool_call_model( + tool_name="setup_agent", + tool_args={ + "soul": f"# Runtime Business E2E\n\nAgent name: {agent_name}", + "description": "runtime lifecycle business path", + }, + tool_call_id="call_runtime_business_1", + final_text=f"Created {agent_name} through the real setup_agent tool.", + ) + + return fake_create_chat_model + + +@pytest.fixture +def isolated_deer_flow_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Path: + home = tmp_path / "deer-flow-home" + home.mkdir() + monkeypatch.setenv("DEER_FLOW_HOME", str(home)) + monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-key-not-used") + monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid") + + staged_config = tmp_path / "config.yaml" + staged_config.write_text(_MINIMAL_CONFIG_YAML, encoding="utf-8") + monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(staged_config)) + + staged_extensions_config = tmp_path / "extensions_config.json" + staged_extensions_config.write_text('{"mcpServers": {}, "skills": {}}', encoding="utf-8") + monkeypatch.setenv("DEER_FLOW_EXTENSIONS_CONFIG_PATH", str(staged_extensions_config)) + return home + + +def _reset_process_singletons(monkeypatch: pytest.MonkeyPatch) -> None: + """Clear runtime singletons that depend on this test's temporary config. + + The Gateway app/lifespan path reads process-wide caches before wiring + request-scoped dependencies. These E2E tests stage a temporary + ``config.yaml``/``extensions_config.json`` and ``DEER_FLOW_HOME``, so the + caches below must be reset before app creation: + + - app_config / extensions_config: parsed config file caches. + - paths: ``DEER_FLOW_HOME``-derived filesystem paths. + - persistence.engine: SQLAlchemy engine/session factory for the sqlite dir. + - app.gateway.deps: cached local auth provider/repository. + + A shared public reset helper would be cleaner long-term; this test keeps + the reset boundary explicit because the PR is focused on runtime lifecycle + coverage rather than config-cache API cleanup. + """ + + from app.gateway import deps as deps_module + from deerflow.config import app_config as app_config_module + from deerflow.config import extensions_config as extensions_config_module + from deerflow.config import paths as paths_module + from deerflow.persistence import engine as engine_module + + for module, attr, value in ( + (app_config_module, "_app_config", None), + (app_config_module, "_app_config_path", None), + (app_config_module, "_app_config_mtime", None), + (app_config_module, "_app_config_is_custom", False), + (extensions_config_module, "_extensions_config", None), + (paths_module, "_paths_singleton", None), + (paths_module, "_paths", None), + (engine_module, "_engine", None), + (engine_module, "_session_factory", None), + (deps_module, "_cached_local_provider", None), + (deps_module, "_cached_repo", None), + ): + monkeypatch.setattr(module, attr, value, raising=False) + + +def _preserve_process_config_singletons(monkeypatch: pytest.MonkeyPatch) -> None: + """Restore config singletons mutated as a side effect of AppConfig loading. + + ``AppConfig.from_file()`` calls ``_apply_singleton_configs()``, which pushes + nested config sections into module-level caches used by middlewares, tool + selection, and runtime providers. Snapshotting those attributes with + ``monkeypatch`` lets pytest restore the pre-test values during teardown, so + loading the isolated test config does not leak into later tests. + """ + + from deerflow.config import ( + acp_config, + agents_api_config, + checkpointer_config, + guardrails_config, + memory_config, + stream_bridge_config, + subagents_config, + summarization_config, + title_config, + tool_search_config, + ) + + for module, attr in ( + (title_config, "_title_config"), + (summarization_config, "_summarization_config"), + (memory_config, "_memory_config"), + (agents_api_config, "_agents_api_config"), + (subagents_config, "_subagents_config"), + (tool_search_config, "_tool_search_config"), + (guardrails_config, "_guardrails_config"), + (checkpointer_config, "_checkpointer_config"), + (stream_bridge_config, "_stream_bridge_config"), + (acp_config, "_acp_agents"), + ): + monkeypatch.setattr(module, attr, getattr(module, attr), raising=False) + + +@pytest.fixture +def isolated_app(isolated_deer_flow_home: Path, monkeypatch: pytest.MonkeyPatch): + _preserve_process_config_singletons(monkeypatch) + _reset_process_singletons(monkeypatch) + + from deerflow.config import app_config as app_config_module + + cfg = app_config_module.get_app_config() + cfg.database.sqlite_dir = str(isolated_deer_flow_home / "db") + + from app.gateway.app import create_app + + return create_app() + + +def _register_user(client, *, email: str = "runtime-e2e@example.com") -> str: + response = client.post( + "/api/v1/auth/register", + json={"email": email, "password": "very-strong-password-123"}, + ) + assert response.status_code == 201, response.text + csrf_token = client.cookies.get("csrf_token") + assert csrf_token + return csrf_token + + +def _create_thread(client, csrf_token: str) -> str: + thread_id = str(uuid.uuid4()) + response = client.post( + "/api/threads", + json={"thread_id": thread_id, "metadata": {"purpose": "runtime-lifecycle-e2e"}}, + headers={"X-CSRF-Token": csrf_token}, + ) + assert response.status_code == 200, response.text + return thread_id + + +def _run_body(**overrides) -> dict[str, Any]: + body: dict[str, Any] = { + "assistant_id": "lead_agent", + "input": {"messages": [{"role": "user", "content": "Run lifecycle E2E prompt"}]}, + "config": {"recursion_limit": 50}, + "stream_mode": ["values"], + } + body.update(overrides) + return body + + +def _drain_stream(response, *, timeout: float = 10.0, max_bytes: int = 1024 * 1024) -> str: + chunks: queue.Queue[bytes | BaseException | object] = queue.Queue() + sentinel = object() + + def read_stream() -> None: + try: + for chunk in response.iter_bytes(): + chunks.put(chunk) + if b"event: end" in chunk: + break + except BaseException as exc: # pragma: no cover - reported in the main test thread + chunks.put(exc) + finally: + chunks.put(sentinel) + + reader = threading.Thread(target=read_stream, daemon=True) + reader.start() + + deadline = time.monotonic() + timeout + body = b"" + while True: + remaining = deadline - time.monotonic() + if remaining <= 0: + raise AssertionError(f"SSE stream did not finish within {timeout}s; transcript tail={body[-4000:].decode('utf-8', errors='replace')}") + try: + chunk = chunks.get(timeout=remaining) + except queue.Empty as exc: + raise AssertionError(f"SSE stream did not produce data within {timeout}s; transcript tail={body[-4000:].decode('utf-8', errors='replace')}") from exc + if chunk is sentinel: + break + if isinstance(chunk, BaseException): + raise AssertionError("SSE reader failed") from chunk + body += chunk + if b"event: end" in body: + break + if len(body) >= max_bytes: + raise AssertionError(f"SSE stream exceeded {max_bytes} bytes without event: end") + if b"event: end" not in body: + raise AssertionError(f"SSE stream closed before event: end; transcript tail={body[-4000:].decode('utf-8', errors='replace')}") + return body.decode("utf-8", errors="replace") + + +def _parse_sse(transcript: str) -> list[dict[str, Any]]: + events: list[dict[str, Any]] = [] + for raw_frame in transcript.split("\n\n"): + frame = raw_frame.strip() + if not frame or frame.startswith(":"): + continue + parsed: dict[str, Any] = {} + for line in frame.splitlines(): + if line.startswith("event: "): + parsed["event"] = line.removeprefix("event: ") + elif line.startswith("data: "): + payload = line.removeprefix("data: ") + parsed["data"] = json.loads(payload) + elif line.startswith("id: "): + parsed["id"] = line.removeprefix("id: ") + if parsed: + events.append(parsed) + return events + + +def _run_id_from_response(response) -> str: + location = response.headers.get("content-location", "") + assert location, "run stream response must include Content-Location" + return location.rstrip("/").split("/")[-1] + + +def _wait_for_status(client, thread_id: str, run_id: str, status: str, *, timeout: float = 5.0) -> dict: + deadline = time.monotonic() + timeout + last: dict | None = None + while time.monotonic() < deadline: + response = client.get(f"/api/threads/{thread_id}/runs/{run_id}") + assert response.status_code == 200, response.text + last = response.json() + if last["status"] == status: + return last + time.sleep(0.05) + raise AssertionError(f"Run {run_id} did not reach {status!r}; last={last!r}") + + +def _thread_id_from_config(config: dict | None) -> str: + config = config or {} + context = config.get("context") if isinstance(config.get("context"), dict) else {} + configurable = config.get("configurable") if isinstance(config.get("configurable"), dict) else {} + thread_id = context.get("thread_id") or configurable.get("thread_id") + assert thread_id, f"runtime config did not contain thread_id: {config!r}" + return str(thread_id) + + +def _last_human_text(graph_input: dict) -> str: + messages = graph_input.get("messages") or [] + if not messages: + return "" + last = messages[-1] + content = getattr(last, "content", last) + if isinstance(content, str): + return content + return str(content) + + +async def _write_checkpoint(checkpointer: Any, *, thread_id: str, state: dict[str, Any]) -> None: + from langgraph.checkpoint.base import empty_checkpoint + + checkpoint = empty_checkpoint() + checkpoint["channel_values"] = dict(state) + checkpoint["channel_versions"] = {key: 1 for key in state} + config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} + metadata = { + "source": "loop", + "step": 1, + "writes": {"scripted_agent": {"title": state.get("title"), "message_count": len(state.get("messages", []))}}, + "parents": {}, + } + + result = checkpointer.aput(config, checkpoint, metadata, {}) + if inspect.isawaitable(result): + await result + + +def _stream_item_for_mode(stream_mode: Any, state: dict[str, Any]) -> Any: + if isinstance(stream_mode, list): + # ``run_agent`` passes a list when multiple modes/subgraphs are active. + return stream_mode[0], state + return state + + +def test_stream_run_completes_and_persists_runtime_state(isolated_app): + """A streaming run should traverse the real runtime and leave state behind.""" + from starlette.testclient import TestClient + + controller = _RunController() + factory = _make_agent_factory( + controller, + title="Lifecycle E2E", + answer="Lifecycle complete.", + ) + + with ( + patch("app.gateway.services.resolve_agent_factory", return_value=factory), + TestClient(isolated_app) as client, + ): + csrf_token = _register_user(client) + thread_id = _create_thread(client, csrf_token) + + with client.stream( + "POST", + f"/api/threads/{thread_id}/runs/stream", + json=_run_body(), + headers={"X-CSRF-Token": csrf_token}, + ) as response: + assert response.status_code == 200, response.read().decode() + run_id = _run_id_from_response(response) + transcript = _drain_stream(response) + + events = _parse_sse(transcript) + assert [event["event"] for event in events] == ["metadata", "values", "end"] + assert events[0]["data"] == {"run_id": run_id, "thread_id": thread_id} + assert events[1]["data"]["title"] == "Lifecycle E2E" + assert events[1]["data"]["messages"][-1]["content"] == "Lifecycle complete." + + run = client.get(f"/api/threads/{thread_id}/runs/{run_id}") + assert run.status_code == 200, run.text + assert run.json()["status"] == "success" + + thread = client.get(f"/api/threads/{thread_id}") + assert thread.status_code == 200, thread.text + assert thread.json()["status"] == "idle" + assert thread.json()["values"]["title"] == "Lifecycle E2E" + + messages = client.get(f"/api/threads/{thread_id}/runs/{run_id}/messages") + assert messages.status_code == 200, messages.text + message_events = messages.json()["data"] + event_types = [row["event_type"] for row in message_events] + assert "llm.human.input" in event_types + assert "llm.ai.response" in event_types + assert any(row["content"]["content"] == "Run lifecycle E2E prompt" for row in message_events if row["event_type"] == "llm.human.input") + assert any(row["content"]["content"] == "Lifecycle complete." for row in message_events if row["event_type"] == "llm.ai.response") + + +def test_stream_run_executes_real_lead_agent_setup_agent_business_path(isolated_app, isolated_deer_flow_home: Path): + """A runtime stream should execute real lead-agent business code and tools.""" + from starlette.testclient import TestClient + + agent_name = "runtime-business-agent" + + with ( + patch( + "deerflow.agents.lead_agent.agent.create_chat_model", + new=_build_fake_setup_agent_model(agent_name), + ), + TestClient(isolated_app) as client, + ): + csrf_token = _register_user(client, email="business-e2e@example.com") + auth_user_id = client.get("/api/v1/auth/me").json()["id"] + thread_id = _create_thread(client, csrf_token) + + body = _run_body( + input={ + "messages": [ + { + "role": "user", + "content": f"Create a custom agent named {agent_name}.", + } + ] + }, + context={ + "agent_name": agent_name, + "is_bootstrap": True, + "thinking_enabled": False, + "is_plan_mode": False, + "subagent_enabled": False, + }, + ) + + with client.stream( + "POST", + f"/api/threads/{thread_id}/runs/stream", + json=body, + headers={"X-CSRF-Token": csrf_token}, + ) as response: + assert response.status_code == 200, response.read().decode() + run_id = _run_id_from_response(response) + transcript = _drain_stream(response, timeout=20.0) + + events = _parse_sse(transcript) + event_names = [event["event"] for event in events] + assert "metadata" in event_names + assert "error" not in event_names, transcript + assert event_names[-1] == "end" + + run = _wait_for_status(client, thread_id, run_id, "success", timeout=10.0) + assert run["assistant_id"] == "lead_agent" + + expected_soul = isolated_deer_flow_home / "users" / auth_user_id / "agents" / agent_name / "SOUL.md" + assert expected_soul.exists(), f"setup_agent did not write SOUL.md. tmp tree: {sorted(str(p.relative_to(isolated_deer_flow_home)) for p in isolated_deer_flow_home.rglob('SOUL.md'))}" + assert f"Agent name: {agent_name}" in expected_soul.read_text(encoding="utf-8") + assert not (isolated_deer_flow_home / "users" / "default" / "agents" / agent_name).exists() + + +def test_cancel_interrupt_stops_running_background_run(isolated_app): + """HTTP cancel?action=interrupt should stop the worker and persist interruption.""" + from starlette.testclient import TestClient + + controller = _RunController() + factory = _make_agent_factory( + controller, + title="Interrupt candidate", + answer="This run should be interrupted.", + block_after_first_chunk=True, + ) + + with ( + patch("app.gateway.services.resolve_agent_factory", return_value=factory), + TestClient(isolated_app) as client, + ): + csrf_token = _register_user(client, email="interrupt-e2e@example.com") + thread_id = _create_thread(client, csrf_token) + + created = client.post( + f"/api/threads/{thread_id}/runs", + json=_run_body(), + headers={"X-CSRF-Token": csrf_token}, + ) + assert created.status_code == 200, created.text + run_id = created.json()["run_id"] + assert controller.started.wait(5), "fake agent never started" + + cancelled = client.post( + f"/api/threads/{thread_id}/runs/{run_id}/cancel?wait=true&action=interrupt", + headers={"X-CSRF-Token": csrf_token}, + ) + assert cancelled.status_code == 204, cancelled.text + assert controller.cancelled.wait(5), "fake agent task was not cancelled" + + run = _wait_for_status(client, thread_id, run_id, "interrupted") + assert run["status"] == "interrupted" + + thread = client.get(f"/api/threads/{thread_id}") + assert thread.status_code == 200, thread.text + assert thread.json()["status"] == "idle" + + +@pytest.mark.anyio +async def test_sse_consumer_disconnect_cancels_inflight_run(): + """A disconnected SSE request should cancel an in-flight run when configured.""" + from app.gateway.services import sse_consumer + from deerflow.runtime import DisconnectMode, MemoryStreamBridge, RunManager, RunStatus + + bridge = MemoryStreamBridge() + run_manager = RunManager() + record = await run_manager.create("thread-disconnect", on_disconnect=DisconnectMode.cancel) + await run_manager.set_status(record.run_id, RunStatus.running) + await bridge.publish(record.run_id, "metadata", {"run_id": record.run_id, "thread_id": record.thread_id}) + worker_started = asyncio.Event() + worker_cancelled = asyncio.Event() + + async def _pending_worker() -> None: + try: + worker_started.set() + await asyncio.Event().wait() + except asyncio.CancelledError: + worker_cancelled.set() + raise + + record.task = asyncio.create_task(_pending_worker()) + await asyncio.wait_for(worker_started.wait(), timeout=1.0) + + class _DisconnectedRequest: + headers: dict[str, str] = {} + + async def is_disconnected(self) -> bool: + return True + + try: + frames = [] + async for frame in sse_consumer(bridge, record, _DisconnectedRequest(), run_manager): + frames.append(frame) + + assert frames == [] + assert record.abort_event.is_set() + assert record.status == RunStatus.interrupted + await asyncio.wait_for(worker_cancelled.wait(), timeout=1.0) + assert record.task.cancelled() + finally: + if record.task is not None and not record.task.done(): + record.task.cancel() + with suppress(asyncio.CancelledError): + await record.task + + +def test_cancel_rollback_restores_pre_run_checkpoint(isolated_app): + """HTTP cancel?action=rollback should restore the checkpoint captured before run start.""" + from starlette.testclient import TestClient + + controller = _RunController() + factory = _make_agent_factory( + controller, + title="During rollback run", + answer="This answer should be rolled back.", + block_after_first_chunk=True, + ) + + with ( + patch("app.gateway.services.resolve_agent_factory", return_value=factory), + TestClient(isolated_app) as client, + ): + csrf_token = _register_user(client, email="rollback-e2e@example.com") + thread_id = _create_thread(client, csrf_token) + + before = client.post( + f"/api/threads/{thread_id}/state", + json={ + "values": { + "title": "Before rollback", + "messages": [{"type": "human", "content": "before"}], + }, + "as_node": "test_seed", + }, + headers={"X-CSRF-Token": csrf_token}, + ) + assert before.status_code == 200, before.text + assert before.json()["values"]["title"] == "Before rollback" + + created = client.post( + f"/api/threads/{thread_id}/runs", + json=_run_body(), + headers={"X-CSRF-Token": csrf_token}, + ) + assert created.status_code == 200, created.text + run_id = created.json()["run_id"] + assert controller.checkpoint_written.wait(5), "fake agent did not write in-run checkpoint" + + during = client.get(f"/api/threads/{thread_id}/state") + assert during.status_code == 200, during.text + assert during.json()["values"]["title"] == "During rollback run" + + rolled_back = client.post( + f"/api/threads/{thread_id}/runs/{run_id}/cancel?wait=true&action=rollback", + headers={"X-CSRF-Token": csrf_token}, + ) + assert rolled_back.status_code == 204, rolled_back.text + assert controller.cancelled.wait(5), "rollback did not cancel the worker task" + + run = _wait_for_status(client, thread_id, run_id, "error") + assert run["status"] == "error" + + after = client.get(f"/api/threads/{thread_id}/state") + assert after.status_code == 200, after.text + assert after.json()["values"]["title"] == "Before rollback" + assert after.json()["values"]["messages"] == [{"type": "human", "content": "before"}] From 9b19cca91c7d33dee2d39607edf19be3ef2e9558 Mon Sep 17 00:00:00 2001 From: john lee <64lamei@gmail.com> Date: Wed, 20 May 2026 16:37:36 +0800 Subject: [PATCH 49/86] fix(runtime): make RunManager.cancel() idempotent for already-interrupted runs (#3055) (#3058) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A second cancel() call on an interrupted run returned False, causing the cancel and stream_existing_run router endpoints to raise 409 on double-stop. Fix: return True inside the lock when record.status == RunStatus.interrupted. This covers both the POST /cancel and POST /join endpoints without any re-fetch or extra get() call — the idempotency lives at the source. Also fixes stream_existing_run (the LangGraph SDK stop-button path), which had the identical cancel() → 409 pattern and was not covered by the original PR. Both endpoints share the fix automatically. Co-authored-by: Claude Sonnet 4.6 --- .../harness/deerflow/runtime/runs/manager.py | 7 +- backend/tests/test_cancel_run_idempotent.py | 142 ++++++++++++++++++ 2 files changed, 148 insertions(+), 1 deletion(-) create mode 100644 backend/tests/test_cancel_run_idempotent.py diff --git a/backend/packages/harness/deerflow/runtime/runs/manager.py b/backend/packages/harness/deerflow/runtime/runs/manager.py index 06731eb91..ea78f89c9 100644 --- a/backend/packages/harness/deerflow/runtime/runs/manager.py +++ b/backend/packages/harness/deerflow/runtime/runs/manager.py @@ -258,12 +258,17 @@ class RunManager: action: "interrupt" keeps checkpoint, "rollback" reverts to pre-run state. Sets the abort event with the action reason and cancels the asyncio task. - Returns ``True`` if the run was in-flight and cancellation was initiated. + Returns ``True`` if cancellation was initiated **or** the run was already + interrupted (idempotent — a second cancel is a no-op success). + Returns ``False`` only when the run is unknown to this worker or has + reached a terminal state other than interrupted (completed, failed, etc.). """ async with self._lock: record = self._runs.get(run_id) if record is None: return False + if record.status == RunStatus.interrupted: + return True # idempotent — already cancelled on this worker if record.status not in (RunStatus.pending, RunStatus.running): return False record.abort_action = action diff --git a/backend/tests/test_cancel_run_idempotent.py b/backend/tests/test_cancel_run_idempotent.py new file mode 100644 index 000000000..0bf2548d1 --- /dev/null +++ b/backend/tests/test_cancel_run_idempotent.py @@ -0,0 +1,142 @@ +"""Tests for idempotent run cancellation (issue #3055). + +RunManager.cancel() returns True when a run is already interrupted so that +a second cancel request from the same worker is treated as a no-op success +(202) rather than a conflict (409). Both the POST cancel endpoint and the +POST stream endpoint share this behaviour through the same cancel() call. +""" + +from __future__ import annotations + +import asyncio + +from _router_auth_helpers import make_authed_test_app +from fastapi.testclient import TestClient + +from app.gateway.routers import thread_runs +from deerflow.runtime import RunManager, RunStatus + +THREAD_ID = "thread-cancel-test" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_app(mgr: RunManager) -> TestClient: + app = make_authed_test_app() + app.include_router(thread_runs.router) + app.state.run_manager = mgr + return TestClient(app, raise_server_exceptions=False) + + +def _create_interrupted_run(mgr: RunManager) -> str: + """Create a run and cancel it, returning its run_id.""" + + async def _setup(): + record = await mgr.create(THREAD_ID) + await mgr.set_status(record.run_id, RunStatus.running) + await mgr.cancel(record.run_id) + return record.run_id + + return asyncio.run(_setup()) + + +# --------------------------------------------------------------------------- +# RunManager.cancel() unit tests +# --------------------------------------------------------------------------- + + +class TestRunManagerCancelIdempotency: + def test_cancel_returns_true_for_already_interrupted_run(self): + """cancel() must return True when the run is already interrupted.""" + + async def run(): + mgr = RunManager() + record = await mgr.create(THREAD_ID) + await mgr.set_status(record.run_id, RunStatus.running) + first = await mgr.cancel(record.run_id) + assert first is True + second = await mgr.cancel(record.run_id) + assert second is True # idempotent + + asyncio.run(run()) + + def test_cancel_returns_false_for_successful_run(self): + """cancel() must still return False for runs that completed successfully.""" + + async def run(): + mgr = RunManager() + record = await mgr.create(THREAD_ID) + await mgr.set_status(record.run_id, RunStatus.running) + await mgr.set_status(record.run_id, RunStatus.success) + result = await mgr.cancel(record.run_id) + assert result is False + + asyncio.run(run()) + + def test_cancel_returns_false_for_unknown_run(self): + async def run(): + mgr = RunManager() + result = await mgr.cancel("nonexistent-run-id") + assert result is False + + asyncio.run(run()) + + +# --------------------------------------------------------------------------- +# POST /cancel endpoint — idempotent 202 +# --------------------------------------------------------------------------- + + +class TestCancelRunEndpointIdempotency: + def test_double_cancel_returns_202_not_409(self): + """Second cancel on an already-interrupted run must return 202, not 409.""" + mgr = RunManager() + run_id = _create_interrupted_run(mgr) + client = _make_app(mgr) + + resp = client.post(f"/api/threads/{THREAD_ID}/runs/{run_id}/cancel") + assert resp.status_code == 202, f"Expected 202, got {resp.status_code}: {resp.text}" + + def test_cancel_unknown_run_returns_404(self): + mgr = RunManager() + client = _make_app(mgr) + resp = client.post(f"/api/threads/{THREAD_ID}/runs/no-such-run/cancel") + assert resp.status_code == 404 + + def test_cancel_successful_run_returns_409(self): + """Successfully-completed runs cannot be cancelled — must return 409.""" + + async def _setup(): + mgr = RunManager() + record = await mgr.create(THREAD_ID) + await mgr.set_status(record.run_id, RunStatus.running) + await mgr.set_status(record.run_id, RunStatus.success) + return mgr, record.run_id + + mgr, run_id = asyncio.run(_setup()) + client = _make_app(mgr) + resp = client.post(f"/api/threads/{THREAD_ID}/runs/{run_id}/cancel") + assert resp.status_code == 409 + + +# --------------------------------------------------------------------------- +# POST /{thread_id}/runs/{run_id}/join (stream_existing_run) — idempotent cancel +# --------------------------------------------------------------------------- + + +class TestStreamExistingRunIdempotentCancel: + def test_stream_cancel_already_interrupted_returns_not_409(self): + """stream_existing_run with action=interrupt on an already-interrupted run + must not raise 409 — the idempotent cancel path returns 202/SSE.""" + mgr = RunManager() + run_id = _create_interrupted_run(mgr) + client = _make_app(mgr) + + resp = client.post( + f"/api/threads/{THREAD_ID}/runs/{run_id}/join", + params={"action": "interrupt"}, + ) + assert resp.status_code != 409, f"Should not 409 on idempotent cancel, got {resp.status_code}" From 9abe5a18e6e9c5fc4f73f852a9d23cbac6149438 Mon Sep 17 00:00:00 2001 From: DanielWalnut <45447813+hetaoBackend@users.noreply.github.com> Date: Wed, 20 May 2026 22:26:02 +0800 Subject: [PATCH 50/86] fix: clean up local nginx on stop (#3005) * fix: clean up local nginx on stop * fix: scope local service cleanup to repo * fix: address serve port review comments --- scripts/serve.sh | 138 ++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 124 insertions(+), 14 deletions(-) diff --git a/scripts/serve.sh b/scripts/serve.sh index a45ff1af1..485c9b5fe 100755 --- a/scripts/serve.sh +++ b/scripts/serve.sh @@ -62,27 +62,129 @@ done # ── Stop helper ────────────────────────────────────────────────────────────── -_kill_port() { +_is_repo_pid() { + local pid=$1 + lsof -p "$pid" 2>/dev/null | grep -F "$REPO_ROOT" >/dev/null +} + +_kill_repo_processes() { + local pattern=$1 + local pid + local pids="" + + while IFS= read -r pid; do + if [ -n "$pid" ] && _is_repo_pid "$pid"; then + case " $pids " in + *" $pid "*) ;; + *) pids="$pids $pid" ;; + esac + fi + done < <(pgrep -f "$pattern" 2>/dev/null || true) + + if [ -n "$pids" ]; then + kill $pids 2>/dev/null || true + fi +} + +_kill_repo_port() { local port=$1 local pid - pid=$(lsof -ti :"$port" 2>/dev/null) || true - if [ -n "$pid" ]; then - kill -9 $pid 2>/dev/null || true + local pids="" + + while IFS= read -r pid; do + if [ -n "$pid" ] && _is_repo_pid "$pid"; then + case " $pids " in + *" $pid "*) ;; + *) pids="$pids $pid" ;; + esac + fi + done < <(lsof -nP -iTCP:"$port" -sTCP:LISTEN -t 2>/dev/null || true) + + if [ -n "$pids" ]; then + kill -9 $pids 2>/dev/null || true + fi +} + +_is_port_listening() { + local port=$1 + + if command -v lsof >/dev/null 2>&1; then + if lsof -nP -iTCP:"$port" -sTCP:LISTEN -t >/dev/null 2>&1; then + return 0 + fi + fi + + if command -v ss >/dev/null 2>&1; then + if ss -ltn "( sport = :$port )" 2>/dev/null | tail -n +2 | grep -q .; then + return 0 + fi + fi + + if command -v netstat >/dev/null 2>&1; then + if netstat -ltn 2>/dev/null | awk '{print $4}' | grep -Eq "(^|[.:])${port}$"; then + return 0 + fi + fi + + return 1 +} + +_is_repo_nginx_pid() { + local pid=$1 + local command + local args + + command=$(ps -p "$pid" -o comm= 2>/dev/null) || return 1 + case "$command" in + nginx|*/nginx) ;; + *) return 1 ;; + esac + + args=$(ps -p "$pid" -o args= 2>/dev/null) || return 1 + case "$args" in + *"$REPO_ROOT/docker/nginx/nginx.local.conf"*|*"$REPO_ROOT"*) return 0 ;; + esac + + _is_repo_pid "$pid" +} + +_kill_repo_nginx() { + local pid + local pids="" + + if [ -f "$REPO_ROOT/logs/nginx.pid" ]; then + read -r pid < "$REPO_ROOT/logs/nginx.pid" || true + if [ -n "$pid" ] && _is_repo_nginx_pid "$pid"; then + pids="$pids $pid" + fi + fi + + while IFS= read -r pid; do + if [ -n "$pid" ] && _is_repo_nginx_pid "$pid"; then + case " $pids " in + *" $pid "*) ;; + *) pids="$pids $pid" ;; + esac + fi + done < <(pgrep -f nginx 2>/dev/null || true) + + if [ -n "$pids" ]; then + kill -9 $pids 2>/dev/null || true fi } stop_all() { echo "Stopping all services..." - pkill -f "uvicorn app.gateway.app:app" 2>/dev/null || true - pkill -f "next dev" 2>/dev/null || true - pkill -f "next start" 2>/dev/null || true - pkill -f "next-server" 2>/dev/null || true + _kill_repo_processes "uvicorn app.gateway.app:app" + _kill_repo_processes "next dev" + _kill_repo_processes "next start" + _kill_repo_processes "next-server" nginx -c "$REPO_ROOT/docker/nginx/nginx.local.conf" -p "$REPO_ROOT" -s quit 2>/dev/null || true sleep 1 - pkill -9 nginx 2>/dev/null || true + _kill_repo_nginx # Force-kill any survivors still holding the service ports - _kill_port 8001 - _kill_port 3000 + _kill_repo_port 8001 + _kill_repo_port 3000 ./scripts/cleanup-containers.sh deer-flow-sandbox 2>/dev/null || true echo "✓ All services stopped" } @@ -216,13 +318,15 @@ echo "" # ── Cleanup handler ────────────────────────────────────────────────────────── cleanup() { + local status="${1:-0}" trap - INT TERM echo "" stop_all - exit 0 + exit "$status" } -trap cleanup INT TERM +trap 'cleanup 130' INT +trap 'cleanup 143' TERM # ── Helper: start a service ────────────────────────────────────────────────── @@ -231,6 +335,12 @@ trap cleanup INT TERM run_service() { local name="$1" cmd="$2" port="$3" timeout="$4" + if _is_port_listening "$port"; then + echo "✗ $name cannot start because port $port is already in use." + echo " If it belongs to this worktree, run 'make stop'; otherwise free the port manually." + cleanup 1 + fi + echo "Starting $name..." if $DAEMON_MODE; then nohup sh -c "$cmd" > /dev/null 2>&1 & @@ -242,7 +352,7 @@ run_service() { local logfile="logs/$(echo "$name" | tr '[:upper:]' '[:lower:]' | tr ' ' '-').log" echo "✗ $name failed to start." [ -f "$logfile" ] && tail -20 "$logfile" - cleanup + cleanup 1 } echo "✓ $name started on localhost:$port" } From b6b3650e50feff7ffae4fec91786ecb6be760a3d Mon Sep 17 00:00:00 2001 From: Airene Fang Date: Wed, 20 May 2026 22:34:10 +0800 Subject: [PATCH 51/86] =?UTF-8?q?fix(trace):memory=20=E4=B8=AD=E6=96=87=20?= =?UTF-8?q?in=20trace=20info=20is=20unicode=20escape=20sequence.=20(#3104)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(trace):memory 中文 in trace is unicode * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- .../harness/deerflow/agents/memory/updater.py | 2 +- backend/tests/test_memory_updater.py | 35 +++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/backend/packages/harness/deerflow/agents/memory/updater.py b/backend/packages/harness/deerflow/agents/memory/updater.py index 6e55330a1..2007a97e2 100644 --- a/backend/packages/harness/deerflow/agents/memory/updater.py +++ b/backend/packages/harness/deerflow/agents/memory/updater.py @@ -338,7 +338,7 @@ class MemoryUpdater: reinforcement_detected=reinforcement_detected, ) prompt = MEMORY_UPDATE_PROMPT.format( - current_memory=json.dumps(current_memory, indent=2), + current_memory=json.dumps(current_memory, indent=2, ensure_ascii=False), conversation=conversation_text, correction_hint=correction_hint, ) diff --git a/backend/tests/test_memory_updater.py b/backend/tests/test_memory_updater.py index 03d135564..038cec627 100644 --- a/backend/tests/test_memory_updater.py +++ b/backend/tests/test_memory_updater.py @@ -78,6 +78,41 @@ def test_apply_updates_skips_existing_duplicate_and_preserves_removals() -> None assert all(fact["id"] != "fact_remove" for fact in result["facts"]) +def test_prepare_update_prompt_preserves_non_ascii_memory_text() -> None: + updater = MemoryUpdater() + current_memory = _make_memory( + facts=[ + { + "id": "fact_cn", + "content": "Deer-flow是一个非常好的框架。", + "category": "context", + "confidence": 0.9, + "createdAt": "2026-05-20T00:00:00Z", + "source": "thread-cn", + }, + ] + ) + + with ( + patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)), + patch("deerflow.agents.memory.updater.get_memory_data", return_value=current_memory), + ): + msg = MagicMock() + msg.type = "human" + msg.content = "你好" + prepared = updater._prepare_update_prompt( + [msg], + agent_name=None, + correction_detected=False, + reinforcement_detected=False, + ) + + assert prepared is not None + _, prompt = prepared + assert "Deer-flow是一个非常好的框架。" in prompt + assert "\\u" not in prompt + + def test_apply_updates_skips_same_batch_duplicates_and_keeps_source_metadata() -> None: updater = MemoryUpdater() current_memory = _make_memory() From 9afeaf66bc1b8fcd190ebf233525510248888d4a Mon Sep 17 00:00:00 2001 From: Yuyi Ao Date: Wed, 20 May 2026 16:27:00 -0700 Subject: [PATCH 52/86] Fix env resolution in MCP config lists (#2556) * Fix env resolution in MCP config lists * fix:unset env variable and consistent function --------- Co-authored-by: Willem Jiang --- .../deerflow/config/extensions_config.py | 41 ++++++++++--------- backend/tests/test_mcp_client_config.py | 20 +++++++++ 2 files changed, 42 insertions(+), 19 deletions(-) diff --git a/backend/packages/harness/deerflow/config/extensions_config.py b/backend/packages/harness/deerflow/config/extensions_config.py index a2daa71f4..425da12b8 100644 --- a/backend/packages/harness/deerflow/config/extensions_config.py +++ b/backend/packages/harness/deerflow/config/extensions_config.py @@ -141,7 +141,7 @@ class ExtensionsConfig(BaseModel): try: with open(resolved_path, encoding="utf-8") as f: config_data = json.load(f) - cls.resolve_env_variables(config_data) + config_data = cls.resolve_env_variables(config_data) return cls.model_validate(config_data) except json.JSONDecodeError as e: raise ValueError(f"Extensions config file at {resolved_path} is not valid JSON: {e}") from e @@ -149,7 +149,7 @@ class ExtensionsConfig(BaseModel): raise RuntimeError(f"Failed to load extensions config from {resolved_path}: {e}") from e @classmethod - def resolve_env_variables(cls, config: dict[str, Any]) -> dict[str, Any]: + def resolve_env_variables(cls, config: Any) -> Any: """Recursively resolve environment variables in the config. Environment variables are resolved using the `os.getenv` function. Example: $OPENAI_API_KEY @@ -160,23 +160,26 @@ class ExtensionsConfig(BaseModel): Returns: The config with environment variables resolved. """ - for key, value in config.items(): - if isinstance(value, str): - if value.startswith("$"): - env_value = os.getenv(value[1:]) - if env_value is None: - # Unresolved placeholder — store empty string so downstream - # consumers (e.g. MCP servers) don't receive the literal "$VAR" - # token as an actual environment value. - config[key] = "" - else: - config[key] = env_value - else: - config[key] = value - elif isinstance(value, dict): - config[key] = cls.resolve_env_variables(value) - elif isinstance(value, list): - config[key] = [cls.resolve_env_variables(item) if isinstance(item, dict) else item for item in value] + if isinstance(config, str): + if not config.startswith("$"): + return config + env_value = os.getenv(config[1:]) + if env_value is None: + # Unresolved placeholder — store empty string so downstream + # consumers (e.g. MCP servers) don't receive the literal "$VAR" + # token as an actual environment value. + return "" + return env_value + + if isinstance(config, dict): + return {key: cls.resolve_env_variables(value) for key, value in config.items()} + + if isinstance(config, list): + return [cls.resolve_env_variables(item) for item in config] + + if isinstance(config, tuple): + return tuple(cls.resolve_env_variables(item) for item in config) + return config def get_enabled_mcp_servers(self) -> dict[str, McpServerConfig]: diff --git a/backend/tests/test_mcp_client_config.py b/backend/tests/test_mcp_client_config.py index 6d0083c0c..ca4d0de59 100644 --- a/backend/tests/test_mcp_client_config.py +++ b/backend/tests/test_mcp_client_config.py @@ -24,6 +24,26 @@ def test_build_server_params_stdio_success(): } +def test_extensions_config_resolves_env_variables_inside_nested_collections(monkeypatch): + monkeypatch.setenv("MCP_TOKEN", "secret") + monkeypatch.delenv("MISSING_TOKEN", raising=False) + raw_config = { + "args": ["--token", "$MCP_TOKEN", {"nested": ["$MCP_TOKEN", "$MISSING_TOKEN"]}], + "tuple_args": ("$MCP_TOKEN", "$MISSING_TOKEN"), + "env": {"API_KEY": "$MCP_TOKEN"}, + "enabled": True, + "timeout": 30, + } + + resolved = ExtensionsConfig.resolve_env_variables(raw_config) + + assert resolved["args"] == ["--token", "secret", {"nested": ["secret", ""]}] + assert resolved["tuple_args"] == ("secret", "") + assert resolved["env"] == {"API_KEY": "secret"} + assert resolved["enabled"] is True + assert resolved["timeout"] == 30 + + def test_build_server_params_stdio_requires_command(): config = McpServerConfig(type="stdio", command=None) From e19bec1422e51750e299457ce44a50d84aa152fe Mon Sep 17 00:00:00 2001 From: InitBoy <804255496@qq.com> Date: Thu, 21 May 2026 07:47:19 +0800 Subject: [PATCH 53/86] fix(task-tool): cancel and schedule deferred cleanup on polling safety timeout (#3097) When the poll loop's safety-net timeout fires (poll_count > max_poll_count), the background subagent task was abandoned without cancellation or cleanup, leaving a stale entry in _background_tasks indefinitely. The original code had a comment promising "the cleanup will happen when the executor completes", but run_task() in executor.py never calls cleanup_background_task after reaching a terminal state -- the promise was never implemented. This change mirrors the asyncio.CancelledError path: signal cooperative cancellation via request_cancel_background_task and schedule _deferred_cleanup_subagent_task to remove the entry once the background thread reaches a terminal state. Direct cleanup at poll-timeout time would introduce a race: run_task() could remove the entry while the poll loop is still mid-iteration, causing a spurious "Task disappeared" error. The deferred approach avoids this by waiting for terminal state before removal. Co-authored-by: Claude Sonnet 4.6 --- .../deerflow/tools/builtins/task_tool.py | 8 +++-- backend/tests/test_task_tool_core_logic.py | 30 +++++++++++++++---- 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/backend/packages/harness/deerflow/tools/builtins/task_tool.py b/backend/packages/harness/deerflow/tools/builtins/task_tool.py index cf9281ff4..a45bff787 100644 --- a/backend/packages/harness/deerflow/tools/builtins/task_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/task_tool.py @@ -383,9 +383,6 @@ async def task_tool( # Polling timeout as a safety net (in case thread pool timeout doesn't work) # Set to execution timeout + 60s buffer, in 5s poll intervals # This catches edge cases where the background task gets stuck - # Note: We don't call cleanup_background_task here because the task may - # still be running in the background. The cleanup will happen when the - # executor completes and sets a terminal status. if poll_count > max_poll_count: timeout_minutes = config.timeout_seconds // 60 logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)") @@ -393,6 +390,11 @@ async def task_tool( usage = _summarize_usage(getattr(result, "token_usage_records", None)) _cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage) writer({"type": "task_timed_out", "task_id": task_id, "usage": usage}) + # The task may still be running in the background. Signal cooperative + # cancellation and schedule deferred cleanup to remove the entry from + # _background_tasks once the background thread reaches a terminal state. + request_cancel_background_task(task_id) + _schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count) return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}" except asyncio.CancelledError: # Signal the background subagent thread to stop cooperatively. diff --git a/backend/tests/test_task_tool_core_logic.py b/backend/tests/test_task_tool_core_logic.py index 658968d65..dc0f844d3 100644 --- a/backend/tests/test_task_tool_core_logic.py +++ b/backend/tests/test_task_tool_core_logic.py @@ -732,17 +732,27 @@ def test_cleanup_called_on_timed_out(monkeypatch): def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch): - """Verify cleanup_background_task is NOT called on polling safety timeout. + """Verify cleanup_background_task is NOT called directly on polling safety timeout. - This prevents race conditions where the background task is still running - but the polling loop gives up. The cleanup should happen later when the - executor completes and sets a terminal status. + The task is still RUNNING so it cannot be safely removed yet. Instead, + cooperative cancellation is requested and a deferred cleanup is scheduled. """ config = _make_subagent_config() # Keep max_poll_count small for test speed: (1 + 60) // 5 = 12 config.timeout_seconds = 1 events = [] cleanup_calls = [] + cancel_requests = [] + scheduled_cleanups = [] + + class DummyCleanupTask: + def add_done_callback(self, _callback): + return None + + def fake_create_task(coro): + scheduled_cleanups.append(coro) + coro.close() + return DummyCleanupTask() monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) monkeypatch.setattr( @@ -759,12 +769,18 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch): ) monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) + monkeypatch.setattr(task_tool_module.asyncio, "create_task", fake_create_task) monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: []) monkeypatch.setattr( task_tool_module, "cleanup_background_task", lambda task_id: cleanup_calls.append(task_id), ) + monkeypatch.setattr( + task_tool_module, + "request_cancel_background_task", + lambda task_id: cancel_requests.append(task_id), + ) output = _run_task_tool( runtime=_make_runtime(), @@ -775,8 +791,12 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch): ) assert output.startswith("Task polling timed out after 0 minutes") - # cleanup should NOT be called because the task is still RUNNING + # cleanup_background_task must NOT be called directly (task is still RUNNING) assert cleanup_calls == [] + # cooperative cancellation must be requested + assert cancel_requests == ["tc-no-cleanup-safety-timeout"] + # a deferred cleanup coroutine must be scheduled + assert len(scheduled_cleanups) == 1 def test_cleanup_scheduled_on_cancellation(monkeypatch): From 7ec8d3a6e7cb1f549118e451f4cf4a09365c8449 Mon Sep 17 00:00:00 2001 From: sunsine <135408348+sunshine-lang@users.noreply.github.com> Date: Thu, 21 May 2026 10:28:57 +0800 Subject: [PATCH 54/86] fix(security): mask sensitive values in MCP config API responses (#2667) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(security): mask sensitive values in MCP config API responses GET /api/mcp/config previously returned plaintext secrets including env dict values (API keys), headers (auth tokens), and OAuth client_secret/refresh_token. Any authenticated user could read all MCP service credentials. This commit masks sensitive fields in GET/PUT responses while preserving the key structure so the frontend round-trip (GET masked → toggle enabled → PUT) correctly preserves existing secrets. Co-Authored-By: Claude Opus 4.6 * fix(security): address Copilot review on MCP config masking - Load raw JSON (un-resolved $VAR placeholders) as merge source instead of resolved config, preventing plaintext secrets from replacing $VAR placeholders on disk (Comment 2) - Preserve all top-level keys (e.g. mcpInterceptors) in PUT, not just mcpServers/skills (Comment 1) - Reject masked value '***' for new keys that don't exist in existing config, returning 400 with actionable error (Comment 3) - Allow empty string '' to explicitly clear OAuth secrets, while None means 'preserve existing' for safe round-trip (Comment 4) - Add 3 new tests for rejection, clearing, and edge cases (18 total) Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Claude Opus 4.6 --- backend/app/gateway/routers/mcp.py | 138 +++++++++- backend/tests/test_mcp_config_secrets.py | 305 +++++++++++++++++++++++ 2 files changed, 434 insertions(+), 9 deletions(-) create mode 100644 backend/tests/test_mcp_config_secrets.py diff --git a/backend/app/gateway/routers/mcp.py b/backend/app/gateway/routers/mcp.py index 386fc13c6..d38406266 100644 --- a/backend/app/gateway/routers/mcp.py +++ b/backend/app/gateway/routers/mcp.py @@ -63,6 +63,99 @@ class McpConfigUpdateRequest(BaseModel): ) +_MASKED_VALUE = "***" + + +def _mask_server_config(server: McpServerConfigResponse) -> McpServerConfigResponse: + """Return a copy of server config with sensitive fields masked. + + Masks env values, header values, and removes OAuth secrets so they + are not exposed through the GET API endpoint. + """ + masked_env = {k: _MASKED_VALUE for k in server.env} + masked_headers = {k: _MASKED_VALUE for k in server.headers} + masked_oauth = None + if server.oauth is not None: + masked_oauth = server.oauth.model_copy( + update={ + "client_secret": None, + "refresh_token": None, + } + ) + return server.model_copy( + update={ + "env": masked_env, + "headers": masked_headers, + "oauth": masked_oauth, + } + ) + + +def _merge_preserving_secrets( + incoming: McpServerConfigResponse, + existing: McpServerConfigResponse, +) -> McpServerConfigResponse: + """Merge incoming config with existing, preserving secrets masked by GET. + + When the frontend toggles ``enabled`` it round-trips the full config: + GET (masked) → modify enabled → PUT (masked values sent back). + This function ensures masked values (``***``) are replaced with the + real secrets from the current on-disk config. + + ``***`` is only accepted for keys that already exist in *existing*. + New keys must provide a real value. + + For OAuth secrets, ``None`` means "preserve the existing stored value" + so masked GET responses can be safely round-tripped. To explicitly clear + a stored secret, clients may send an empty string, which is converted + to ``None`` before persisting. + """ + merged_env = {} + for k, v in incoming.env.items(): + if v == _MASKED_VALUE: + if k in existing.env: + merged_env[k] = existing.env[k] + else: + raise HTTPException( + status_code=400, + detail=f"Cannot set env key '{k}' to masked value '***'; provide a real value.", + ) + else: + merged_env[k] = v + + merged_headers = {} + for k, v in incoming.headers.items(): + if v == _MASKED_VALUE: + if k in existing.headers: + merged_headers[k] = existing.headers[k] + else: + raise HTTPException( + status_code=400, + detail=f"Cannot set header '{k}' to masked value '***'; provide a real value.", + ) + else: + merged_headers[k] = v + + merged_oauth = incoming.oauth + if incoming.oauth is not None and existing.oauth is not None: + # None = preserve (masked round-trip), "" = explicitly clear, else = new value + merged_client_secret = existing.oauth.client_secret if incoming.oauth.client_secret is None else (None if incoming.oauth.client_secret == "" else incoming.oauth.client_secret) + merged_refresh_token = existing.oauth.refresh_token if incoming.oauth.refresh_token is None else (None if incoming.oauth.refresh_token == "" else incoming.oauth.refresh_token) + merged_oauth = incoming.oauth.model_copy( + update={ + "client_secret": merged_client_secret, + "refresh_token": merged_refresh_token, + } + ) + return incoming.model_copy( + update={ + "env": merged_env, + "headers": merged_headers, + "oauth": merged_oauth, + } + ) + + @router.get( "/mcp/config", response_model=McpConfigResponse, @@ -83,7 +176,7 @@ async def get_mcp_configuration() -> McpConfigResponse: "enabled": true, "command": "npx", "args": ["-y", "@modelcontextprotocol/server-github"], - "env": {"GITHUB_TOKEN": "ghp_xxx"}, + "env": {"GITHUB_TOKEN": "***"}, "description": "GitHub MCP server for repository operations" } } @@ -92,7 +185,8 @@ async def get_mcp_configuration() -> McpConfigResponse: """ config = get_extensions_config() - return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in config.mcp_servers.items()}) + servers = {name: _mask_server_config(McpServerConfigResponse(**server.model_dump())) for name, server in config.mcp_servers.items()} + return McpConfigResponse(mcp_servers=servers) @router.put( @@ -142,14 +236,39 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig config_path = Path.cwd().parent / "extensions_config.json" logger.info(f"No existing extensions config found. Creating new config at: {config_path}") - # Load current config to preserve skills configuration + # Load current config to preserve skills current_config = get_extensions_config() - # Convert request to dict format for JSON serialization - config_data = { - "mcpServers": {name: server.model_dump() for name, server in request.mcp_servers.items()}, - "skills": {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()}, - } + # Load raw (un-resolved) JSON from disk to use as the merge source. + # This preserves $VAR placeholders in env values and top-level keys + # like mcpInterceptors that would otherwise be lost. + raw_servers: dict[str, dict] = {} + raw_other_keys: dict = {} + if config_path is not None and config_path.exists(): + with open(config_path, encoding="utf-8") as f: + raw_data = json.load(f) + raw_servers = raw_data.get("mcpServers", {}) + # Preserve any top-level keys beyond mcpServers/skills + for key, value in raw_data.items(): + if key not in ("mcpServers", "skills"): + raw_other_keys[key] = value + + # Merge incoming server configs with raw on-disk secrets + merged_servers: dict[str, McpServerConfigResponse] = {} + for name, incoming in request.mcp_servers.items(): + raw_server = raw_servers.get(name) + if raw_server is not None: + merged_servers[name] = _merge_preserving_secrets( + incoming, + McpServerConfigResponse(**raw_server), + ) + else: + merged_servers[name] = incoming + + # Build config data preserving all top-level keys from the original file + config_data = dict(raw_other_keys) + config_data["mcpServers"] = {name: server.model_dump() for name, server in merged_servers.items()} + config_data["skills"] = {name: {"enabled": skill.enabled} for name, skill in current_config.skills.items()} # Write the configuration to file with open(config_path, "w", encoding="utf-8") as f: @@ -162,7 +281,8 @@ async def update_mcp_configuration(request: McpConfigUpdateRequest) -> McpConfig # Reload the configuration and update the global cache reloaded_config = reload_extensions_config() - return McpConfigResponse(mcp_servers={name: McpServerConfigResponse(**server.model_dump()) for name, server in reloaded_config.mcp_servers.items()}) + servers = {name: _mask_server_config(McpServerConfigResponse(**server.model_dump())) for name, server in reloaded_config.mcp_servers.items()} + return McpConfigResponse(mcp_servers=servers) except Exception as e: logger.error(f"Failed to update MCP configuration: {e}", exc_info=True) diff --git a/backend/tests/test_mcp_config_secrets.py b/backend/tests/test_mcp_config_secrets.py new file mode 100644 index 000000000..831b8611b --- /dev/null +++ b/backend/tests/test_mcp_config_secrets.py @@ -0,0 +1,305 @@ +"""Tests for MCP config secret masking and preservation. + +Verifies that GET /api/mcp/config masks sensitive fields (env values, +header values, OAuth secrets) and that PUT /api/mcp/config correctly +preserves existing secrets when the frontend round-trips masked values. +""" + +from __future__ import annotations + +import pytest + +from app.gateway.routers.mcp import ( + McpOAuthConfigResponse, + McpServerConfigResponse, + _mask_server_config, + _merge_preserving_secrets, +) + +# --------------------------------------------------------------------------- +# _mask_server_config +# --------------------------------------------------------------------------- + + +def test_mask_replaces_env_values_with_asterisks(): + """Env dict values should be replaced with '***'.""" + server = McpServerConfigResponse( + env={"GITHUB_TOKEN": "ghp_real_secret_123", "API_KEY": "sk-abc"}, + ) + masked = _mask_server_config(server) + assert masked.env == {"GITHUB_TOKEN": "***", "API_KEY": "***"} + + +def test_mask_replaces_header_values_with_asterisks(): + """Header dict values should be replaced with '***'.""" + server = McpServerConfigResponse( + headers={"Authorization": "Bearer tok_123", "X-API-Key": "key_456"}, + ) + masked = _mask_server_config(server) + assert masked.headers == {"Authorization": "***", "X-API-Key": "***"} + + +def test_mask_removes_oauth_secrets(): + """OAuth client_secret and refresh_token should be set to None.""" + server = McpServerConfigResponse( + oauth=McpOAuthConfigResponse( + client_id="my-client", + client_secret="super-secret", + refresh_token="refresh-token-abc", + token_url="https://auth.example.com/token", + ), + ) + masked = _mask_server_config(server) + assert masked.oauth is not None + assert masked.oauth.client_secret is None + assert masked.oauth.refresh_token is None + # Non-secret fields preserved + assert masked.oauth.client_id == "my-client" + assert masked.oauth.token_url == "https://auth.example.com/token" + + +def test_mask_preserves_non_secret_fields(): + """Non-sensitive fields should pass through unchanged.""" + server = McpServerConfigResponse( + enabled=True, + type="stdio", + command="npx", + args=["-y", "@modelcontextprotocol/server-github"], + env={"KEY": "val"}, + description="GitHub MCP server", + ) + masked = _mask_server_config(server) + assert masked.enabled is True + assert masked.type == "stdio" + assert masked.command == "npx" + assert masked.args == ["-y", "@modelcontextprotocol/server-github"] + assert masked.description == "GitHub MCP server" + + +def test_mask_handles_empty_env_and_headers(): + """Empty env/headers dicts should remain empty.""" + server = McpServerConfigResponse() + masked = _mask_server_config(server) + assert masked.env == {} + assert masked.headers == {} + + +def test_mask_handles_no_oauth(): + """Server without OAuth should remain None.""" + server = McpServerConfigResponse(oauth=None) + masked = _mask_server_config(server) + assert masked.oauth is None + + +def test_mask_does_not_mutate_original(): + """Masking should return a new object, not modify the original.""" + server = McpServerConfigResponse(env={"KEY": "secret"}) + masked = _mask_server_config(server) + assert server.env["KEY"] == "secret" + assert masked.env["KEY"] == "***" + + +# --------------------------------------------------------------------------- +# _merge_preserving_secrets +# --------------------------------------------------------------------------- + + +def test_merge_preserves_masked_env_values(): + """Incoming '***' env values should be replaced with existing secrets.""" + incoming = McpServerConfigResponse(env={"KEY": "***"}) + existing = McpServerConfigResponse(env={"KEY": "real_secret"}) + merged = _merge_preserving_secrets(incoming, existing) + assert merged.env["KEY"] == "real_secret" + + +def test_merge_preserves_masked_header_values(): + """Incoming '***' header values should be replaced with existing secrets.""" + incoming = McpServerConfigResponse(headers={"Authorization": "***"}) + existing = McpServerConfigResponse(headers={"Authorization": "Bearer real"}) + merged = _merge_preserving_secrets(incoming, existing) + assert merged.headers["Authorization"] == "Bearer real" + + +def test_merge_preserves_oauth_secrets_when_none(): + """Incoming None oauth secrets should preserve existing values.""" + incoming = McpServerConfigResponse( + oauth=McpOAuthConfigResponse( + client_secret=None, + refresh_token=None, + token_url="https://auth.example.com/token", + ), + ) + existing = McpServerConfigResponse( + oauth=McpOAuthConfigResponse( + client_secret="existing-secret", + refresh_token="existing-refresh", + token_url="https://auth.example.com/token", + ), + ) + merged = _merge_preserving_secrets(incoming, existing) + assert merged.oauth is not None + assert merged.oauth.client_secret == "existing-secret" + assert merged.oauth.refresh_token == "existing-refresh" + + +def test_merge_accepts_new_secret_values(): + """Incoming real secret values should replace existing ones.""" + incoming = McpServerConfigResponse( + env={"KEY": "new_secret"}, + oauth=McpOAuthConfigResponse( + client_secret="new-client-secret", + refresh_token="new-refresh-token", + token_url="https://auth.example.com/token", + ), + ) + existing = McpServerConfigResponse( + env={"KEY": "old_secret"}, + oauth=McpOAuthConfigResponse( + client_secret="old-secret", + refresh_token="old-refresh", + token_url="https://auth.example.com/token", + ), + ) + merged = _merge_preserving_secrets(incoming, existing) + assert merged.env["KEY"] == "new_secret" + assert merged.oauth.client_secret == "new-client-secret" + assert merged.oauth.refresh_token == "new-refresh-token" + + +def test_merge_handles_no_existing_oauth(): + """When existing has no oauth but incoming does, keep incoming.""" + incoming = McpServerConfigResponse( + oauth=McpOAuthConfigResponse( + client_secret="new-secret", + token_url="https://auth.example.com/token", + ), + ) + existing = McpServerConfigResponse(oauth=None) + merged = _merge_preserving_secrets(incoming, existing) + assert merged.oauth is not None + assert merged.oauth.client_secret == "new-secret" + + +def test_merge_does_not_mutate_original(): + """Merge should return a new object, not modify the original.""" + incoming = McpServerConfigResponse(env={"KEY": "***"}) + existing = McpServerConfigResponse(env={"KEY": "secret"}) + merged = _merge_preserving_secrets(incoming, existing) + assert incoming.env["KEY"] == "***" + assert existing.env["KEY"] == "secret" + assert merged.env["KEY"] == "secret" + + +# --------------------------------------------------------------------------- +# Comment 2 fix: masked value for new key is rejected +# --------------------------------------------------------------------------- + + +def test_merge_rejects_masked_value_for_new_env_key(): + """Sending '***' for a key that doesn't exist in existing should raise 400.""" + from fastapi import HTTPException + + incoming = McpServerConfigResponse(env={"NEW_KEY": "***"}) + existing = McpServerConfigResponse(env={}) + with pytest.raises(HTTPException) as exc_info: + _merge_preserving_secrets(incoming, existing) + assert exc_info.value.status_code == 400 + assert "NEW_KEY" in exc_info.value.detail + + +def test_merge_rejects_masked_value_for_new_header_key(): + """Sending '***' for a header key that doesn't exist should raise 400.""" + from fastapi import HTTPException + + incoming = McpServerConfigResponse(headers={"X-New-Auth": "***"}) + existing = McpServerConfigResponse(headers={}) + with pytest.raises(HTTPException) as exc_info: + _merge_preserving_secrets(incoming, existing) + assert exc_info.value.status_code == 400 + assert "X-New-Auth" in exc_info.value.detail + + +# --------------------------------------------------------------------------- +# Comment 4 fix: empty string clears OAuth secrets +# --------------------------------------------------------------------------- + + +def test_merge_empty_string_clears_oauth_client_secret(): + """Sending '' for client_secret should clear the stored value.""" + incoming = McpServerConfigResponse( + oauth=McpOAuthConfigResponse( + client_secret="", + refresh_token=None, + token_url="https://auth.example.com/token", + ), + ) + existing = McpServerConfigResponse( + oauth=McpOAuthConfigResponse( + client_secret="existing-secret", + refresh_token="existing-refresh", + token_url="https://auth.example.com/token", + ), + ) + merged = _merge_preserving_secrets(incoming, existing) + assert merged.oauth.client_secret is None + assert merged.oauth.refresh_token == "existing-refresh" + + +def test_merge_empty_string_clears_oauth_refresh_token(): + """Sending '' for refresh_token should clear the stored value.""" + incoming = McpServerConfigResponse( + oauth=McpOAuthConfigResponse( + client_secret=None, + refresh_token="", + token_url="https://auth.example.com/token", + ), + ) + existing = McpServerConfigResponse( + oauth=McpOAuthConfigResponse( + client_secret="existing-secret", + refresh_token="existing-refresh", + token_url="https://auth.example.com/token", + ), + ) + merged = _merge_preserving_secrets(incoming, existing) + assert merged.oauth.client_secret == "existing-secret" + assert merged.oauth.refresh_token is None + + +# --------------------------------------------------------------------------- +# Round-trip integration: mask → merge should preserve original secrets +# --------------------------------------------------------------------------- + + +def test_roundtrip_mask_then_merge_preserves_original_secrets(): + """Simulates the full frontend round-trip: GET (masked) → toggle → PUT.""" + original = McpServerConfigResponse( + enabled=True, + env={"GITHUB_TOKEN": "ghp_real_secret"}, + headers={"Authorization": "Bearer real_token"}, + oauth=McpOAuthConfigResponse( + client_id="client-123", + client_secret="oauth-secret", + refresh_token="refresh-abc", + token_url="https://auth.example.com/token", + ), + description="GitHub MCP server", + ) + + # Step 1: Server returns masked config (simulates GET response) + masked = _mask_server_config(original) + assert masked.env["GITHUB_TOKEN"] == "***" + assert masked.oauth.client_secret is None + + # Step 2: Frontend toggles enabled and sends back (simulates PUT request) + from_frontend = masked.model_copy(update={"enabled": False}) + + # Step 3: Server merges with existing secrets (simulates PUT handler) + restored = _merge_preserving_secrets(from_frontend, original) + assert restored.enabled is False + assert restored.env["GITHUB_TOKEN"] == "ghp_real_secret" + assert restored.headers["Authorization"] == "Bearer real_token" + assert restored.oauth.client_secret == "oauth-secret" + assert restored.oauth.refresh_token == "refresh-abc" + # Non-secret fields from the update are preserved + assert restored.description == "GitHub MCP server" From dcc6f1e6789512c05c9674cb86aa58792c1d9066 Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Thu, 21 May 2026 08:36:07 +0200 Subject: [PATCH 55/86] feat(loop-detection): defer warning injection (#2752) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(loop-detection): defer warn injection to wrap_model_call The warn branch in LoopDetectionMiddleware injected a HumanMessage into state from after_model. The tools node had not yet produced ToolMessage responses to the previous AIMessage(tool_calls=...), so the new HumanMessage landed *between* the assistant's tool_calls and their responses. OpenAI/Moonshot reject the next request with "tool_call_ids did not have response messages" because their validators require tool_calls to be followed immediately by tool messages. Detection now runs in after_model as before, but only enqueues the warning into a per-thread list. Injection happens in wrap_model_call, where every prior ToolMessage is already present in request.messages. The warning is appended at the end as HumanMessage(name="loop_warning") — pairing intact, AIMessage semantics untouched, no SystemMessage issues for Anthropic. Closes #2029, addresses #2255 #2293 #2304 #2511. Co-Authored-By: Claude Opus 4.7 * fix(channels): remove loop warning display filter * feat(loop-detection): scope pending warnings by run * docs(loop-detection): update docs * test(loop-detection): assert deferred warnings are queued * fix(loop-detection): cap transient warning state * docs: update docs * add async awrap_model_call test coverage * docs(loop-detection): document transient warnings --------- Co-authored-by: Claude Opus 4.7 --- backend/app/channels/manager.py | 15 - backend/docs/middleware-execution-flow.md | 144 ++--- .../middlewares/loop_detection_middleware.py | 229 +++++++- backend/tests/test_channels.py | 31 -- .../tests/test_loop_detection_middleware.py | 494 +++++++++++++++--- .../src/content/en/harness/middlewares.mdx | 2 + .../src/content/zh/harness/middlewares.mdx | 2 + 7 files changed, 696 insertions(+), 221 deletions(-) diff --git a/backend/app/channels/manager.py b/backend/app/channels/manager.py index aa52fa298..015f91e58 100644 --- a/backend/app/channels/manager.py +++ b/backend/app/channels/manager.py @@ -146,13 +146,6 @@ def _normalize_custom_agent_name(raw_value: str) -> str: return normalized -def _strip_loop_warning_text(text: str) -> str: - """Remove middleware-authored loop warning lines from display text.""" - if "[LOOP DETECTED]" not in text: - return text - return "\n".join(line for line in text.splitlines() if "[LOOP DETECTED]" not in line).strip() - - def _extract_response_text(result: dict | list) -> str: """Extract the last AI message text from a LangGraph runs.wait result. @@ -162,7 +155,6 @@ def _extract_response_text(result: dict | list) -> str: Handles special cases: - Regular AI text responses - Clarification interrupts (``ask_clarification`` tool messages) - - Strips loop-detection warnings attached to tool-call AI messages """ if isinstance(result, list): messages = result @@ -192,12 +184,7 @@ def _extract_response_text(result: dict | list) -> str: # Regular AI message with text content if msg_type == "ai": content = msg.get("content", "") - has_tool_calls = bool(msg.get("tool_calls")) if isinstance(content, str) and content: - if has_tool_calls: - content = _strip_loop_warning_text(content) - if not content: - continue return content # content can be a list of content blocks if isinstance(content, list): @@ -208,8 +195,6 @@ def _extract_response_text(result: dict | list) -> str: elif isinstance(block, str): parts.append(block) text = "".join(parts) - if has_tool_calls: - text = _strip_loop_warning_text(text) if text: return text return "" diff --git a/backend/docs/middleware-execution-flow.md b/backend/docs/middleware-execution-flow.md index 922cc9640..99d638938 100644 --- a/backend/docs/middleware-execution-flow.md +++ b/backend/docs/middleware-execution-flow.md @@ -4,22 +4,22 @@ `create_deerflow_agent` 通过 `RuntimeFeatures` 组装的完整 middleware 链(默认全开时): -| # | Middleware | `before_agent` | `before_model` | `after_model` | `after_agent` | `wrap_tool_call` | 主 Agent | Subagent | 来源 | -|---|-----------|:-:|:-:|:-:|:-:|:-:|:-:|:-:|------| -| 0 | ThreadDataMiddleware | ✓ | | | | | ✓ | ✓ | `sandbox` | -| 1 | UploadsMiddleware | ✓ | | | | | ✓ | ✗ | `sandbox` | -| 2 | SandboxMiddleware | ✓ | | | ✓ | | ✓ | ✓ | `sandbox` | -| 3 | DanglingToolCallMiddleware | | | ✓ | | | ✓ | ✗ | 始终开启 | -| 4 | GuardrailMiddleware | | | | | ✓ | ✓ | ✓ | *Phase 2 纳入* | -| 5 | ToolErrorHandlingMiddleware | | | | | ✓ | ✓ | ✓ | 始终开启 | -| 6 | SummarizationMiddleware | | | ✓ | | | ✓ | ✗ | `summarization` | -| 7 | TodoMiddleware | | | ✓ | | | ✓ | ✗ | `plan_mode` 参数 | -| 8 | TitleMiddleware | | | ✓ | | | ✓ | ✗ | `auto_title` | -| 9 | MemoryMiddleware | | | | ✓ | | ✓ | ✗ | `memory` | -| 10 | ViewImageMiddleware | | ✓ | | | | ✓ | ✗ | `vision` | -| 11 | SubagentLimitMiddleware | | | ✓ | | | ✓ | ✗ | `subagent` | -| 12 | LoopDetectionMiddleware | | | ✓ | | | ✓ | ✗ | 始终开启 | -| 13 | ClarificationMiddleware | | | ✓ | | | ✓ | ✗ | 始终最后 | +| # | Middleware | `before_agent` | `before_model` | `after_model` | `after_agent` | `wrap_model_call` | `wrap_tool_call` | 主 Agent | Subagent | 来源 | +|---|-----------|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|------| +| 0 | ThreadDataMiddleware | ✓ | | | | | | ✓ | ✓ | `sandbox` | +| 1 | UploadsMiddleware | ✓ | | | | | | ✓ | ✗ | `sandbox` | +| 2 | SandboxMiddleware | ✓ | | | ✓ | | | ✓ | ✓ | `sandbox` | +| 3 | DanglingToolCallMiddleware | | | | | ✓ | | ✓ | ✗ | 始终开启 | +| 4 | GuardrailMiddleware | | | | | | ✓ | ✓ | ✓ | *Phase 2 纳入* | +| 5 | ToolErrorHandlingMiddleware | | | | | | ✓ | ✓ | ✓ | 始终开启 | +| 6 | SummarizationMiddleware | | ✓ | | | | | ✓ | ✗ | `summarization` | +| 7 | TodoMiddleware | | ✓ | ✓ | | ✓ | | ✓ | ✗ | `plan_mode` 参数 | +| 8 | TitleMiddleware | | | ✓ | | | | ✓ | ✗ | `auto_title` | +| 9 | MemoryMiddleware | | | | ✓ | | | ✓ | ✗ | `memory` | +| 10 | ViewImageMiddleware | | ✓ | | | | | ✓ | ✗ | `vision` | +| 11 | SubagentLimitMiddleware | | | ✓ | | | | ✓ | ✗ | `subagent` | +| 12 | LoopDetectionMiddleware | ✓ | | ✓ | ✓ | ✓ | | ✓ | ✗ | 始终开启 | +| 13 | ClarificationMiddleware | | | | | | ✓ | ✓ | ✗ | 始终最后 | 主 agent **14 个** middleware(`make_lead_agent`),subagent **4 个**(ThreadData、Sandbox、Guardrail、ToolErrorHandling)。`create_deerflow_agent` Phase 1 实现 **13 个**(Guardrail 仅支持自定义实例,无内置默认)。 @@ -35,7 +35,7 @@ graph TB subgraph BA ["before_agent 正序 0→N"] direction TB - TD["[0] ThreadData
创建线程目录"] --> UL["[1] Uploads
扫描上传文件"] --> SB["[2] Sandbox
获取沙箱"] + TD["[0] ThreadData
创建线程目录"] --> UL["[1] Uploads
扫描上传文件"] --> SB["[2] Sandbox
获取沙箱"] --> LD_BA["[12] LoopDetection
清理 stale warning"] end subgraph BM ["before_model 正序 0→N"] @@ -43,34 +43,42 @@ graph TB VI["[10] ViewImage
注入图片 base64"] end - SB --> VI - VI --> M["MODEL"] + subgraph WM ["wrap_model_call"] + direction TB + DTC_WM["[3] DanglingToolCall
补悬空 ToolMessage"] --> LD_WM["[12] LoopDetection
注入当前 run warning"] + end + + LD_BA --> VI + VI --> DTC_WM + LD_WM --> M["MODEL"] subgraph AM ["after_model 反序 N→0"] direction TB - CL["[13] Clarification
拦截 ask_clarification"] --> LD["[12] LoopDetection
检测循环"] --> SL["[11] SubagentLimit
截断多余 task"] --> TI["[8] Title
生成标题"] --> SM["[6] Summarization
上下文压缩"] --> DTC["[3] DanglingToolCall
补缺失 ToolMessage"] + LD["[12] LoopDetection
检测循环/排队 warning"] --> SL["[11] SubagentLimit
截断多余 task"] --> TI["[8] Title
生成标题"] end - M --> CL + M --> LD subgraph AA ["after_agent 反序 N→0"] direction TB - SBR["[2] Sandbox
释放沙箱"] --> MEM["[9] Memory
入队记忆"] + LD_CLEAN["[12] LoopDetection
清理 pending warning"] --> MEM["[9] Memory
入队记忆"] --> SBR["[2] Sandbox
释放沙箱"] end - DTC --> SBR - MEM --> END(["response"]) + TI --> LD_CLEAN + SBR --> END(["response"]) classDef beforeNode fill:#a0a8b5,stroke:#636b7a,color:#2d3239 classDef modelNode fill:#b5a8a0,stroke:#7a6b63,color:#2d3239 + classDef wrapModelNode fill:#a8a0b5,stroke:#6b637a,color:#2d3239 classDef afterModelNode fill:#b5a0a8,stroke:#7a636b,color:#2d3239 classDef afterAgentNode fill:#a0b5a8,stroke:#637a6b,color:#2d3239 classDef terminalNode fill:#a8b5a0,stroke:#6b7a63,color:#2d3239 - class TD,UL,SB,VI beforeNode + class TD,UL,SB,LD_BA,VI beforeNode + class DTC_WM,LD_WM wrapModelNode class M modelNode - class CL,LD,SL,TI,SM,DTC afterModelNode - class SBR,MEM afterAgentNode + class LD,SL,TI afterModelNode + class LD_CLEAN,SBR,MEM afterAgentNode class START,END terminalNode ``` @@ -82,13 +90,12 @@ sequenceDiagram participant TD as ThreadDataMiddleware participant UL as UploadsMiddleware participant SB as SandboxMiddleware + participant LD as LoopDetectionMiddleware participant VI as ViewImageMiddleware + participant DTC as DanglingToolCallMiddleware participant M as MODEL - participant CL as ClarificationMiddleware participant SL as SubagentLimitMiddleware participant TI as TitleMiddleware - participant SM as SummarizationMiddleware - participant DTC as DanglingToolCallMiddleware participant MEM as MemoryMiddleware U ->> TD: invoke @@ -103,19 +110,26 @@ sequenceDiagram activate SB Note right of SB: before_agent 获取沙箱 - SB ->> VI: before_model + SB ->> LD: before_agent + activate LD + Note right of LD: before_agent 清理同 thread 旧 run 的 pending warning + LD ->> VI: before_model activate VI Note right of VI: before_model 注入图片 base64 - VI ->> M: messages + tools + VI ->> DTC: wrap_model_call + activate DTC + Note right of DTC: wrap_model_call 补悬空 ToolMessage + DTC ->> LD: wrap_model_call + Note right of LD: wrap_model_call drain 当前 run warning 并追加到末尾 + LD ->> M: messages + tools activate M - M -->> CL: AI response + M -->> LD: AI response deactivate M - activate CL - Note right of CL: after_model 拦截 ask_clarification - CL -->> SL: after_model - deactivate CL + Note right of LD: after_model 检测循环;warning 入队,hard-stop 清 tool_calls + LD -->> SL: after_model + deactivate LD activate SL Note right of SL: after_model 截断多余 task @@ -124,22 +138,18 @@ sequenceDiagram activate TI Note right of TI: after_model 生成标题 - TI -->> SM: after_model + TI -->> DTC: done deactivate TI - activate SM - Note right of SM: after_model 上下文压缩 - SM -->> DTC: after_model - deactivate SM - - activate DTC - Note right of DTC: after_model 补缺失 ToolMessage - DTC -->> VI: done deactivate DTC VI -->> SB: done deactivate VI + Note right of LD: after_agent 清理当前 run 未消费 warning + + Note right of MEM: after_agent 入队记忆 + Note right of SB: after_agent 释放沙箱 SB -->> UL: done deactivate SB @@ -147,8 +157,6 @@ sequenceDiagram UL -->> TD: done deactivate UL - Note right of MEM: after_agent 入队记忆 - TD -->> U: response deactivate TD ``` @@ -224,12 +232,12 @@ sequenceDiagram participant TD as ThreadData participant UL as Uploads participant SB as Sandbox + participant LD as LoopDetection participant VI as ViewImage + participant DTC as DanglingToolCall participant M as MODEL - participant CL as Clarification participant SL as SubagentLimit participant TI as Title - participant SM as Summarization participant MEM as Memory U ->> TD: invoke @@ -238,34 +246,40 @@ sequenceDiagram Note right of UL: before_agent 扫描文件 UL ->> SB: . Note right of SB: before_agent 获取沙箱 + SB ->> LD: . + Note right of LD: before_agent 清理 stale pending warning loop 每轮对话(tool call 循环) SB ->> VI: . Note right of VI: before_model 注入图片 - VI ->> M: messages + tools - M -->> CL: AI response - Note right of CL: after_model 拦截 ask_clarification - CL -->> SL: . + VI ->> DTC: . + Note right of DTC: wrap_model_call 补悬空工具结果 + DTC ->> LD: . + Note right of LD: wrap_model_call 注入当前 run warning + LD ->> M: messages + tools + M -->> LD: AI response + Note right of LD: after_model 检测循环/排队 warning + LD -->> SL: . Note right of SL: after_model 截断多余 task SL -->> TI: . Note right of TI: after_model 生成标题 - TI -->> SM: . - Note right of SM: after_model 上下文压缩 end - Note right of SB: after_agent 释放沙箱 - SB -->> MEM: . + Note right of LD: after_agent 清理当前 run pending warning + LD -->> MEM: . Note right of MEM: after_agent 入队记忆 - MEM -->> U: response + MEM -->> SB: . + Note right of SB: after_agent 释放沙箱 + SB -->> U: response ``` > [!warning] 不是洋葱 -> 14 个 middleware 中只有 SandboxMiddleware 有 before/after 对称(获取/释放)。其余都是单向的:要么只在 `before_*` 做事,要么只在 `after_*` 做事。`before_agent` / `after_agent` 只跑一次,`before_model` / `after_model` 每轮循环都跑。 +> 大部分 middleware 只用一个阶段。SandboxMiddleware 使用 `before_agent`/`after_agent` 做资源获取/释放;LoopDetectionMiddleware 也使用这两个钩子,但用途是清理 run-scoped pending warnings,不是资源生命周期对称。`before_agent` / `after_agent` 只跑一次,`before_model` / `after_model` / `wrap_model_call` 每轮循环都跑。 硬依赖只有 2 处: 1. **ThreadData 在 Sandbox 之前** — sandbox 需要线程目录 -2. **Clarification 在列表最后** — `after_model` 反序时最先执行,第一个拦截 `ask_clarification` +2. **Clarification 在列表最后** — `wrap_tool_call` 处理 `ask_clarification` 时优先拦截,并通过 `Command(goto=END)` 中断执行 ### 结论 @@ -273,19 +287,19 @@ sequenceDiagram |---|---|---| | 每个 middleware | before + after 对称 | 大多只用一个钩子 | | 激活条 | 嵌套(外长内短) | 不嵌套(串行) | -| 反序的意义 | 清理与初始化配对 | 仅影响 after_model 的执行优先级 | +| 反序的意义 | 清理与初始化配对 | 影响 `after_model` / `after_agent` 的执行优先级 | | 典型例子 | Auth: 校验 token / 清理上下文 | ThreadData: 只创建目录,没有清理 | ## 关键设计点 ### ClarificationMiddleware 为什么在列表最后? -位置最后 = `after_model` 最先执行。它需要**第一个**看到 model 输出,检查是否有 `ask_clarification` tool call。如果有,立即中断(`Command(goto=END)`),后续 middleware 的 `after_model` 不再执行。 +位置最后使它在工具调用包装链中优先拦截 `ask_clarification`。如果命中,它返回 `Command(goto=END)`,把格式化后的澄清问题写成 `ToolMessage` 并中断执行。 ### SandboxMiddleware 的对称性 `before_agent`(正序第 3 个)获取沙箱,`after_agent`(反序第 1 个)释放沙箱。外层进入 → 外层退出,天然的洋葱对称。 -### 大部分 middleware 只用一个钩子 +### LoopDetectionMiddleware 为什么同时用多个钩子? -14 个 middleware 中,只有 SandboxMiddleware 同时用了 `before_agent` + `after_agent`(获取/释放)。其余都只在一个阶段执行。洋葱模型的反序特性主要影响 `after_model` 阶段的执行顺序。 +`after_model` 只做检测:重复工具调用达到 warning 阈值时,把 warning 放入 `(thread_id, run_id)` 作用域的 pending 队列。真正注入发生在下一次 `wrap_model_call`:此时上一轮 `AIMessage(tool_calls)` 对应的 `ToolMessage` 已经在请求里,warning 追加在末尾,不会破坏 OpenAI/Moonshot 的 tool-call pairing。`before_agent` 清理同一 thread 下旧 run 的残留 warning,`after_agent` 清理当前 run 没被消费的 warning。 diff --git a/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py index db83051e9..396377952 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/loop_detection_middleware.py @@ -6,10 +6,36 @@ arguments indefinitely until the recursion limit kills the run. Detection strategy: 1. After each model response, hash the tool calls (name + args). 2. Track recent hashes in a sliding window. - 3. If the same hash appears >= warn_threshold times, inject a - "you are repeating yourself — wrap up" system message (once per hash). + 3. If the same hash appears >= warn_threshold times, queue a + "you are repeating yourself — wrap up" warning for the current + thread/run. The warning is **injected at the next model call** (in + ``wrap_model_call``) as a ``HumanMessage`` appended to the message + list, *after* all ToolMessage responses to the previous + AIMessage(tool_calls). 4. If it appears >= hard_limit times, strip all tool_calls from the response so the agent is forced to produce a final text answer. + +Why the warning is injected at ``wrap_model_call`` instead of +``after_model``: + + ``after_model`` fires immediately after the model emits an + ``AIMessage`` that may carry ``tool_calls``. The tools node has not + run yet, so no matching ``ToolMessage`` exists in the history. Any + message we add here lands *between* the assistant's tool_calls and + their responses. OpenAI/Moonshot reject the next request with + ``"tool_call_ids did not have response messages"`` because their + validators require the assistant's tool_calls to be followed + immediately by tool messages. Anthropic also disallows mid-stream + ``SystemMessage``. By deferring the warning to ``wrap_model_call``, + every prior ToolMessage is already present in the request's message + list and the warning is appended at the end — pairing intact, no + ``AIMessage`` semantics are mutated. + +Queued warnings are intentionally transient. If a run ends before the +next model request drains a queued warning, ``after_agent`` drops it +instead of carrying it into a later invocation for the same thread. The +hard-stop path still forces termination when the configured safety limit +is reached. """ from __future__ import annotations @@ -19,11 +45,14 @@ import json import logging import threading from collections import OrderedDict, defaultdict +from collections.abc import Awaitable, Callable from copy import deepcopy from typing import TYPE_CHECKING, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware +from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse +from langchain_core.messages import HumanMessage from langgraph.runtime import Runtime if TYPE_CHECKING: @@ -38,6 +67,7 @@ _DEFAULT_WINDOW_SIZE = 20 # track last N tool calls _DEFAULT_MAX_TRACKED_THREADS = 100 # LRU eviction limit _DEFAULT_TOOL_FREQ_WARN = 30 # warn after 30 calls to the same tool type _DEFAULT_TOOL_FREQ_HARD_LIMIT = 50 # force-stop after 50 calls to the same tool type +_MAX_PENDING_WARNINGS_PER_RUN = 4 def _normalize_tool_call_args(raw_args: object) -> tuple[dict, str | None]: @@ -195,6 +225,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): self._warned: dict[str, set[str]] = defaultdict(set) self._tool_freq: dict[str, dict[str, int]] = defaultdict(lambda: defaultdict(int)) self._tool_freq_warned: dict[str, set[str]] = defaultdict(set) + # Per-thread/run queue of warnings to inject at the next model call. + # Populated by ``after_model`` (detection) and drained by + # ``wrap_model_call`` (injection); see module docstring. + self._pending_warnings: dict[tuple[str, str], list[str]] = defaultdict(list) + self._pending_warning_touch_order: OrderedDict[tuple[str, str], None] = OrderedDict() + self._max_pending_warning_keys = max(1, self.max_tracked_threads * 2) @classmethod def from_config(cls, config: LoopDetectionConfig) -> LoopDetectionMiddleware: @@ -213,9 +249,20 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): """Extract thread_id from runtime context for per-thread tracking.""" thread_id = runtime.context.get("thread_id") if runtime.context else None if thread_id: - return thread_id + return str(thread_id) return "default" + def _get_run_id(self, runtime: Runtime) -> str: + """Extract run_id from runtime context for per-run warning scoping.""" + run_id = runtime.context.get("run_id") if runtime.context else None + if run_id: + return str(run_id) + return "default" + + def _pending_key(self, runtime: Runtime) -> tuple[str, str]: + """Return the pending-warning key for the current thread/run.""" + return self._get_thread_id(runtime), self._get_run_id(runtime) + def _evict_if_needed(self) -> None: """Evict least recently used threads if over the limit. @@ -226,8 +273,52 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): self._warned.pop(evicted_id, None) self._tool_freq.pop(evicted_id, None) self._tool_freq_warned.pop(evicted_id, None) + for key in list(self._pending_warnings): + if key[0] == evicted_id: + self._drop_pending_warning_key_locked(key) logger.debug("Evicted loop tracking for thread %s (LRU)", evicted_id) + def _drop_pending_warning_key_locked(self, key: tuple[str, str]) -> None: + """Drop all pending-warning bookkeeping for one thread/run key. + + Must be called while holding self._lock. + """ + self._pending_warnings.pop(key, None) + self._pending_warning_touch_order.pop(key, None) + + def _touch_pending_warning_key_locked(self, key: tuple[str, str]) -> None: + """Mark a pending-warning key as recently used. + + Must be called while holding self._lock. + """ + self._pending_warning_touch_order[key] = None + self._pending_warning_touch_order.move_to_end(key) + + def _prune_pending_warning_state_locked(self, protected_key: tuple[str, str]) -> None: + """Cap pending-warning state across abnormal or concurrent runs. + + Must be called while holding self._lock. + """ + overflow = len(self._pending_warning_touch_order) - self._max_pending_warning_keys + if overflow <= 0: + return + + candidates = [key for key in self._pending_warning_touch_order if key != protected_key] + for key in candidates[:overflow]: + self._drop_pending_warning_key_locked(key) + + def _queue_pending_warning(self, runtime: Runtime, warning: str) -> None: + """Queue one transient warning for the current thread/run with caps.""" + pending_key = self._pending_key(runtime) + with self._lock: + warnings = self._pending_warnings[pending_key] + if warning not in warnings: + warnings.append(warning) + if len(warnings) > _MAX_PENDING_WARNINGS_PER_RUN: + del warnings[: len(warnings) - _MAX_PENDING_WARNINGS_PER_RUN] + self._touch_pending_warning_key_locked(pending_key) + self._prune_pending_warning_state_locked(protected_key=pending_key) + def _track_and_check(self, state: AgentState, runtime: Runtime) -> tuple[str | None, bool]: """Track tool calls and check for loops. @@ -268,6 +359,12 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): if len(history) > self.window_size: history[:] = history[-self.window_size :] + warned_hashes = self._warned.get(thread_id) + if warned_hashes is not None: + warned_hashes.intersection_update(history) + if not warned_hashes: + self._warned.pop(thread_id, None) + count = history.count(call_hash) tool_names = [tc.get("name", "?") for tc in tool_calls] @@ -381,7 +478,10 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): warning, hard_stop = self._track_and_check(state, runtime) if hard_stop: - # Strip tool_calls from the last AIMessage to force text output + # Strip tool_calls from the last AIMessage to force text output. + # Once tool_calls are stripped, the AIMessage no longer requires + # matching ToolMessage responses, so mutating it in place here + # is safe for OpenAI/Moonshot pairing validators. messages = state.get("messages", []) last_msg = messages[-1] content = self._append_text(last_msg.content, warning or _HARD_STOP_MSG) @@ -389,33 +489,48 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): return {"messages": [stripped_msg]} if warning: - # WORKAROUND for v2.0-m1 — see #2724. - # - # Append the warning to the AIMessage content instead of - # injecting a separate HumanMessage. Inserting any non-tool - # message between an AIMessage(tool_calls=...) and its - # ToolMessage responses breaks OpenAI/Moonshot strict pairing - # validation ("tool_call_ids did not have response messages") - # because the tools node has not run yet at after_model time. - # tool_calls are preserved so the tools node still executes. - # - # This is a temporary mitigation: mutating an existing - # AIMessage to carry framework-authored text leaks loop-warning - # text into downstream consumers (MemoryMiddleware fact - # extraction, TitleMiddleware, telemetry, model replay) as if - # the model said it. The proper fix is to defer warning - # injection from after_model to wrap_model_call so every prior - # ToolMessage is already in the request — see RFC #2517 (which - # lists "loop intervention does not leave invalid - # tool-call/tool-message state" as acceptance criteria) and - # the prototype on `fix/loop-detection-tool-call-pairing`. - messages = state.get("messages", []) - last_msg = messages[-1] - patched_msg = last_msg.model_copy(update={"content": self._append_text(last_msg.content, warning)}) - return {"messages": [patched_msg]} + # Defer injection to the next model call. We must NOT alter the + # AIMessage(tool_calls=...) here (would put framework words in + # the model's mouth, polluting downstream consumers like + # MemoryMiddleware), nor insert a separate non-tool message + # (would break OpenAI/Moonshot tool-call pairing because the + # tools node has not produced ToolMessage responses yet). The + # warning is delivered via ``wrap_model_call`` below. + self._queue_pending_warning(runtime, warning) + return None return None + def _clear_other_run_pending_warnings(self, runtime: Runtime) -> None: + """Drop stale pending warnings for previous runs in this thread.""" + thread_id, current_run_id = self._pending_key(runtime) + with self._lock: + for key in list(self._pending_warnings): + if key[0] == thread_id and key[1] != current_run_id: + self._drop_pending_warning_key_locked(key) + + def _clear_current_run_pending_warnings(self, runtime: Runtime) -> None: + """Drop pending warnings owned by the current thread/run.""" + pending_key = self._pending_key(runtime) + with self._lock: + self._drop_pending_warning_key_locked(pending_key) + + @staticmethod + def _format_warning_message(warnings: list[str]) -> str: + """Merge pending warnings into one prompt message.""" + deduped = list(dict.fromkeys(warnings)) + return "\n\n".join(deduped) + + @override + def before_agent(self, state: AgentState, runtime: Runtime) -> dict | None: + self._clear_other_run_pending_warnings(runtime) + return None + + @override + async def abefore_agent(self, state: AgentState, runtime: Runtime) -> dict | None: + self._clear_other_run_pending_warnings(runtime) + return None + @override def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: return self._apply(state, runtime) @@ -424,6 +539,59 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None: return self._apply(state, runtime) + @override + def after_agent(self, state: AgentState, runtime: Runtime) -> dict | None: + self._clear_current_run_pending_warnings(runtime) + return None + + @override + async def aafter_agent(self, state: AgentState, runtime: Runtime) -> dict | None: + self._clear_current_run_pending_warnings(runtime) + return None + + def _drain_pending_warnings(self, runtime: Runtime) -> list[str]: + """Pop and return all queued warnings for *runtime*'s thread/run.""" + pending_key = self._pending_key(runtime) + with self._lock: + warnings = self._pending_warnings.pop(pending_key, []) + self._pending_warning_touch_order.pop(pending_key, None) + return warnings + + def _augment_request(self, request: ModelRequest) -> ModelRequest: + """Append queued loop warnings (if any) to the outgoing message list. + + The warning is placed *after* every existing message, including the + ToolMessage responses to the previous AIMessage(tool_calls). This + keeps ``assistant tool_calls -> tool_messages`` pairing intact for + OpenAI/Moonshot, avoids the Anthropic mid-stream SystemMessage + restriction (we use HumanMessage), and never mutates an existing + AIMessage. + """ + warnings = self._drain_pending_warnings(request.runtime) + if not warnings: + return request + new_messages = [ + *request.messages, + HumanMessage(content=self._format_warning_message(warnings), name="loop_warning"), + ] + return request.override(messages=new_messages) + + @override + def wrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse], + ) -> ModelCallResult: + return handler(self._augment_request(request)) + + @override + async def awrap_model_call( + self, + request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]], + ) -> ModelCallResult: + return await handler(self._augment_request(request)) + def reset(self, thread_id: str | None = None) -> None: """Clear tracking state. If thread_id given, clear only that thread.""" with self._lock: @@ -432,8 +600,13 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]): self._warned.pop(thread_id, None) self._tool_freq.pop(thread_id, None) self._tool_freq_warned.pop(thread_id, None) + for key in list(self._pending_warnings): + if key[0] == thread_id: + self._drop_pending_warning_key_locked(key) else: self._history.clear() self._warned.clear() self._tool_freq.clear() self._tool_freq_warned.clear() + self._pending_warnings.clear() + self._pending_warning_touch_order.clear() diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index f85062a17..61a402def 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -372,37 +372,6 @@ class TestExtractResponseText: # Should return "" (no text in current turn), NOT "Hi there!" from previous turn assert _extract_response_text(result) == "" - def test_does_not_publish_loop_warning_on_tool_calling_ai_message(self): - """Loop-detection warning text on a tool-calling AI message is middleware-authored.""" - from app.channels.manager import _extract_response_text - - result = { - "messages": [ - {"type": "human", "content": "search the repo"}, - { - "type": "ai", - "content": "[LOOP DETECTED] You are repeating the same tool calls.", - "tool_calls": [{"name": "grep", "args": {"pattern": "TODO"}, "id": "call_1"}], - }, - ] - } - assert _extract_response_text(result) == "" - - def test_preserves_visible_text_when_stripping_loop_warning(self): - from app.channels.manager import _extract_response_text - - result = { - "messages": [ - {"type": "human", "content": "prepare the report"}, - { - "type": "ai", - "content": "Here is the report.\n\n[LOOP DETECTED] You are repeating the same tool calls.", - "tool_calls": [{"name": "present_files", "args": {"filepaths": ["/mnt/user-data/outputs/report.md"]}, "id": "call_1"}], - }, - ] - } - assert _extract_response_text(result) == "Here is the report." - # --------------------------------------------------------------------------- # ChannelManager tests diff --git a/backend/tests/test_loop_detection_middleware.py b/backend/tests/test_loop_detection_middleware.py index 022afc117..3b7256ad3 100644 --- a/backend/tests/test_loop_detection_middleware.py +++ b/backend/tests/test_loop_detection_middleware.py @@ -1,24 +1,94 @@ """Tests for LoopDetectionMiddleware.""" import copy +from collections import OrderedDict +from typing import Any from unittest.mock import MagicMock -from langchain_core.messages import AIMessage, SystemMessage +import pytest +from langchain.agents import create_agent +from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.runnables import Runnable +from langchain_core.tools import tool as as_tool +from pydantic import PrivateAttr from deerflow.agents.middlewares.loop_detection_middleware import ( _HARD_STOP_MSG, + _MAX_PENDING_WARNINGS_PER_RUN, LoopDetectionMiddleware, _hash_tool_calls, ) -def _make_runtime(thread_id="test-thread"): +def _make_runtime(thread_id="test-thread", run_id="test-run"): """Build a minimal Runtime mock with context.""" runtime = MagicMock() - runtime.context = {"thread_id": thread_id} + runtime.context = {"thread_id": thread_id, "run_id": run_id} return runtime +def _pending_key(thread_id="test-thread", run_id="test-run"): + return (thread_id, run_id) + + +def _make_request(messages, runtime): + """Build a minimal ModelRequest stand-in for wrap_model_call tests.""" + request = MagicMock() + request.messages = list(messages) + request.runtime = runtime + request.override = lambda **updates: _override_request(request, updates) + return request + + +def _override_request(request, updates): + """Mimic ModelRequest.override(): return a copy with fields replaced.""" + new = MagicMock() + new.messages = updates.get("messages", request.messages) + new.runtime = updates.get("runtime", request.runtime) + new.override = lambda **u: _override_request(new, u) + return new + + +def _capture_handler(): + """Build a sync handler that records the request it was called with.""" + captured: list = [] + + def handler(req): + captured.append(req) + return MagicMock() + + return captured, handler + + +class _CapturingFakeMessagesListChatModel(FakeMessagesListChatModel): + """Fake chat model that records each model request's messages.""" + + _seen_messages: list[list[Any]] = PrivateAttr(default_factory=list) + + @property + def seen_messages(self) -> list[list[Any]]: + return self._seen_messages + + def bind_tools( + self, + tools: Any, + *, + tool_choice: Any = None, + **kwargs: Any, + ) -> Runnable: + return self + + def _generate(self, messages, stop=None, run_manager=None, **kwargs): + self._seen_messages.append(list(messages)) + return super()._generate( + messages, + stop=stop, + run_manager=run_manager, + **kwargs, + ) + + def _make_state(tool_calls=None, content=""): """Build a minimal AgentState dict with an AIMessage. @@ -138,7 +208,15 @@ class TestLoopDetection: result = mw._apply(_make_state(tool_calls=call), runtime) assert result is None - def test_warn_at_threshold(self): + def test_warn_at_threshold_queues_but_does_not_mutate_state(self): + """At warn threshold, ``after_model`` enqueues but returns None. + + Detection observes the just-emitted AIMessage(tool_calls=...). The + tools node hasn't run yet, so injecting any non-tool message here + would split the assistant's tool_calls from their ToolMessage + responses and break OpenAI/Moonshot pairing. The warning is + delivered later from ``wrap_model_call``. + """ mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=5) runtime = _make_runtime() call = [_bash_call("ls")] @@ -146,44 +224,150 @@ class TestLoopDetection: for _ in range(2): mw._apply(_make_state(tool_calls=call), runtime) - # Third identical call triggers warning. The warning is appended to - # the AIMessage content (tool_calls preserved) — never inserted as a - # separate HumanMessage between the AIMessage(tool_calls) and its - # ToolMessage responses, which would break OpenAI/Moonshot strict - # tool-call pairing validation. + # Third identical call triggers warning detection. result = mw._apply(_make_state(tool_calls=call), runtime) - assert result is not None - msgs = result["messages"] - assert len(msgs) == 1 - assert isinstance(msgs[0], AIMessage) - assert len(msgs[0].tool_calls) == len(call) - assert msgs[0].tool_calls[0]["id"] == call[0]["id"] - assert "LOOP DETECTED" in msgs[0].content + # Detection must not mutate state — the AIMessage with tool_calls is + # left untouched so the tools node runs normally. + assert result is None + # ...but a warning is queued for the next model call. + assert mw._pending_warnings[_pending_key()] + assert "LOOP DETECTED" in mw._pending_warnings[_pending_key()][0] - def test_warn_does_not_break_tool_call_pairing(self): - """Regression: the warn branch must NOT inject a non-tool message - after an AIMessage(tool_calls=...). Moonshot/OpenAI reject the next - request with 'tool_call_ids did not have response messages' if any - non-tool message is wedged between the AIMessage and its ToolMessage - responses. See #2029. + def test_warn_injected_at_next_model_call(self): + """``wrap_model_call`` appends a HumanMessage(loop_warning) to the + outgoing messages — *after* every existing message — so that the + AIMessage(tool_calls=...) -> ToolMessage(...) pairing stays intact. """ mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) runtime = _make_runtime() call = [_bash_call("ls")] - - for _ in range(2): + for _ in range(3): mw._apply(_make_state(tool_calls=call), runtime) - result = mw._apply(_make_state(tool_calls=call), runtime) - assert result is not None - msgs = result["messages"] - assert len(msgs) == 1 - assert isinstance(msgs[0], AIMessage) - assert len(msgs[0].tool_calls) == len(call) - assert msgs[0].tool_calls[0]["id"] == call[0]["id"] + # Build the messages the agent runtime would assemble for the next + # turn: prior AIMessage(tool_calls), its ToolMessage responses, ... + ai_msg = AIMessage(content="", tool_calls=call) + tool_msg = ToolMessage(content="ok", tool_call_id=call[0]["id"], name="bash") + request = _make_request([ai_msg, tool_msg], runtime) - def test_warn_only_injected_once(self): - """Warning for the same hash should only be injected once per thread.""" + captured, handler = _capture_handler() + mw.wrap_model_call(request, handler) + + sent = captured[0].messages + # AIMessage and ToolMessage stay in order, untouched. + assert sent[0] is ai_msg + assert sent[1] is tool_msg + # HumanMessage(warning) appears AFTER the ToolMessage — pairing intact. + assert isinstance(sent[2], HumanMessage) + assert sent[2].name == "loop_warning" + assert "LOOP DETECTED" in sent[2].content + + def test_warn_queue_drained_after_injection(self): + """A queued warning must be emitted exactly once per detection event.""" + mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) + runtime = _make_runtime() + call = [_bash_call("ls")] + for _ in range(3): + mw._apply(_make_state(tool_calls=call), runtime) + + request = _make_request([AIMessage(content="hi")], runtime) + captured, handler = _capture_handler() + + # First call: warning is appended. + mw.wrap_model_call(request, handler) + first = captured[0].messages + assert any(isinstance(m, HumanMessage) for m in first) + + # Subsequent call without new detection: no warning re-emitted. + request2 = _make_request([AIMessage(content="hi")], runtime) + mw.wrap_model_call(request2, handler) + second = captured[1].messages + assert not any(isinstance(m, HumanMessage) for m in second) + + def test_warn_queue_scoped_by_run_id(self): + """A warning queued for one run must not be injected into another run.""" + mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) + runtime_a = _make_runtime(run_id="run-A") + runtime_b = _make_runtime(run_id="run-B") + call = [_bash_call("ls")] + + for _ in range(3): + mw._apply(_make_state(tool_calls=call), runtime_a) + + request_b = _make_request([AIMessage(content="hi")], runtime_b) + captured, handler = _capture_handler() + mw.wrap_model_call(request_b, handler) + assert not any(isinstance(m, HumanMessage) for m in captured[0].messages) + assert mw._pending_warnings.get(_pending_key(run_id="run-A")) + + request_a = _make_request([AIMessage(content="hi")], runtime_a) + mw.wrap_model_call(request_a, handler) + assert any(isinstance(message, HumanMessage) and message.name == "loop_warning" for message in captured[1].messages) + + def test_missing_run_id_uses_default_pending_scope(self): + """When runtime has no run_id, warning handling falls back to the default run scope.""" + mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) + runtime = MagicMock() + runtime.context = {"thread_id": "test-thread"} + call = [_bash_call("ls")] + + for _ in range(3): + mw._apply(_make_state(tool_calls=call), runtime) + + assert mw._pending_warnings.get(_pending_key(run_id="default")) + + request = _make_request([AIMessage(content="hi")], runtime) + captured, handler = _capture_handler() + mw.wrap_model_call(request, handler) + + loop_warnings = [message for message in captured[0].messages if isinstance(message, HumanMessage) and message.name == "loop_warning"] + assert len(loop_warnings) == 1 + assert "LOOP DETECTED" in loop_warnings[0].content + assert not mw._pending_warnings.get(_pending_key(run_id="default")) + + def test_before_agent_clears_stale_pending_warnings_for_thread(self): + """Starting a new run drops stale warnings from prior runs in the same thread.""" + mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) + runtime_a = _make_runtime(run_id="run-A") + runtime_b = _make_runtime(run_id="run-B") + call = [_bash_call("ls")] + + for _ in range(3): + mw._apply(_make_state(tool_calls=call), runtime_a) + + assert mw._pending_warnings.get(_pending_key(run_id="run-A")) + mw.before_agent({"messages": []}, runtime_b) + assert not mw._pending_warnings.get(_pending_key(run_id="run-A")) + + def test_after_agent_clears_current_run_pending_warnings(self): + """Run cleanup should drop warnings that never reached wrap_model_call.""" + mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) + runtime = _make_runtime() + call = [_bash_call("ls")] + + for _ in range(3): + mw._apply(_make_state(tool_calls=call), runtime) + + assert mw._pending_warnings.get(_pending_key()) + mw.after_agent({"messages": []}, runtime) + assert not mw._pending_warnings.get(_pending_key()) + + def test_multiple_pending_warnings_are_merged_into_one_message(self): + """Edge-case drains should produce one loop_warning prompt message.""" + mw = LoopDetectionMiddleware() + runtime = _make_runtime() + mw._pending_warnings[_pending_key()] = ["first warning", "second warning", "first warning"] + request = _make_request([AIMessage(content="hi")], runtime) + captured, handler = _capture_handler() + + mw.wrap_model_call(request, handler) + + loop_warnings = [message for message in captured[0].messages if isinstance(message, HumanMessage) and message.name == "loop_warning"] + assert len(loop_warnings) == 1 + assert loop_warnings[0].content == "first warning\n\nsecond warning" + + def test_warn_only_queued_once_per_hash(self): + """Same hash repeated past the threshold should warn only once.""" mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) runtime = _make_runtime() call = [_bash_call("ls")] @@ -192,14 +376,13 @@ class TestLoopDetection: for _ in range(2): mw._apply(_make_state(tool_calls=call), runtime) - # Third — warning injected - result = mw._apply(_make_state(tool_calls=call), runtime) - assert result is not None - assert "LOOP DETECTED" in result["messages"][0].content + # Third — warning queued + mw._apply(_make_state(tool_calls=call), runtime) + assert len(mw._pending_warnings[_pending_key()]) == 1 - # Fourth — warning already injected, should return None - result = mw._apply(_make_state(tool_calls=call), runtime) - assert result is None + # Fourth — already warned for this hash, no additional enqueue. + mw._apply(_make_state(tool_calls=call), runtime) + assert len(mw._pending_warnings[_pending_key()]) == 1 def test_hard_stop_at_limit(self): mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4) @@ -257,6 +440,7 @@ class TestLoopDetection: mw.reset() result = mw._apply(_make_state(tool_calls=call), runtime) assert result is None + assert not mw._pending_warnings.get(_pending_key()) def test_non_ai_message_ignored(self): mw = LoopDetectionMiddleware() @@ -283,15 +467,16 @@ class TestLoopDetection: # One call on thread B mw._apply(_make_state(tool_calls=call), runtime_b) - # Second call on thread A — triggers warning (2 >= warn_threshold) - result = mw._apply(_make_state(tool_calls=call), runtime_a) - assert result is not None - assert "LOOP DETECTED" in result["messages"][0].content + # Second call on thread A — queues warning under thread-A only. + mw._apply(_make_state(tool_calls=call), runtime_a) + assert mw._pending_warnings.get(_pending_key("thread-A")) + assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-A")][0] + assert not mw._pending_warnings.get(_pending_key("thread-B")) - # Second call on thread B — also triggers (independent tracking) - result = mw._apply(_make_state(tool_calls=call), runtime_b) - assert result is not None - assert "LOOP DETECTED" in result["messages"][0].content + # Second call on thread B — independent queue. + mw._apply(_make_state(tool_calls=call), runtime_b) + assert mw._pending_warnings.get(_pending_key("thread-B")) + assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-B")][0] def test_lru_eviction(self): """Old threads should be evicted when max_tracked_threads is exceeded.""" @@ -313,6 +498,55 @@ class TestLoopDetection: assert "thread-new" in mw._history assert len(mw._history) == 3 + def test_warned_hashes_are_pruned_to_sliding_window(self): + """A long-lived thread should not keep every historical warned hash.""" + mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=100, window_size=4) + runtime = _make_runtime() + + for i in range(12): + call = [_bash_call(f"cmd_{i}")] + mw._apply(_make_state(tool_calls=call), runtime) + mw._apply(_make_state(tool_calls=call), runtime) + + assert len(mw._history["test-thread"]) <= 4 + assert set(mw._warned["test-thread"]).issubset(set(mw._history["test-thread"])) + assert len(mw._warned["test-thread"]) <= 4 + + def test_pending_warning_keys_are_capped(self): + """Abnormal same-thread runs cannot grow pending-warning keys forever.""" + mw = LoopDetectionMiddleware(warn_threshold=2, max_tracked_threads=2) + + for i in range(10): + runtime = _make_runtime(thread_id="same-thread", run_id=f"run-{i}") + mw._queue_pending_warning(runtime, f"warning-{i}") + + assert len(mw._pending_warnings) == mw._max_pending_warning_keys + assert len(mw._pending_warning_touch_order) == mw._max_pending_warning_keys + assert _pending_key("same-thread", "run-9") in mw._pending_warnings + + def test_pending_warning_list_is_capped_and_deduped(self): + """One run cannot accumulate an unbounded warning list.""" + mw = LoopDetectionMiddleware() + runtime = _make_runtime() + + for i in range(_MAX_PENDING_WARNINGS_PER_RUN + 4): + mw._queue_pending_warning(runtime, f"warning-{i}") + mw._queue_pending_warning(runtime, f"warning-{_MAX_PENDING_WARNINGS_PER_RUN + 3}") + + warnings = mw._pending_warnings[_pending_key()] + assert len(warnings) == _MAX_PENDING_WARNINGS_PER_RUN + assert warnings == [f"warning-{i}" for i in range(4, _MAX_PENDING_WARNINGS_PER_RUN + 4)] + + def test_pending_warning_touch_order_cleared_with_pending_key(self): + mw = LoopDetectionMiddleware() + runtime = _make_runtime() + mw._queue_pending_warning(runtime, "warning") + + mw.after_agent({"messages": []}, runtime) + + assert mw._pending_warnings == {} + assert mw._pending_warning_touch_order == OrderedDict() + def test_thread_safe_mutations(self): """Verify lock is used for mutations (basic structural test).""" mw = LoopDetectionMiddleware() @@ -331,6 +565,99 @@ class TestLoopDetection: assert "default" in mw._history +class TestLoopDetectionAgentGraphIntegration: + def test_loop_warning_is_transient_in_real_agent_graph(self): + """after_model queues the warning; wrap_model_call injects it request-only.""" + + @as_tool + def bash(command: str) -> str: + """Run a fake shell command.""" + return f"ran: {command}" + + repeated_calls = [[{"name": "bash", "id": f"call_ls_{i}", "args": {"command": "ls"}}] for i in range(3)] + mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) + model = _CapturingFakeMessagesListChatModel( + responses=[ + AIMessage(content="", tool_calls=repeated_calls[0]), + AIMessage(content="", tool_calls=repeated_calls[1]), + AIMessage(content="", tool_calls=repeated_calls[2]), + AIMessage(content="final answer"), + ], + ) + graph = create_agent(model=model, tools=[bash], middleware=[mw]) + + result = graph.invoke( + {"messages": [("user", "inspect the directory")]}, + context={"thread_id": "integration-thread", "run_id": "integration-run"}, + config={"recursion_limit": 20}, + ) + + assert len(model.seen_messages) == 4 + loop_warnings_by_call = [[message for message in messages if isinstance(message, HumanMessage) and message.name == "loop_warning"] for messages in model.seen_messages] + assert loop_warnings_by_call[0] == [] + assert loop_warnings_by_call[1] == [] + assert loop_warnings_by_call[2] == [] + assert len(loop_warnings_by_call[3]) == 1 + assert "LOOP DETECTED" in loop_warnings_by_call[3][0].content + + fourth_request = model.seen_messages[3] + assert isinstance(fourth_request[-2], ToolMessage) + assert fourth_request[-2].tool_call_id == "call_ls_2" + assert fourth_request[-1] is loop_warnings_by_call[3][0] + + persisted_loop_warnings = [message for message in result["messages"] if isinstance(message, HumanMessage) and message.name == "loop_warning"] + assert persisted_loop_warnings == [] + assert result["messages"][-1].content == "final answer" + assert mw._pending_warnings == {} + assert mw._pending_warning_touch_order == OrderedDict() + + @pytest.mark.asyncio + async def test_loop_warning_is_transient_in_async_agent_graph(self): + """awrap_model_call injects loop_warning request-only in async graph runs.""" + + @as_tool + async def bash(command: str) -> str: + """Run a fake shell command.""" + return f"ran: {command}" + + repeated_calls = [[{"name": "bash", "id": f"call_async_ls_{i}", "args": {"command": "ls"}}] for i in range(3)] + mw = LoopDetectionMiddleware(warn_threshold=3, hard_limit=10) + model = _CapturingFakeMessagesListChatModel( + responses=[ + AIMessage(content="", tool_calls=repeated_calls[0]), + AIMessage(content="", tool_calls=repeated_calls[1]), + AIMessage(content="", tool_calls=repeated_calls[2]), + AIMessage(content="async final answer"), + ], + ) + graph = create_agent(model=model, tools=[bash], middleware=[mw]) + + result = await graph.ainvoke( + {"messages": [("user", "inspect the directory asynchronously")]}, + context={"thread_id": "async-integration-thread", "run_id": "async-integration-run"}, + config={"recursion_limit": 20}, + ) + + assert len(model.seen_messages) == 4 + loop_warnings_by_call = [[message for message in messages if isinstance(message, HumanMessage) and message.name == "loop_warning"] for messages in model.seen_messages] + assert loop_warnings_by_call[0] == [] + assert loop_warnings_by_call[1] == [] + assert loop_warnings_by_call[2] == [] + assert len(loop_warnings_by_call[3]) == 1 + assert "LOOP DETECTED" in loop_warnings_by_call[3][0].content + + fourth_request = model.seen_messages[3] + assert isinstance(fourth_request[-2], ToolMessage) + assert fourth_request[-2].tool_call_id == "call_async_ls_2" + assert fourth_request[-1] is loop_warnings_by_call[3][0] + + persisted_loop_warnings = [message for message in result["messages"] if isinstance(message, HumanMessage) and message.name == "loop_warning"] + assert persisted_loop_warnings == [] + assert result["messages"][-1].content == "async final answer" + assert mw._pending_warnings == {} + assert mw._pending_warning_touch_order == OrderedDict() + + class TestAppendText: """Unit tests for LoopDetectionMiddleware._append_text.""" @@ -507,33 +834,29 @@ class TestToolFrequencyDetection: for i in range(4): mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) - # 5th call to read_file (different file each time) triggers freq warning + # 5th call queues a per-tool-type frequency warning; state untouched. result = mw._apply(_make_state(tool_calls=[self._read_call("/file_4.py")]), runtime) - assert result is not None - msg = result["messages"][0] - # Warning is appended to the AIMessage content; tool_calls preserved - # so the tools node still runs and Moonshot/OpenAI tool-call pairing - # validation does not break. - assert isinstance(msg, AIMessage) - assert msg.tool_calls - assert "read_file" in msg.content - assert "LOOP DETECTED" in msg.content + assert result is None + queued = mw._pending_warnings.get(_pending_key(), []) + assert queued + assert "read_file" in queued[0] + assert "LOOP DETECTED" in queued[0] - def test_freq_warn_only_injected_once(self): + def test_freq_warn_only_queued_once(self): mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10) runtime = _make_runtime() for i in range(2): mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) - # 3rd triggers warning - result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime) - assert result is not None - assert "LOOP DETECTED" in result["messages"][0].content + # 3rd queues a frequency warning. + mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime) + assert len(mw._pending_warnings[_pending_key()]) == 1 - # 4th should not re-warn (already warned for read_file) + # 4th: same tool name, no additional enqueue. result = mw._apply(_make_state(tool_calls=[self._read_call("/file_3.py")]), runtime) assert result is None + assert len(mw._pending_warnings[_pending_key()]) == 1 def test_freq_hard_stop_at_limit(self): mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=6) @@ -565,10 +888,10 @@ class TestToolFrequencyDetection: result = mw._apply(_make_state(tool_calls=[_bash_call(f"cmd_{i}")]), runtime) assert result is None - # 3rd read_file triggers (read_file count = 3) + # 3rd read_file triggers — warning is queued (state unchanged). result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime) - assert result is not None - assert "read_file" in result["messages"][0].content + assert result is None + assert "read_file" in mw._pending_warnings[_pending_key()][0] def test_freq_reset_clears_state(self): mw = LoopDetectionMiddleware(tool_freq_warn=3, tool_freq_hard_limit=10) @@ -600,10 +923,10 @@ class TestToolFrequencyDetection: assert "thread-A" not in mw._tool_freq assert "thread-A" not in mw._tool_freq_warned - # thread-B state should still be intact — 3rd call triggers warn + # thread-B state should still be intact — 3rd call queues a warn. result = mw._apply(_make_state(tool_calls=[self._read_call("/b_2.py")]), runtime_b) - assert result is not None - assert "LOOP DETECTED" in result["messages"][0].content + assert result is None + assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-B")][0] # thread-A restarted from 0 — should not trigger result = mw._apply(_make_state(tool_calls=[self._read_call("/a_new.py")]), runtime_a) @@ -623,10 +946,11 @@ class TestToolFrequencyDetection: for i in range(2): mw._apply(_make_state(tool_calls=[self._read_call(f"/other_{i}.py")]), runtime_b) - # 3rd call on thread A — triggers (count=3 for thread A only) + # 3rd call on thread A — queues a warning (count=3 for thread A only). result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime_a) - assert result is not None - assert "LOOP DETECTED" in result["messages"][0].content + assert result is None + assert "LOOP DETECTED" in mw._pending_warnings[_pending_key("thread-A")][0] + assert not mw._pending_warnings.get(_pending_key("thread-B")) def test_multi_tool_single_response_counted(self): """When a single response has multiple tool calls, each is counted.""" @@ -643,10 +967,10 @@ class TestToolFrequencyDetection: result = mw._apply(_make_state(tool_calls=call), runtime) assert result is None - # Response 3: 1 more → count = 5 → triggers warn + # Response 3: 1 more → count = 5 → queues warn. result = mw._apply(_make_state(tool_calls=[self._read_call("/e.py")]), runtime) - assert result is not None - assert "read_file" in result["messages"][0].content + assert result is None + assert "read_file" in mw._pending_warnings[_pending_key()][0] def test_override_tool_uses_override_thresholds(self): """A tool in tool_freq_overrides uses its own thresholds, not the global ones.""" @@ -674,10 +998,14 @@ class TestToolFrequencyDetection: for i in range(2): mw._apply(_make_state(tool_calls=[self._read_call(f"/file_{i}.py")]), runtime) - # 3rd read_file call hits global warn=3 (read_file has no override) + # 3rd read_file call hits global warn=3 (read_file has no override). + # Warning delivery is deferred to wrap_model_call so the just-emitted + # AIMessage(tool_calls=...) is not mutated before ToolMessages exist. result = mw._apply(_make_state(tool_calls=[self._read_call("/file_2.py")]), runtime) - assert result is not None - assert "read_file" in result["messages"][0].content + assert result is None + queued = mw._pending_warnings.get(_pending_key(), []) + assert queued + assert "read_file" in queued[0] def test_hash_detection_takes_priority(self): """Hash-based hard stop fires before frequency check for identical calls.""" @@ -736,11 +1064,13 @@ class TestFromConfig: mw = LoopDetectionMiddleware.from_config(self._config()) assert mw._tool_freq_overrides == {} - def test_constructed_middleware_detects_loops(self): + def test_constructed_middleware_queues_loop_warning(self): mw = LoopDetectionMiddleware.from_config(self._config(warn_threshold=2, hard_limit=4)) runtime = _make_runtime() call = [_bash_call("ls")] mw._apply(_make_state(tool_calls=call), runtime) result = mw._apply(_make_state(tool_calls=call), runtime) - assert result is not None - assert "LOOP DETECTED" in result["messages"][0].content + assert result is None + queued = mw._pending_warnings.get(_pending_key(), []) + assert queued + assert "LOOP DETECTED" in queued[0] diff --git a/frontend/src/content/en/harness/middlewares.mdx b/frontend/src/content/en/harness/middlewares.mdx index 72195664d..389b9881c 100644 --- a/frontend/src/content/en/harness/middlewares.mdx +++ b/frontend/src/content/en/harness/middlewares.mdx @@ -50,6 +50,8 @@ Intercepts clarification tool calls and converts them into proper user-facing re Detects when the agent is making the same tool call repeatedly without making progress. When a loop is detected, the middleware intervenes to break the cycle and prevents the agent from burning turns indefinitely. +Warning interventions are queued per thread and run, then drained on the next model call as a single hidden `HumanMessage(name="loop_warning")` appended after existing tool results. This keeps provider tool-call pairing valid. Run start/end hooks clear stale or undelivered warnings, and hard stops still strip tool calls before forcing a final text response. + **Configuration**: built-in, no user configuration. --- diff --git a/frontend/src/content/zh/harness/middlewares.mdx b/frontend/src/content/zh/harness/middlewares.mdx index 051729b94..9e81caa3e 100644 --- a/frontend/src/content/zh/harness/middlewares.mdx +++ b/frontend/src/content/zh/harness/middlewares.mdx @@ -50,6 +50,8 @@ import { Callout } from "nextra/components"; 检测 Agent 是否在没有取得进展的情况下重复进行相同的工具调用。检测到循环时,中间件会介入打破循环,防止 Agent 无限消耗轮次。 +Warning 介入会按 thread 和 run 排队,并在下一次模型调用时合并为一条隐藏的 `HumanMessage(name="loop_warning")`,追加到已有工具结果之后。这样不会破坏 provider 对 tool-call/tool-message 配对的校验。Run 开始和结束时会清理过期或未送达的 warning;达到 hard stop 时仍会清空 tool calls 并强制生成最终文本回复。 + **配置**:内置,无需用户配置。 --- From 8b697245ebe835bedc4e386dbfb15c422e2f6e11 Mon Sep 17 00:00:00 2001 From: AochenShen99 <142667174+ShenAC-SAC@users.noreply.github.com> Date: Thu, 21 May 2026 14:44:34 +0800 Subject: [PATCH 56/86] fix(sandbox): avoid blocking sandbox readiness polling (#2822) * fix(sandbox): offload async sandbox acquisition Run blocking sandbox provider acquisition through the async provider hook so eager sandbox setup does not stall the event loop. * fix(sandbox): add async readiness polling Introduce an async sandbox readiness poller using httpx and asyncio.sleep while preserving the existing synchronous API. * test(sandbox): cover async readiness polling Lock in non-blocking readiness behavior so the async helper does not regress to requests.get or time.sleep. * fix(sandbox): allow anonymous backend creation * fix(sandbox): use async readiness in provider acquisition * fix(sandbox): use async acquisition for lazy tools * test(sandbox): cover anonymous remote creation * fix(sandbox): clamp async readiness timeout budget * fix(sandbox): offload async lock file handling * fix(sandbox): delegate async middleware fallthrough * docs(sandbox): document async acquisition path * fix(sandbox): offload async sandbox release * docs(sandbox): mention async release hook * fix(sandbox): address async lock review Reduce duplicate sync/async sandbox acquisition state handling and move async thread-lock waits onto a dedicated executor with cancellation-safe cleanup. * chore: retrigger ci Retrigger GitHub Actions after upstream main fixed the stale PR merge lint failure. * test(sandbox): sync backend unit fixtures --------- Co-authored-by: Willem Jiang --- backend/CLAUDE.md | 2 +- backend/README.md | 2 +- .../aio_sandbox/aio_sandbox_provider.py | 297 ++++++++++++++---- .../deerflow/community/aio_sandbox/backend.py | 32 +- .../community/aio_sandbox/local_backend.py | 2 +- .../community/aio_sandbox/remote_backend.py | 4 +- .../harness/deerflow/sandbox/middleware.py | 45 +++ .../deerflow/sandbox/sandbox_provider.py | 11 + .../harness/deerflow/sandbox/tools.py | 174 ++++++++++ backend/tests/test_aio_sandbox_provider.py | 177 +++++++++++ backend/tests/test_aio_sandbox_readiness.py | 119 +++++++ backend/tests/test_remote_sandbox_backend.py | 20 ++ backend/tests/test_sandbox_middleware.py | 225 +++++++++++++ 13 files changed, 1037 insertions(+), 73 deletions(-) create mode 100644 backend/tests/test_aio_sandbox_readiness.py create mode 100644 backend/tests/test_sandbox_middleware.py diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index b951f919c..886b82dcb 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -236,7 +236,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti ### Sandbox System (`packages/harness/deerflow/sandbox/`) **Interface**: Abstract `Sandbox` with `execute_command`, `read_file`, `write_file`, `list_dir` -**Provider Pattern**: `SandboxProvider` with `acquire`, `get`, `release` lifecycle +**Provider Pattern**: `SandboxProvider` with `acquire`, `acquire_async`, `get`, `release` lifecycle. Async agent/tool paths call async sandbox lifecycle hooks so Docker sandbox creation, discovery, cross-process locking, readiness polling, and release stay off the event loop. **Implementations**: - `LocalSandboxProvider` - Local filesystem execution. `acquire(thread_id)` returns a per-thread `LocalSandbox` (id `local:{thread_id}`) whose `path_mappings` resolve `/mnt/user-data/{workspace,uploads,outputs}` and `/mnt/acp-workspace` to that thread's host directories, so the public `Sandbox` API honours the `/mnt/user-data` contract uniformly with AIO. `acquire()` / `acquire(None)` keeps the legacy generic singleton (id `local`) for callers without a thread context. Per-thread sandboxes are held in an LRU cache (default 256 entries) guarded by a `threading.Lock`. - `AioSandboxProvider` (`packages/harness/deerflow/community/`) - Docker-based isolation diff --git a/backend/README.md b/backend/README.md index 8c61e2db2..0ee0d454b 100644 --- a/backend/README.md +++ b/backend/README.md @@ -69,7 +69,7 @@ Middlewares execute in strict order, each handling a specific concern: Per-thread isolated execution with virtual path translation: - **Abstract interface**: `execute_command`, `read_file`, `write_file`, `list_dir` -- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/) +- **Providers**: `LocalSandboxProvider` (filesystem) and `AioSandboxProvider` (Docker, in community/). Async runtime paths use async sandbox lifecycle hooks so startup, readiness polling, and release do not block the event loop. - **Virtual paths**: `/mnt/user-data/{workspace,uploads,outputs}` → thread-specific physical directories - **Skills path**: `/mnt/skills` → `deer-flow/skills/` directory - **Skills loading**: Recursively discovers nested `SKILL.md` files under `skills/{public,custom}` and preserves nested container paths diff --git a/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox_provider.py b/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox_provider.py index 292a43758..4d7e16cab 100644 --- a/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox_provider.py +++ b/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox_provider.py @@ -10,6 +10,7 @@ The provider itself handles: - Mount computation (thread-specific, skills) """ +import asyncio import atexit import hashlib import logging @@ -18,6 +19,7 @@ import signal import threading import time import uuid +from concurrent.futures import ThreadPoolExecutor try: import fcntl @@ -32,7 +34,7 @@ from deerflow.sandbox.sandbox import Sandbox from deerflow.sandbox.sandbox_provider import SandboxProvider from .aio_sandbox import AioSandbox -from .backend import SandboxBackend, wait_for_sandbox_ready +from .backend import SandboxBackend, wait_for_sandbox_ready, wait_for_sandbox_ready_async from .local_backend import LocalContainerBackend from .remote_backend import RemoteSandboxBackend from .sandbox_info import SandboxInfo @@ -46,6 +48,9 @@ DEFAULT_CONTAINER_PREFIX = "deer-flow-sandbox" DEFAULT_IDLE_TIMEOUT = 600 # 10 minutes in seconds DEFAULT_REPLICAS = 3 # Maximum concurrent sandbox containers IDLE_CHECK_INTERVAL = 60 # Check every 60 seconds +THREAD_LOCK_EXECUTOR_WORKERS = min(32, (os.cpu_count() or 1) + 4) +_THREAD_LOCK_EXECUTOR = ThreadPoolExecutor(max_workers=THREAD_LOCK_EXECUTOR_WORKERS, thread_name_prefix="sandbox-lock-wait") +atexit.register(_THREAD_LOCK_EXECUTOR.shutdown, wait=False, cancel_futures=True) def _lock_file_exclusive(lock_file) -> None: @@ -66,6 +71,40 @@ def _unlock_file(lock_file) -> None: msvcrt.locking(lock_file.fileno(), msvcrt.LK_UNLCK, 1) +def _open_lock_file(lock_path): + return open(lock_path, "a", encoding="utf-8") + + +async def _acquire_thread_lock_async(lock: threading.Lock) -> None: + """Acquire a threading.Lock without polling or using the default executor.""" + loop = asyncio.get_running_loop() + acquire_future = loop.run_in_executor(_THREAD_LOCK_EXECUTOR, lock.acquire, True) + + try: + acquired = await asyncio.shield(acquire_future) + except asyncio.CancelledError: + acquire_future.add_done_callback(lambda task: _release_cancelled_lock_acquire(lock, task)) + raise + + if not acquired: + raise RuntimeError("Failed to acquire sandbox thread lock") + + +def _release_cancelled_lock_acquire(lock: threading.Lock, task: asyncio.Future[bool]) -> None: + """Release a lock acquired after its awaiting coroutine was cancelled.""" + if task.cancelled(): + return + + try: + acquired = task.result() + except Exception as e: + logger.warning(f"Cancelled sandbox lock acquisition finished with error: {e}") + return + + if acquired: + lock.release() + + class AioSandboxProvider(SandboxProvider): """Sandbox provider that manages containers running the AIO sandbox. @@ -416,6 +455,96 @@ class AioSandboxProvider(SandboxProvider): self._thread_locks[thread_id] = threading.Lock() return self._thread_locks[thread_id] + def _sandbox_id_for_thread(self, thread_id: str | None) -> str: + """Return deterministic IDs for thread sandboxes and random IDs otherwise.""" + return self._deterministic_sandbox_id(thread_id) if thread_id else str(uuid.uuid4())[:8] + + def _reuse_in_process_sandbox(self, thread_id: str | None, *, post_lock: bool = False) -> str | None: + """Reuse an active in-process sandbox for a thread if one is still tracked.""" + if thread_id is None: + return None + + with self._lock: + if thread_id not in self._thread_sandboxes: + return None + + existing_id = self._thread_sandboxes[thread_id] + if existing_id in self._sandboxes: + suffix = " (post-lock check)" if post_lock else "" + logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}{suffix}") + self._last_activity[existing_id] = time.time() + return existing_id + + del self._thread_sandboxes[thread_id] + return None + + def _reclaim_warm_pool_sandbox(self, thread_id: str | None, sandbox_id: str, *, post_lock: bool = False) -> str | None: + """Promote a warm-pool sandbox back to active tracking if available.""" + if thread_id is None: + return None + + with self._lock: + if sandbox_id not in self._warm_pool: + return None + + info, _ = self._warm_pool.pop(sandbox_id) + sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url) + self._sandboxes[sandbox_id] = sandbox + self._sandbox_infos[sandbox_id] = info + self._last_activity[sandbox_id] = time.time() + self._thread_sandboxes[thread_id] = sandbox_id + + suffix = " (post-lock check)" if post_lock else f" at {info.sandbox_url}" + logger.info(f"Reclaimed warm-pool sandbox {sandbox_id} for thread {thread_id}{suffix}") + return sandbox_id + + def _recheck_cached_sandbox(self, thread_id: str, sandbox_id: str) -> str | None: + """Re-check in-memory caches after acquiring the cross-process file lock.""" + return self._reuse_in_process_sandbox(thread_id, post_lock=True) or self._reclaim_warm_pool_sandbox(thread_id, sandbox_id, post_lock=True) + + def _register_discovered_sandbox(self, thread_id: str, info: SandboxInfo) -> str: + """Track a sandbox discovered through the backend.""" + sandbox = AioSandbox(id=info.sandbox_id, base_url=info.sandbox_url) + with self._lock: + self._sandboxes[info.sandbox_id] = sandbox + self._sandbox_infos[info.sandbox_id] = info + self._last_activity[info.sandbox_id] = time.time() + self._thread_sandboxes[thread_id] = info.sandbox_id + + logger.info(f"Discovered existing sandbox {info.sandbox_id} for thread {thread_id} at {info.sandbox_url}") + return info.sandbox_id + + def _register_created_sandbox(self, thread_id: str | None, sandbox_id: str, info: SandboxInfo) -> str: + """Track a newly-created sandbox in the active maps.""" + sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url) + with self._lock: + self._sandboxes[sandbox_id] = sandbox + self._sandbox_infos[sandbox_id] = info + self._last_activity[sandbox_id] = time.time() + if thread_id: + self._thread_sandboxes[thread_id] = sandbox_id + + logger.info(f"Created sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}") + return sandbox_id + + def _replica_count(self) -> tuple[int, int]: + """Return configured replicas and currently tracked sandbox count.""" + replicas = self._config.get("replicas", DEFAULT_REPLICAS) + with self._lock: + total = len(self._sandboxes) + len(self._warm_pool) + return replicas, total + + def _log_replicas_soft_cap(self, replicas: int, sandbox_id: str, evicted: str | None) -> None: + """Log the result of enforcing the warm-pool replica budget.""" + if evicted: + logger.info(f"Evicted warm-pool sandbox {evicted} to stay within replicas={replicas}") + return + + # All slots are occupied by active sandboxes — proceed anyway and log. + # The replicas limit is a soft cap; we never forcibly stop a container + # that is actively serving a thread. + logger.warning(f"All {replicas} replica slots are in active use; creating sandbox {sandbox_id} beyond the soft limit") + # ── Core: acquire / get / release / shutdown ───────────────────────── def acquire(self, thread_id: str | None = None) -> str: @@ -440,6 +569,23 @@ class AioSandboxProvider(SandboxProvider): else: return self._acquire_internal(thread_id) + async def acquire_async(self, thread_id: str | None = None) -> str: + """Acquire a sandbox environment without blocking the event loop. + + Mirrors ``acquire()`` while keeping blocking backend operations off the + event loop and using async-native readiness polling for newly created + sandboxes. + """ + if thread_id: + thread_lock = self._get_thread_lock(thread_id) + await _acquire_thread_lock_async(thread_lock) + try: + return await self._acquire_internal_async(thread_id) + finally: + thread_lock.release() + + return await self._acquire_internal_async(thread_id) + def _acquire_internal(self, thread_id: str | None) -> str: """Internal sandbox acquisition with two-layer consistency. @@ -448,33 +594,17 @@ class AioSandboxProvider(SandboxProvider): sandbox_id is deterministic from thread_id so no shared state file is needed — any process can derive the same container name) """ - # ── Layer 1: In-process cache (fast path) ── - if thread_id: - with self._lock: - if thread_id in self._thread_sandboxes: - existing_id = self._thread_sandboxes[thread_id] - if existing_id in self._sandboxes: - logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id}") - self._last_activity[existing_id] = time.time() - return existing_id - else: - del self._thread_sandboxes[thread_id] + cached_id = self._reuse_in_process_sandbox(thread_id) + if cached_id is not None: + return cached_id # Deterministic ID for thread-specific, random for anonymous - sandbox_id = self._deterministic_sandbox_id(thread_id) if thread_id else str(uuid.uuid4())[:8] + sandbox_id = self._sandbox_id_for_thread(thread_id) # ── Layer 1.5: Warm pool (container still running, no cold-start) ── - if thread_id: - with self._lock: - if sandbox_id in self._warm_pool: - info, _ = self._warm_pool.pop(sandbox_id) - sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url) - self._sandboxes[sandbox_id] = sandbox - self._sandbox_infos[sandbox_id] = info - self._last_activity[sandbox_id] = time.time() - self._thread_sandboxes[thread_id] = sandbox_id - logger.info(f"Reclaimed warm-pool sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}") - return sandbox_id + reclaimed_id = self._reclaim_warm_pool_sandbox(thread_id, sandbox_id) + if reclaimed_id is not None: + return reclaimed_id # ── Layer 2: Backend discovery + create (protected by cross-process lock) ── # Use a file lock so that two processes racing to create the same sandbox @@ -485,6 +615,26 @@ class AioSandboxProvider(SandboxProvider): return self._create_sandbox(thread_id, sandbox_id) + async def _acquire_internal_async(self, thread_id: str | None) -> str: + """Async counterpart to ``_acquire_internal``.""" + cached_id = self._reuse_in_process_sandbox(thread_id) + if cached_id is not None: + return cached_id + + # Deterministic ID for thread-specific, random for anonymous + sandbox_id = self._sandbox_id_for_thread(thread_id) + + # ── Layer 1.5: Warm pool (container still running, no cold-start) ── + reclaimed_id = self._reclaim_warm_pool_sandbox(thread_id, sandbox_id) + if reclaimed_id is not None: + return reclaimed_id + + # ── Layer 2: Backend discovery + create (protected by cross-process lock) ── + if thread_id: + return await self._discover_or_create_with_lock_async(thread_id, sandbox_id) + + return await self._create_sandbox_async(thread_id, sandbox_id) + def _discover_or_create_with_lock(self, thread_id: str, sandbox_id: str) -> str: """Discover an existing sandbox or create a new one under a cross-process file lock. @@ -503,40 +653,50 @@ class AioSandboxProvider(SandboxProvider): locked = True # Re-check in-process caches under the file lock in case another # thread in this process won the race while we were waiting. - with self._lock: - if thread_id in self._thread_sandboxes: - existing_id = self._thread_sandboxes[thread_id] - if existing_id in self._sandboxes: - logger.info(f"Reusing in-process sandbox {existing_id} for thread {thread_id} (post-lock check)") - self._last_activity[existing_id] = time.time() - return existing_id - if sandbox_id in self._warm_pool: - info, _ = self._warm_pool.pop(sandbox_id) - sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url) - self._sandboxes[sandbox_id] = sandbox - self._sandbox_infos[sandbox_id] = info - self._last_activity[sandbox_id] = time.time() - self._thread_sandboxes[thread_id] = sandbox_id - logger.info(f"Reclaimed warm-pool sandbox {sandbox_id} for thread {thread_id} (post-lock check)") - return sandbox_id + cached_id = self._recheck_cached_sandbox(thread_id, sandbox_id) + if cached_id is not None: + return cached_id # Backend discovery: another process may have created the container. discovered = self._backend.discover(sandbox_id) if discovered is not None: - sandbox = AioSandbox(id=discovered.sandbox_id, base_url=discovered.sandbox_url) - with self._lock: - self._sandboxes[discovered.sandbox_id] = sandbox - self._sandbox_infos[discovered.sandbox_id] = discovered - self._last_activity[discovered.sandbox_id] = time.time() - self._thread_sandboxes[thread_id] = discovered.sandbox_id - logger.info(f"Discovered existing sandbox {discovered.sandbox_id} for thread {thread_id} at {discovered.sandbox_url}") - return discovered.sandbox_id + return self._register_discovered_sandbox(thread_id, discovered) return self._create_sandbox(thread_id, sandbox_id) finally: if locked: _unlock_file(lock_file) + async def _discover_or_create_with_lock_async(self, thread_id: str, sandbox_id: str) -> str: + """Async counterpart to ``_discover_or_create_with_lock``.""" + paths = get_paths() + user_id = get_effective_user_id() + await asyncio.to_thread(paths.ensure_thread_dirs, thread_id, user_id=user_id) + lock_path = paths.thread_dir(thread_id, user_id=user_id) / f"{sandbox_id}.lock" + + lock_file = await asyncio.to_thread(_open_lock_file, lock_path) + locked = False + try: + await asyncio.to_thread(_lock_file_exclusive, lock_file) + locked = True + # Re-check in-process caches under the file lock in case another + # thread in this process won the race while we were waiting. + cached_id = self._recheck_cached_sandbox(thread_id, sandbox_id) + if cached_id is not None: + return cached_id + + # Backend discovery is sync because local discovery may inspect + # Docker and perform a health check; keep it off the event loop. + discovered = await asyncio.to_thread(self._backend.discover, sandbox_id) + if discovered is not None: + return self._register_discovered_sandbox(thread_id, discovered) + + return await self._create_sandbox_async(thread_id, sandbox_id) + finally: + if locked: + await asyncio.to_thread(_unlock_file, lock_file) + await asyncio.to_thread(lock_file.close) + def _evict_oldest_warm(self) -> str | None: """Destroy the oldest container in the warm pool to free capacity. @@ -574,18 +734,10 @@ class AioSandboxProvider(SandboxProvider): # Enforce replicas: only warm-pool containers count toward eviction budget. # Active sandboxes are in use by live threads and must not be forcibly stopped. - replicas = self._config.get("replicas", DEFAULT_REPLICAS) - with self._lock: - total = len(self._sandboxes) + len(self._warm_pool) + replicas, total = self._replica_count() if total >= replicas: evicted = self._evict_oldest_warm() - if evicted: - logger.info(f"Evicted warm-pool sandbox {evicted} to stay within replicas={replicas}") - else: - # All slots are occupied by active sandboxes — proceed anyway and log. - # The replicas limit is a soft cap; we never forcibly stop a container - # that is actively serving a thread. - logger.warning(f"All {replicas} replica slots are in active use; creating sandbox {sandbox_id} beyond the soft limit") + self._log_replicas_soft_cap(replicas, sandbox_id, evicted) info = self._backend.create(thread_id, sandbox_id, extra_mounts=extra_mounts or None) @@ -594,16 +746,27 @@ class AioSandboxProvider(SandboxProvider): self._backend.destroy(info) raise RuntimeError(f"Sandbox {sandbox_id} failed to become ready within timeout at {info.sandbox_url}") - sandbox = AioSandbox(id=sandbox_id, base_url=info.sandbox_url) - with self._lock: - self._sandboxes[sandbox_id] = sandbox - self._sandbox_infos[sandbox_id] = info - self._last_activity[sandbox_id] = time.time() - if thread_id: - self._thread_sandboxes[thread_id] = sandbox_id + return self._register_created_sandbox(thread_id, sandbox_id, info) - logger.info(f"Created sandbox {sandbox_id} for thread {thread_id} at {info.sandbox_url}") - return sandbox_id + async def _create_sandbox_async(self, thread_id: str | None, sandbox_id: str) -> str: + """Async counterpart to ``_create_sandbox``.""" + extra_mounts = await asyncio.to_thread(self._get_extra_mounts, thread_id) + + # Enforce replicas: only warm-pool containers count toward eviction budget. + # Active sandboxes are in use by live threads and must not be forcibly stopped. + replicas, total = self._replica_count() + if total >= replicas: + evicted = await asyncio.to_thread(self._evict_oldest_warm) + self._log_replicas_soft_cap(replicas, sandbox_id, evicted) + + info = await asyncio.to_thread(self._backend.create, thread_id, sandbox_id, extra_mounts=extra_mounts or None) + + # Wait for sandbox to be ready without blocking the event loop. + if not await wait_for_sandbox_ready_async(info.sandbox_url, timeout=60): + await asyncio.to_thread(self._backend.destroy, info) + raise RuntimeError(f"Sandbox {sandbox_id} failed to become ready within timeout at {info.sandbox_url}") + + return self._register_created_sandbox(thread_id, sandbox_id, info) def get(self, sandbox_id: str) -> Sandbox | None: """Get a sandbox by ID. Updates last activity timestamp. diff --git a/backend/packages/harness/deerflow/community/aio_sandbox/backend.py b/backend/packages/harness/deerflow/community/aio_sandbox/backend.py index 0200ba783..a1db1bf31 100644 --- a/backend/packages/harness/deerflow/community/aio_sandbox/backend.py +++ b/backend/packages/harness/deerflow/community/aio_sandbox/backend.py @@ -2,10 +2,12 @@ from __future__ import annotations +import asyncio import logging import time from abc import ABC, abstractmethod +import httpx import requests from .sandbox_info import SandboxInfo @@ -35,6 +37,34 @@ def wait_for_sandbox_ready(sandbox_url: str, timeout: int = 30) -> bool: return False +async def wait_for_sandbox_ready_async(sandbox_url: str, timeout: int = 30, poll_interval: float = 1.0) -> bool: + """Async variant of sandbox readiness polling. + + Use this from async runtime paths so sandbox startup waits do not block the + event loop. The synchronous ``wait_for_sandbox_ready`` function remains for + existing synchronous backend/provider call sites. + """ + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout + + async with httpx.AsyncClient(timeout=5) as client: + while True: + remaining = deadline - loop.time() + if remaining <= 0: + break + try: + response = await client.get(f"{sandbox_url}/v1/sandbox", timeout=min(5.0, remaining)) + if response.status_code == 200: + return True + except httpx.RequestError: + pass + remaining = deadline - loop.time() + if remaining <= 0: + break + await asyncio.sleep(min(poll_interval, remaining)) + return False + + class SandboxBackend(ABC): """Abstract base for sandbox provisioning backends. @@ -44,7 +74,7 @@ class SandboxBackend(ABC): """ @abstractmethod - def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo: + def create(self, thread_id: str | None, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo: """Create/provision a new sandbox. Args: diff --git a/backend/packages/harness/deerflow/community/aio_sandbox/local_backend.py b/backend/packages/harness/deerflow/community/aio_sandbox/local_backend.py index 92d933d89..69d838208 100644 --- a/backend/packages/harness/deerflow/community/aio_sandbox/local_backend.py +++ b/backend/packages/harness/deerflow/community/aio_sandbox/local_backend.py @@ -241,7 +241,7 @@ class LocalContainerBackend(SandboxBackend): # ── SandboxBackend interface ────────────────────────────────────────── - def create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo: + def create(self, thread_id: str | None, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo: """Start a new container and return its connection info. Args: diff --git a/backend/packages/harness/deerflow/community/aio_sandbox/remote_backend.py b/backend/packages/harness/deerflow/community/aio_sandbox/remote_backend.py index 9b23e05dc..83925df13 100644 --- a/backend/packages/harness/deerflow/community/aio_sandbox/remote_backend.py +++ b/backend/packages/harness/deerflow/community/aio_sandbox/remote_backend.py @@ -59,7 +59,7 @@ class RemoteSandboxBackend(SandboxBackend): def create( self, - thread_id: str, + thread_id: str | None, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None, ) -> SandboxInfo: @@ -132,7 +132,7 @@ class RemoteSandboxBackend(SandboxBackend): logger.warning("Provisioner list_running failed: %s", exc) return [] - def _provisioner_create(self, thread_id: str, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo: + def _provisioner_create(self, thread_id: str | None, sandbox_id: str, extra_mounts: list[tuple[str, str, bool]] | None = None) -> SandboxInfo: """POST /api/sandboxes → create Pod + Service.""" try: resp = requests.post( diff --git a/backend/packages/harness/deerflow/sandbox/middleware.py b/backend/packages/harness/deerflow/sandbox/middleware.py index deefc2397..f40781333 100644 --- a/backend/packages/harness/deerflow/sandbox/middleware.py +++ b/backend/packages/harness/deerflow/sandbox/middleware.py @@ -1,3 +1,4 @@ +import asyncio import logging from typing import NotRequired, override @@ -48,6 +49,15 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]): logger.info(f"Acquiring sandbox {sandbox_id}") return sandbox_id + async def _acquire_sandbox_async(self, thread_id: str) -> str: + provider = get_sandbox_provider() + sandbox_id = await provider.acquire_async(thread_id) + logger.info(f"Acquiring sandbox {sandbox_id}") + return sandbox_id + + async def _release_sandbox_async(self, sandbox_id: str) -> None: + await asyncio.to_thread(get_sandbox_provider().release, sandbox_id) + @override def before_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: # Skip acquisition if lazy_init is enabled @@ -64,6 +74,23 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]): return {"sandbox": {"sandbox_id": sandbox_id}} return super().before_agent(state, runtime) + @override + async def abefore_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: + # Skip acquisition if lazy_init is enabled + if self._lazy_init: + return await super().abefore_agent(state, runtime) + + # Eager initialization (original behavior), but use the async provider + # hook so blocking sandbox startup/polling runs outside the event loop. + if "sandbox" not in state or state["sandbox"] is None: + thread_id = (runtime.context or {}).get("thread_id") + if thread_id is None: + return await super().abefore_agent(state, runtime) + sandbox_id = await self._acquire_sandbox_async(thread_id) + logger.info(f"Assigned sandbox {sandbox_id} to thread {thread_id}") + return {"sandbox": {"sandbox_id": sandbox_id}} + return await super().abefore_agent(state, runtime) + @override def after_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: sandbox = state.get("sandbox") @@ -81,3 +108,21 @@ class SandboxMiddleware(AgentMiddleware[SandboxMiddlewareState]): # No sandbox to release return super().after_agent(state, runtime) + + @override + async def aafter_agent(self, state: SandboxMiddlewareState, runtime: Runtime) -> dict | None: + sandbox = state.get("sandbox") + if sandbox is not None: + sandbox_id = sandbox["sandbox_id"] + logger.info(f"Releasing sandbox {sandbox_id}") + await self._release_sandbox_async(sandbox_id) + return None + + if (runtime.context or {}).get("sandbox_id") is not None: + sandbox_id = runtime.context.get("sandbox_id") + logger.info(f"Releasing sandbox {sandbox_id} from context") + await self._release_sandbox_async(sandbox_id) + return None + + # No sandbox to release + return await super().aafter_agent(state, runtime) diff --git a/backend/packages/harness/deerflow/sandbox/sandbox_provider.py b/backend/packages/harness/deerflow/sandbox/sandbox_provider.py index 0aa4d619a..b989f7830 100644 --- a/backend/packages/harness/deerflow/sandbox/sandbox_provider.py +++ b/backend/packages/harness/deerflow/sandbox/sandbox_provider.py @@ -1,3 +1,4 @@ +import asyncio from abc import ABC, abstractmethod from deerflow.config import get_app_config @@ -19,6 +20,16 @@ class SandboxProvider(ABC): """ pass + async def acquire_async(self, thread_id: str | None = None) -> str: + """Acquire a sandbox without blocking the event loop. + + Most sandbox providers expose a synchronous lifecycle API because local + Docker/provisioner operations are blocking. Async runtimes should call + this method so those blocking operations run in a worker thread instead + of stalling the event loop. + """ + return await asyncio.to_thread(self.acquire, thread_id) + @abstractmethod def get(self, sandbox_id: str) -> Sandbox | None: """Get a sandbox environment by ID. diff --git a/backend/packages/harness/deerflow/sandbox/tools.py b/backend/packages/harness/deerflow/sandbox/tools.py index 2694e9406..c8c0b06fb 100644 --- a/backend/packages/harness/deerflow/sandbox/tools.py +++ b/backend/packages/harness/deerflow/sandbox/tools.py @@ -1,6 +1,8 @@ +import asyncio import posixpath import re import shlex +from collections.abc import Callable from pathlib import Path from langchain.tools import tool @@ -1111,6 +1113,68 @@ def ensure_sandbox_initialized(runtime: Runtime | None = None) -> Sandbox: return sandbox +async def ensure_sandbox_initialized_async(runtime: Runtime | None = None) -> Sandbox: + """Async counterpart to ``ensure_sandbox_initialized`` for tool runtimes. + + This keeps lazy sandbox acquisition on the async provider hook, so AIO + sandbox startup and readiness polling do not fall back to synchronous + ``provider.acquire()`` during async tool execution. + """ + if runtime is None: + raise SandboxRuntimeError("Tool runtime not available") + + if runtime.state is None: + raise SandboxRuntimeError("Tool runtime state not available") + + sandbox_state = runtime.state.get("sandbox") + if sandbox_state is not None: + sandbox_id = sandbox_state.get("sandbox_id") + if sandbox_id is not None: + sandbox = get_sandbox_provider().get(sandbox_id) + if sandbox is not None: + if runtime.context is not None: + runtime.context["sandbox_id"] = sandbox_id + return sandbox + + thread_id = runtime.context.get("thread_id") if runtime.context else None + if thread_id is None: + thread_id = runtime.config.get("configurable", {}).get("thread_id") if runtime.config else None + if thread_id is None: + raise SandboxRuntimeError("Thread ID not available in runtime context") + + provider = get_sandbox_provider() + sandbox_id = await provider.acquire_async(thread_id) + + runtime.state["sandbox"] = {"sandbox_id": sandbox_id} + + sandbox = provider.get(sandbox_id) + if sandbox is None: + raise SandboxNotFoundError("Sandbox not found after acquisition", sandbox_id=sandbox_id) + + if runtime.context is not None: + runtime.context["sandbox_id"] = sandbox_id + return sandbox + + +async def _run_sync_tool_after_async_sandbox_init( + func: Callable[..., str] | None, + runtime: Runtime, + *args: object, +) -> str: + """Initialize lazily via async provider, then run sync tool body off-thread.""" + try: + await ensure_sandbox_initialized_async(runtime) + except SandboxError as e: + return f"Error: {e}" + except Exception as e: + return f"Error: Unexpected error initializing sandbox: {_sanitize_error(e, runtime)}" + + if func is None: + return "Error: Tool implementation not available" + + return await asyncio.to_thread(func, runtime, *args) + + def ensure_thread_directories_exist(runtime: Runtime | None) -> None: """Ensure thread data directories (workspace, uploads, outputs) exist. @@ -1273,6 +1337,13 @@ def bash_tool(runtime: Runtime, description: str, command: str) -> str: return f"Error: Unexpected error executing command: {_sanitize_error(e, runtime)}" +async def _bash_tool_async(runtime: Runtime, description: str, command: str) -> str: + return await _run_sync_tool_after_async_sandbox_init(bash_tool.func, runtime, description, command) + + +bash_tool.coroutine = _bash_tool_async + + @tool("ls", parse_docstring=True) def ls_tool(runtime: Runtime, description: str, path: str) -> str: """List the contents of a directory up to 2 levels deep in tree format. @@ -1320,6 +1391,13 @@ def ls_tool(runtime: Runtime, description: str, path: str) -> str: return f"Error: Unexpected error listing directory: {_sanitize_error(e, runtime)}" +async def _ls_tool_async(runtime: Runtime, description: str, path: str) -> str: + return await _run_sync_tool_after_async_sandbox_init(ls_tool.func, runtime, description, path) + + +ls_tool.coroutine = _ls_tool_async + + @tool("glob", parse_docstring=True) def glob_tool( runtime: Runtime, @@ -1370,6 +1448,28 @@ def glob_tool( return f"Error: Unexpected error searching paths: {_sanitize_error(e, runtime)}" +async def _glob_tool_async( + runtime: Runtime, + description: str, + pattern: str, + path: str, + include_dirs: bool = False, + max_results: int = _DEFAULT_GLOB_MAX_RESULTS, +) -> str: + return await _run_sync_tool_after_async_sandbox_init( + glob_tool.func, + runtime, + description, + pattern, + path, + include_dirs, + max_results, + ) + + +glob_tool.coroutine = _glob_tool_async + + @tool("grep", parse_docstring=True) def grep_tool( runtime: Runtime, @@ -1440,6 +1540,32 @@ def grep_tool( return f"Error: Unexpected error searching file contents: {_sanitize_error(e, runtime)}" +async def _grep_tool_async( + runtime: Runtime, + description: str, + pattern: str, + path: str, + glob: str | None = None, + literal: bool = False, + case_sensitive: bool = False, + max_results: int = _DEFAULT_GREP_MAX_RESULTS, +) -> str: + return await _run_sync_tool_after_async_sandbox_init( + grep_tool.func, + runtime, + description, + pattern, + path, + glob, + literal, + case_sensitive, + max_results, + ) + + +grep_tool.coroutine = _grep_tool_async + + @tool("read_file", parse_docstring=True) def read_file_tool( runtime: Runtime, @@ -1495,6 +1621,19 @@ def read_file_tool( return f"Error: Unexpected error reading file: {_sanitize_error(e, runtime)}" +async def _read_file_tool_async( + runtime: Runtime, + description: str, + path: str, + start_line: int | None = None, + end_line: int | None = None, +) -> str: + return await _run_sync_tool_after_async_sandbox_init(read_file_tool.func, runtime, description, path, start_line, end_line) + + +read_file_tool.coroutine = _read_file_tool_async + + @tool("write_file", parse_docstring=True) def write_file_tool( runtime: Runtime, @@ -1536,6 +1675,19 @@ def write_file_tool( return f"Error: Unexpected error writing file: {_sanitize_error(e, runtime)}" +async def _write_file_tool_async( + runtime: Runtime, + description: str, + path: str, + content: str, + append: bool = False, +) -> str: + return await _run_sync_tool_after_async_sandbox_init(write_file_tool.func, runtime, description, path, content, append) + + +write_file_tool.coroutine = _write_file_tool_async + + @tool("str_replace", parse_docstring=True) def str_replace_tool( runtime: Runtime, @@ -1585,3 +1737,25 @@ def str_replace_tool( return f"Error: Permission denied accessing file: {requested_path}" except Exception as e: return f"Error: Unexpected error replacing string: {_sanitize_error(e, runtime)}" + + +async def _str_replace_tool_async( + runtime: Runtime, + description: str, + path: str, + old_str: str, + new_str: str, + replace_all: bool = False, +) -> str: + return await _run_sync_tool_after_async_sandbox_init( + str_replace_tool.func, + runtime, + description, + path, + old_str, + new_str, + replace_all, + ) + + +str_replace_tool.coroutine = _str_replace_tool_async diff --git a/backend/tests/test_aio_sandbox_provider.py b/backend/tests/test_aio_sandbox_provider.py index 732d52170..4b3d215b3 100644 --- a/backend/tests/test_aio_sandbox_provider.py +++ b/backend/tests/test_aio_sandbox_provider.py @@ -1,5 +1,6 @@ """Tests for AioSandboxProvider mount helpers.""" +import asyncio import importlib from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -140,6 +141,182 @@ def test_discover_or_create_only_unlocks_when_lock_succeeds(tmp_path, monkeypatc assert unlock_calls == [] +@pytest.mark.anyio +async def test_acquire_async_uses_async_readiness_polling(monkeypatch): + """AioSandboxProvider async creation must not use sync readiness polling.""" + aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider") + provider = _make_provider(None) + provider._config = {"replicas": 3} + provider._thread_locks = {} + provider._warm_pool = {} + provider._sandbox_infos = {} + provider._thread_sandboxes = {} + provider._last_activity = {} + provider._lock = aio_mod.threading.Lock() + provider._backend = SimpleNamespace( + create=MagicMock(return_value=aio_mod.SandboxInfo(sandbox_id="sandbox-async", sandbox_url="http://sandbox")), + destroy=MagicMock(), + discover=MagicMock(return_value=None), + ) + + async_readiness_calls: list[tuple[str, int]] = [] + + async def fake_wait_for_sandbox_ready_async(sandbox_url: str, timeout: int = 30, poll_interval: float = 1.0) -> bool: + async_readiness_calls.append((sandbox_url, timeout)) + return True + + monkeypatch.setattr(aio_mod, "wait_for_sandbox_ready_async", fake_wait_for_sandbox_ready_async) + monkeypatch.setattr( + aio_mod, + "wait_for_sandbox_ready", + lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("sync readiness should not be used")), + ) + + sandbox_id = await provider._create_sandbox_async("thread-async", "sandbox-async") + + assert sandbox_id == "sandbox-async" + assert async_readiness_calls == [("http://sandbox", 60)] + assert provider._backend.destroy.call_count == 0 + assert provider._thread_sandboxes["thread-async"] == "sandbox-async" + + +@pytest.mark.anyio +async def test_discover_or_create_with_lock_async_offloads_lock_file_open_and_close(tmp_path, monkeypatch): + """Async lock path must not open or close lock files on the event loop.""" + aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider") + provider = _make_provider(tmp_path) + provider._discover_or_create_with_lock_async = aio_mod.AioSandboxProvider._discover_or_create_with_lock_async.__get__( + provider, + aio_mod.AioSandboxProvider, + ) + provider._thread_locks = {} + provider._warm_pool = {} + provider._sandbox_infos = {} + provider._thread_sandboxes = {"thread-async-lock": "sandbox-async-lock"} + provider._sandboxes = {"sandbox-async-lock": aio_mod.AioSandbox(id="sandbox-async-lock", base_url="http://sandbox")} + provider._last_activity = {} + provider._lock = aio_mod.threading.Lock() + provider._backend = SimpleNamespace(discover=MagicMock(return_value=None)) + + monkeypatch.setattr(aio_mod, "get_paths", lambda: Paths(base_dir=tmp_path)) + + to_thread_calls: list[object] = [] + + async def fake_to_thread(func, /, *args, **kwargs): + to_thread_calls.append(func) + return func(*args, **kwargs) + + monkeypatch.setattr(aio_mod.asyncio, "to_thread", fake_to_thread) + + sandbox_id = await provider._discover_or_create_with_lock_async("thread-async-lock", "sandbox-async-lock") + + assert sandbox_id == "sandbox-async-lock" + assert aio_mod._open_lock_file in to_thread_calls + assert any(getattr(func, "__name__", "") == "close" for func in to_thread_calls) + + +@pytest.mark.anyio +async def test_acquire_thread_lock_async_uses_dedicated_executor(monkeypatch): + """Per-thread lock waits should not consume the default asyncio.to_thread pool.""" + aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider") + lock = aio_mod.threading.Lock() + + async def fail_to_thread(*_args, **_kwargs): + raise AssertionError("thread-lock acquisition must not use asyncio.to_thread") + + monkeypatch.setattr(aio_mod.asyncio, "to_thread", fail_to_thread) + + await aio_mod._acquire_thread_lock_async(lock) + try: + assert not lock.acquire(blocking=False) + finally: + lock.release() + + +@pytest.mark.anyio +async def test_acquire_async_cancellation_does_not_leak_thread_lock(tmp_path): + """Cancelled async lock waiters must not leave the per-thread lock held.""" + aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider") + provider = _make_provider(tmp_path) + provider._thread_locks = {} + provider._warm_pool = {} + provider._sandbox_infos = {} + provider._thread_sandboxes = {} + provider._last_activity = {} + provider._lock = aio_mod.threading.Lock() + + thread_id = "thread-cancel-lock" + thread_lock = provider._get_thread_lock(thread_id) + thread_lock.acquire() + + task = asyncio.create_task(provider.acquire_async(thread_id)) + await asyncio.sleep(0.05) + task.cancel() + + try: + await task + except asyncio.CancelledError: + pass + + thread_lock.release() + deadline = asyncio.get_running_loop().time() + 1 + while asyncio.get_running_loop().time() < deadline: + acquired = thread_lock.acquire(blocking=False) + if acquired: + thread_lock.release() + return + await asyncio.sleep(0.01) + + pytest.fail("provider thread lock was leaked after cancelling acquire_async") + + +@pytest.mark.anyio +async def test_acquire_async_cancelled_waiter_does_not_block_successor(tmp_path, monkeypatch): + """A cancelled waiter must not prevent the next live waiter from acquiring.""" + aio_mod = importlib.import_module("deerflow.community.aio_sandbox.aio_sandbox_provider") + provider = _make_provider(tmp_path) + provider._thread_locks = {} + provider._warm_pool = {} + provider._sandbox_infos = {} + provider._thread_sandboxes = {} + provider._last_activity = {} + provider._lock = aio_mod.threading.Lock() + + async def fake_acquire_internal_async(thread_id: str | None) -> str: + assert thread_id == "thread-successor-lock" + await asyncio.sleep(0) + return "sandbox-successor" + + monkeypatch.setattr(provider, "_acquire_internal_async", fake_acquire_internal_async) + + thread_id = "thread-successor-lock" + thread_lock = provider._get_thread_lock(thread_id) + thread_lock.acquire() + + cancelled_waiter = asyncio.create_task(provider.acquire_async(thread_id)) + await asyncio.sleep(0.05) + cancelled_waiter.cancel() + try: + await cancelled_waiter + except asyncio.CancelledError: + pass + + live_waiter = asyncio.create_task(provider.acquire_async(thread_id)) + thread_lock.release() + + assert await asyncio.wait_for(live_waiter, timeout=1) == "sandbox-successor" + + deadline = asyncio.get_running_loop().time() + 1 + while asyncio.get_running_loop().time() < deadline: + acquired = thread_lock.acquire(blocking=False) + if acquired: + thread_lock.release() + return + await asyncio.sleep(0.01) + + pytest.fail("provider thread lock was not released after successor acquire_async") + + def test_remote_backend_create_forwards_effective_user_id(monkeypatch): """Provisioner mode must receive user_id so PVC subPath matches user isolation.""" remote_mod = importlib.import_module("deerflow.community.aio_sandbox.remote_backend") diff --git a/backend/tests/test_aio_sandbox_readiness.py b/backend/tests/test_aio_sandbox_readiness.py new file mode 100644 index 000000000..1560bbab3 --- /dev/null +++ b/backend/tests/test_aio_sandbox_readiness.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from deerflow.community.aio_sandbox import backend as readiness + + +class _FakeAsyncClient: + def __init__(self, *, responses: list[object], calls: list[str], timeout: float, request_timeouts: list[float] | None = None) -> None: + self._responses = responses + self._calls = calls + self._timeout = timeout + self._request_timeouts = request_timeouts + + async def __aenter__(self) -> _FakeAsyncClient: + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + return None + + async def get(self, url: str, *, timeout: float): + self._calls.append(url) + if self._request_timeouts is not None: + self._request_timeouts.append(timeout) + response = self._responses.pop(0) + if isinstance(response, BaseException): + raise response + return response + + +class _FakeLoop: + def __init__(self, times: list[float]) -> None: + self._times = times + self._index = 0 + + def time(self) -> float: + value = self._times[self._index] + self._index += 1 + return value + + +@pytest.mark.anyio +async def test_wait_for_sandbox_ready_async_uses_nonblocking_polling(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[str] = [] + sleeps: list[float] = [] + + def fake_client(*, timeout: float): + return _FakeAsyncClient( + responses=[SimpleNamespace(status_code=503), SimpleNamespace(status_code=200)], + calls=calls, + timeout=timeout, + ) + + async def fake_sleep(delay: float) -> None: + sleeps.append(delay) + + monkeypatch.setattr(readiness.httpx, "AsyncClient", fake_client) + monkeypatch.setattr(readiness.asyncio, "sleep", fake_sleep) + monkeypatch.setattr(readiness.requests, "get", lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("requests.get should not be used"))) + monkeypatch.setattr(readiness.time, "sleep", lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("time.sleep should not be used"))) + + assert await readiness.wait_for_sandbox_ready_async("http://sandbox", timeout=5, poll_interval=0.05) is True + + assert calls == ["http://sandbox/v1/sandbox", "http://sandbox/v1/sandbox"] + assert sleeps == [0.05] + + +@pytest.mark.anyio +async def test_wait_for_sandbox_ready_async_retries_request_errors(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[str] = [] + sleeps: list[float] = [] + + def fake_client(*, timeout: float): + return _FakeAsyncClient( + responses=[readiness.httpx.ConnectError("not ready"), SimpleNamespace(status_code=200)], + calls=calls, + timeout=timeout, + ) + + async def fake_sleep(delay: float) -> None: + sleeps.append(delay) + + monkeypatch.setattr(readiness.httpx, "AsyncClient", fake_client) + monkeypatch.setattr(readiness.asyncio, "sleep", fake_sleep) + + assert await readiness.wait_for_sandbox_ready_async("http://sandbox", timeout=5, poll_interval=0.01) is True + + assert len(calls) == 2 + assert sleeps == [0.01] + + +@pytest.mark.anyio +async def test_wait_for_sandbox_ready_async_clamps_request_and_sleep_to_deadline(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[str] = [] + request_timeouts: list[float] = [] + sleeps: list[float] = [] + + def fake_client(*, timeout: float): + return _FakeAsyncClient( + responses=[SimpleNamespace(status_code=503)], + calls=calls, + timeout=timeout, + request_timeouts=request_timeouts, + ) + + async def fake_sleep(delay: float) -> None: + sleeps.append(delay) + + monkeypatch.setattr(readiness.httpx, "AsyncClient", fake_client) + monkeypatch.setattr(readiness.asyncio, "sleep", fake_sleep) + monkeypatch.setattr(readiness.asyncio, "get_running_loop", lambda: _FakeLoop([100.0, 100.5, 101.75, 102.0])) + + assert await readiness.wait_for_sandbox_ready_async("http://sandbox", timeout=2, poll_interval=1.0) is False + + assert calls == ["http://sandbox/v1/sandbox"] + assert request_timeouts == [1.5] + assert sleeps == [0.25] diff --git a/backend/tests/test_remote_sandbox_backend.py b/backend/tests/test_remote_sandbox_backend.py index ed4dd7991..beb7564c5 100644 --- a/backend/tests/test_remote_sandbox_backend.py +++ b/backend/tests/test_remote_sandbox_backend.py @@ -159,6 +159,26 @@ def test_provisioner_create_returns_sandbox_info(monkeypatch): assert info.sandbox_url == "http://k3s:31001" +def test_provisioner_create_accepts_anonymous_thread_id(monkeypatch): + backend = RemoteSandboxBackend("http://provisioner:8002") + + def mock_post(url: str, json: dict, timeout: int): + assert url == "http://provisioner:8002/api/sandboxes" + assert json == { + "sandbox_id": "anon123", + "thread_id": None, + "user_id": "test-user-autouse", + } + assert timeout == 30 + return _StubResponse(payload={"sandbox_id": "anon123", "sandbox_url": "http://k3s:31002"}) + + monkeypatch.setattr(requests, "post", mock_post) + + info = backend.create(None, "anon123") + assert info.sandbox_id == "anon123" + assert info.sandbox_url == "http://k3s:31002" + + def test_provisioner_create_raises_runtime_error_on_request_exception(monkeypatch): backend = RemoteSandboxBackend("http://provisioner:8002") diff --git a/backend/tests/test_sandbox_middleware.py b/backend/tests/test_sandbox_middleware.py new file mode 100644 index 000000000..e3daa3088 --- /dev/null +++ b/backend/tests/test_sandbox_middleware.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import asyncio + +import pytest +from langchain.agents.middleware import AgentMiddleware +from langchain.tools import ToolRuntime +from langgraph.runtime import Runtime + +from deerflow.sandbox.middleware import SandboxMiddleware +from deerflow.sandbox.sandbox import Sandbox +from deerflow.sandbox.sandbox_provider import SandboxProvider, reset_sandbox_provider, set_sandbox_provider +from deerflow.sandbox.search import GrepMatch +from deerflow.sandbox.tools import ls_tool + + +class _SyncProvider(SandboxProvider): + def __init__(self) -> None: + self.thread_ids: list[str | None] = [] + + def acquire(self, thread_id: str | None = None) -> str: + self.thread_ids.append(thread_id) + return "sync-sandbox" + + def get(self, sandbox_id: str) -> Sandbox | None: + return None + + def release(self, sandbox_id: str) -> None: + return None + + +class _SandboxStub(Sandbox): + def execute_command(self, command: str) -> str: + return "OK" + + def read_file(self, path: str) -> str: + return "content" + + def download_file(self, path: str) -> bytes: + return b"content" + + def list_dir(self, path: str, max_depth: int = 2) -> list[str]: + return ["/mnt/user-data/workspace/file.txt"] + + def write_file(self, path: str, content: str, append: bool = False) -> None: + return None + + def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]: + return [], False + + def grep( + self, + path: str, + pattern: str, + *, + glob: str | None = None, + literal: bool = False, + case_sensitive: bool = False, + max_results: int = 100, + ) -> tuple[list[GrepMatch], bool]: + return [], False + + def update_file(self, path: str, content: bytes) -> None: + return None + + +class _AsyncOnlyProvider(SandboxProvider): + def __init__(self) -> None: + self.thread_ids: list[str | None] = [] + self.released_ids: list[str] = [] + self.sandbox = _SandboxStub("async-sandbox") + + def acquire(self, thread_id: str | None = None) -> str: + raise AssertionError("async middleware should not call sync acquire") + + async def acquire_async(self, thread_id: str | None = None) -> str: + self.thread_ids.append(thread_id) + return "async-sandbox" + + def get(self, sandbox_id: str) -> Sandbox | None: + if sandbox_id == "async-sandbox": + return self.sandbox + return None + + def release(self, sandbox_id: str) -> None: + self.released_ids.append(sandbox_id) + return None + + +@pytest.mark.anyio +async def test_provider_default_acquire_async_offloads_sync_acquire(monkeypatch: pytest.MonkeyPatch) -> None: + provider = _SyncProvider() + calls: list[tuple[object, tuple[object, ...]]] = [] + + async def fake_to_thread(func, /, *args): + calls.append((func, args)) + return func(*args) + + monkeypatch.setattr(asyncio, "to_thread", fake_to_thread) + + sandbox_id = await provider.acquire_async("thread-1") + + assert sandbox_id == "sync-sandbox" + assert provider.thread_ids == ["thread-1"] + assert calls == [(provider.acquire, ("thread-1",))] + + +@pytest.mark.anyio +async def test_abefore_agent_uses_async_provider_acquire() -> None: + provider = _AsyncOnlyProvider() + set_sandbox_provider(provider) + try: + middleware = SandboxMiddleware(lazy_init=False) + + result = await middleware.abefore_agent({}, Runtime(context={"thread_id": "thread-2"})) + finally: + reset_sandbox_provider() + + assert result == {"sandbox": {"sandbox_id": "async-sandbox"}} + assert provider.thread_ids == ["thread-2"] + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("middleware", "state", "runtime"), + [ + (SandboxMiddleware(lazy_init=True), {}, Runtime(context={"thread_id": "thread-lazy"})), + (SandboxMiddleware(lazy_init=False), {}, Runtime(context={})), + (SandboxMiddleware(lazy_init=False), {"sandbox": {"sandbox_id": "existing"}}, Runtime(context={"thread_id": "thread-existing"})), + ], +) +async def test_abefore_agent_delegates_to_super_when_not_acquiring( + monkeypatch: pytest.MonkeyPatch, + middleware: SandboxMiddleware, + state: dict, + runtime: Runtime, +) -> None: + calls: list[tuple[dict, Runtime]] = [] + + async def fake_super_abefore_agent(self, state_arg, runtime_arg): + calls.append((state_arg, runtime_arg)) + return {"delegated": True} + + monkeypatch.setattr(AgentMiddleware, "abefore_agent", fake_super_abefore_agent) + + result = await middleware.abefore_agent(state, runtime) + + assert result == {"delegated": True} + assert calls == [(state, runtime)] + + +@pytest.mark.anyio +async def test_default_lazy_tool_acquisition_uses_async_provider() -> None: + provider = _AsyncOnlyProvider() + set_sandbox_provider(provider) + try: + runtime = ToolRuntime( + state={}, + context={"thread_id": "thread-lazy"}, + config={"configurable": {}}, + stream_writer=lambda _: None, + tools=[], + tool_call_id="call-1", + store=None, + ) + + result = await ls_tool.ainvoke({"runtime": runtime, "description": "list workspace", "path": "/mnt/user-data/workspace"}) + finally: + reset_sandbox_provider() + + assert result == "/mnt/user-data/workspace/file.txt" + assert provider.thread_ids == ["thread-lazy"] + assert runtime.state["sandbox"] == {"sandbox_id": "async-sandbox"} + assert runtime.context["sandbox_id"] == "async-sandbox" + + +@pytest.mark.anyio +@pytest.mark.parametrize( + ("state", "runtime", "expected_sandbox_id"), + [ + ({"sandbox": {"sandbox_id": "state-sandbox"}}, Runtime(context={}), "state-sandbox"), + ({}, Runtime(context={"sandbox_id": "context-sandbox"}), "context-sandbox"), + ], +) +async def test_aafter_agent_releases_sandbox_off_thread( + monkeypatch: pytest.MonkeyPatch, + state: dict, + runtime: Runtime, + expected_sandbox_id: str, +) -> None: + provider = _AsyncOnlyProvider() + to_thread_calls: list[tuple[object, tuple[object, ...]]] = [] + + async def fake_to_thread(func, /, *args): + to_thread_calls.append((func, args)) + return func(*args) + + monkeypatch.setattr(asyncio, "to_thread", fake_to_thread) + set_sandbox_provider(provider) + try: + result = await SandboxMiddleware().aafter_agent(state, runtime) + finally: + reset_sandbox_provider() + + assert result is None + assert provider.released_ids == [expected_sandbox_id] + assert to_thread_calls == [(provider.release, (expected_sandbox_id,))] + + +@pytest.mark.anyio +async def test_aafter_agent_delegates_to_super_when_no_sandbox(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[tuple[dict, Runtime]] = [] + + async def fake_super_aafter_agent(self, state_arg, runtime_arg): + calls.append((state_arg, runtime_arg)) + return {"delegated": True} + + monkeypatch.setattr(AgentMiddleware, "aafter_agent", fake_super_aafter_agent) + + state = {} + runtime = Runtime(context={}) + result = await SandboxMiddleware().aafter_agent(state, runtime) + + assert result == {"delegated": True} + assert calls == [(state, runtime)] From 923f516debc9989cecc688eaa47c32a32b30b86f Mon Sep 17 00:00:00 2001 From: Airene Fang Date: Thu, 21 May 2026 14:48:28 +0800 Subject: [PATCH 57/86] feat(trace):LangGraph -> lead_agent and set custom agent_name to run_name (#3101) * feat(trace):LangGraph -> lead_agent and set user custom agent name to run_name * feat(trace):follow github copilot suggest * feat(trace):Refactor run_name resolution and improve test coverage --- backend/app/gateway/services.py | 2 + .../harness/deerflow/runtime/runs/naming.py | 16 +++ .../harness/deerflow/runtime/runs/worker.py | 4 + backend/tests/test_gateway_services.py | 4 + backend/tests/test_run_naming.py | 34 ++++++ backend/tests/test_run_worker_rollback.py | 102 ++++++++++++++++++ 6 files changed, 162 insertions(+) create mode 100644 backend/packages/harness/deerflow/runtime/runs/naming.py create mode 100644 backend/tests/test_run_naming.py diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index 96521b86f..4713d303e 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -32,6 +32,7 @@ from deerflow.runtime import ( UnsupportedStrategyError, run_agent, ) +from deerflow.runtime.runs.naming import resolve_root_run_name logger = logging.getLogger(__name__) @@ -235,6 +236,7 @@ def build_run_config( target = config.setdefault("configurable", {}) if target is not None and "agent_name" not in target: target["agent_name"] = normalized + config.setdefault("run_name", resolve_root_run_name(config, normalized)) if metadata: config.setdefault("metadata", {}).update(metadata) return config diff --git a/backend/packages/harness/deerflow/runtime/runs/naming.py b/backend/packages/harness/deerflow/runtime/runs/naming.py new file mode 100644 index 000000000..57c67f17c --- /dev/null +++ b/backend/packages/harness/deerflow/runtime/runs/naming.py @@ -0,0 +1,16 @@ +"""Run naming helpers for LangChain/LangSmith tracing.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + + +def resolve_root_run_name(config: Mapping[str, Any], assistant_id: str | None) -> str: + for container_name in ("context", "configurable"): + container = config.get(container_name) + if isinstance(container, Mapping): + agent_name = container.get("agent_name") + if isinstance(agent_name, str) and agent_name.strip(): + return agent_name + return assistant_id or "lead_agent" diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index f78d425a2..09e3c66e9 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -33,6 +33,7 @@ from deerflow.runtime.serialization import serialize from deerflow.runtime.stream_bridge import StreamBridge from .manager import RunManager, RunRecord +from .naming import resolve_root_run_name from .schemas import RunStatus logger = logging.getLogger(__name__) @@ -224,6 +225,9 @@ async def run_agent( if journal is not None: config.setdefault("callbacks", []).append(journal) + # Resolve after runtime context installation so context/configurable reflect + # the agent name that this run will actually execute. + config.setdefault("run_name", resolve_root_run_name(config, record.assistant_id)) runnable_config = RunnableConfig(**config) if ctx.app_config is not None and _agent_factory_supports_app_config(agent_factory): agent = agent_factory(config=runnable_config, app_config=ctx.app_config) diff --git a/backend/tests/test_gateway_services.py b/backend/tests/test_gateway_services.py index b024405b5..aa9e20e78 100644 --- a/backend/tests/test_gateway_services.py +++ b/backend/tests/test_gateway_services.py @@ -114,6 +114,7 @@ def test_build_run_config_custom_agent_injects_agent_name(): config = build_run_config("thread-1", None, None, assistant_id="finalis") assert config["configurable"]["agent_name"] == "finalis" + assert config["run_name"] == "finalis" def test_build_run_config_lead_agent_no_agent_name(): @@ -122,6 +123,7 @@ def test_build_run_config_lead_agent_no_agent_name(): config = build_run_config("thread-1", None, None, assistant_id="lead_agent") assert "agent_name" not in config["configurable"] + assert "run_name" not in config def test_build_run_config_none_assistant_id_no_agent_name(): @@ -130,6 +132,7 @@ def test_build_run_config_none_assistant_id_no_agent_name(): config = build_run_config("thread-1", None, None, assistant_id=None) assert "agent_name" not in config["configurable"] + assert "run_name" not in config def test_build_run_config_explicit_agent_name_not_overwritten(): @@ -143,6 +146,7 @@ def test_build_run_config_explicit_agent_name_not_overwritten(): assistant_id="other-agent", ) assert config["configurable"]["agent_name"] == "explicit-agent" + assert config["run_name"] == "explicit-agent" def test_build_run_config_context_custom_agent_injects_agent_name(): diff --git a/backend/tests/test_run_naming.py b/backend/tests/test_run_naming.py new file mode 100644 index 000000000..4afb6fad7 --- /dev/null +++ b/backend/tests/test_run_naming.py @@ -0,0 +1,34 @@ +from deerflow.runtime.runs.naming import resolve_root_run_name + + +def test_resolve_root_run_name_from_context_agent_name(): + assert resolve_root_run_name({"context": {"agent_name": "finalis"}}, "lead_agent") == "finalis" + + +def test_resolve_root_run_name_from_configurable_agent_name(): + assert resolve_root_run_name({"configurable": {"agent_name": "finalis"}}, "lead_agent") == "finalis" + + +def test_resolve_root_run_name_falls_back_to_assistant_id(): + assert resolve_root_run_name({}, "my-agent") == "my-agent" + + +def test_resolve_root_run_name_falls_back_to_lead_agent(): + assert resolve_root_run_name({}, None) == "lead_agent" + + +def test_resolve_root_run_name_prefers_context_over_configurable(): + config = { + "context": {"agent_name": "ctx-agent"}, + "configurable": {"agent_name": "cfg-agent"}, + } + + assert resolve_root_run_name(config, "lead_agent") == "ctx-agent" + + +def test_resolve_root_run_name_ignores_blank_agent_name(): + assert resolve_root_run_name({"context": {"agent_name": " "}}, "my-agent") == "my-agent" + + +def test_resolve_root_run_name_ignores_non_string_agent_name(): + assert resolve_root_run_name({"context": {"agent_name": None}}, "my-agent") == "my-agent" diff --git a/backend/tests/test_run_worker_rollback.py b/backend/tests/test_run_worker_rollback.py index 72e3ac98e..5a8ec71f7 100644 --- a/backend/tests/test_run_worker_rollback.py +++ b/backend/tests/test_run_worker_rollback.py @@ -95,6 +95,108 @@ async def test_run_agent_threads_explicit_app_config_into_config_only_factory(): bridge.cleanup.assert_awaited_once_with(record.run_id, delay=60) +@pytest.mark.anyio +async def test_run_agent_defaults_root_run_name_from_assistant_id(): + run_manager = RunManager() + record = await run_manager.create("thread-1", assistant_id="lead_agent") + bridge = SimpleNamespace( + publish=AsyncMock(), + publish_end=AsyncMock(), + cleanup=AsyncMock(), + ) + captured: dict[str, object] = {} + + class DummyAgent: + async def astream(self, graph_input, config=None, stream_mode=None, subgraphs=False): + captured["astream_run_name"] = config["run_name"] + yield {"messages": []} + + def factory(*, config): + captured["factory_run_name"] = config["run_name"] + return DummyAgent() + + await run_agent( + bridge, + run_manager, + record, + ctx=RunContext(checkpointer=None), + agent_factory=factory, + graph_input={}, + config={}, + ) + + assert captured["factory_run_name"] == "lead_agent" + assert captured["astream_run_name"] == "lead_agent" + + +@pytest.mark.anyio +async def test_run_agent_defaults_root_run_name_from_context_agent_name(): + run_manager = RunManager() + record = await run_manager.create("thread-1", assistant_id="lead_agent") + bridge = SimpleNamespace( + publish=AsyncMock(), + publish_end=AsyncMock(), + cleanup=AsyncMock(), + ) + captured: dict[str, object] = {} + + class DummyAgent: + async def astream(self, graph_input, config=None, stream_mode=None, subgraphs=False): + captured["astream_run_name"] = config["run_name"] + yield {"messages": []} + + def factory(*, config): + captured["factory_run_name"] = config["run_name"] + return DummyAgent() + + await run_agent( + bridge, + run_manager, + record, + ctx=RunContext(checkpointer=None), + agent_factory=factory, + graph_input={}, + config={"context": {"agent_name": "finalis"}}, + ) + + assert captured["factory_run_name"] == "finalis" + assert captured["astream_run_name"] == "finalis" + + +@pytest.mark.anyio +async def test_run_agent_defaults_root_run_name_from_configurable_agent_name(): + run_manager = RunManager() + record = await run_manager.create("thread-1", assistant_id="lead_agent") + bridge = SimpleNamespace( + publish=AsyncMock(), + publish_end=AsyncMock(), + cleanup=AsyncMock(), + ) + captured: dict[str, object] = {} + + class DummyAgent: + async def astream(self, graph_input, config=None, stream_mode=None, subgraphs=False): + captured["astream_run_name"] = config["run_name"] + yield {"messages": []} + + def factory(*, config): + captured["factory_run_name"] = config["run_name"] + return DummyAgent() + + await run_agent( + bridge, + run_manager, + record, + ctx=RunContext(checkpointer=None), + agent_factory=factory, + graph_input={}, + config={"configurable": {"agent_name": "finalis"}}, + ) + + assert captured["factory_run_name"] == "finalis" + assert captured["astream_run_name"] == "finalis" + + @pytest.mark.anyio async def test_rollback_restores_snapshot_without_deleting_thread(): checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) From 31513c2ccb26a3e2a4b637e4ebcae312140ecdf9 Mon Sep 17 00:00:00 2001 From: Xinmin Zeng <135568692+fancyboi999@users.noreply.github.com> Date: Thu, 21 May 2026 16:22:09 +0800 Subject: [PATCH 58/86] fix(persistence): emit tz-aware timestamps from SQLite-backed stores (#3130) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SQLAlchemy's DateTime(timezone=True) is a no-op on SQLite (the backend has no native tz type), so values round-tripped through the DB come back as naive datetimes. The four SQL _row_to_dict helpers were calling .isoformat() directly on those naive values, shipping timezone-less strings like "2026-05-20T06:10:22.970977" out of the API. The browser's new Date(...) then parses them as local time, shifting recent threads in /threads/search by the local UTC offset (about 8h in Asia/Shanghai). Route the four call sites through coerce_iso() instead — it already normalizes naive values as UTC and emits "+00:00" so the wire format always carries tz. No data migration is needed; existing SQLite rows read back via the corrected serializer. PostgreSQL deployments are unaffected because timestamptz preserves tzinfo end-to-end. Closes #3120 --- .../deerflow/persistence/feedback/sql.py | 4 +- .../harness/deerflow/persistence/run/sql.py | 7 +- .../deerflow/persistence/thread_meta/sql.py | 5 +- .../deerflow/runtime/events/store/db.py | 5 +- backend/tests/test_persistence_timezone.py | 106 ++++++++++++++++++ 5 files changed, 122 insertions(+), 5 deletions(-) create mode 100644 backend/tests/test_persistence_timezone.py diff --git a/backend/packages/harness/deerflow/persistence/feedback/sql.py b/backend/packages/harness/deerflow/persistence/feedback/sql.py index 1db74ce84..cdb5db89b 100644 --- a/backend/packages/harness/deerflow/persistence/feedback/sql.py +++ b/backend/packages/harness/deerflow/persistence/feedback/sql.py @@ -13,6 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from deerflow.persistence.feedback.model import FeedbackRow from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id +from deerflow.utils.time import coerce_iso class FeedbackRepository: @@ -24,7 +25,8 @@ class FeedbackRepository: d = row.to_dict() val = d.get("created_at") if isinstance(val, datetime): - d["created_at"] = val.isoformat() + # SQLite drops tzinfo on read; normalize via ``coerce_iso`` so output is always tz-aware. + d["created_at"] = coerce_iso(val) return d async def create( diff --git a/backend/packages/harness/deerflow/persistence/run/sql.py b/backend/packages/harness/deerflow/persistence/run/sql.py index d586a2b13..5679cc68f 100644 --- a/backend/packages/harness/deerflow/persistence/run/sql.py +++ b/backend/packages/harness/deerflow/persistence/run/sql.py @@ -17,6 +17,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from deerflow.persistence.run.model import RunRow from deerflow.runtime.runs.store.base import RunStore from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id +from deerflow.utils.time import coerce_iso class RunRepository(RunStore): @@ -68,11 +69,13 @@ class RunRepository(RunStore): # Remap JSON columns to match RunStore interface d["metadata"] = d.pop("metadata_json", {}) d["kwargs"] = d.pop("kwargs_json", {}) - # Convert datetime to ISO string for consistency with MemoryRunStore + # Convert datetime to ISO string for consistency with MemoryRunStore. + # SQLite drops tzinfo on read despite ``DateTime(timezone=True)`` — + # ``coerce_iso`` normalizes naive datetimes as UTC. for key in ("created_at", "updated_at"): val = d.get(key) if isinstance(val, datetime): - d[key] = val.isoformat() + d[key] = coerce_iso(val) return d async def put( diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py index 0d3f587de..930128087 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py @@ -13,6 +13,7 @@ from deerflow.persistence.json_compat import json_match from deerflow.persistence.thread_meta.base import InvalidMetadataFilterError, ThreadMetaStore from deerflow.persistence.thread_meta.model import ThreadMetaRow from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id +from deerflow.utils.time import coerce_iso logger = logging.getLogger(__name__) @@ -28,7 +29,9 @@ class ThreadMetaRepository(ThreadMetaStore): for key in ("created_at", "updated_at"): val = d.get(key) if isinstance(val, datetime): - d[key] = val.isoformat() + # SQLite drops tzinfo despite ``DateTime(timezone=True)``; + # ``coerce_iso`` normalizes naive values as UTC so the wire format always carries tz. + d[key] = coerce_iso(val) return d async def create( diff --git a/backend/packages/harness/deerflow/runtime/events/store/db.py b/backend/packages/harness/deerflow/runtime/events/store/db.py index b7e54754f..7bb55133e 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/db.py +++ b/backend/packages/harness/deerflow/runtime/events/store/db.py @@ -17,6 +17,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from deerflow.persistence.models.run_event import RunEventRow from deerflow.runtime.events.store.base import RunEventStore from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_user_id +from deerflow.utils.time import coerce_iso logger = logging.getLogger(__name__) @@ -32,7 +33,9 @@ class DbRunEventStore(RunEventStore): d["metadata"] = d.pop("event_metadata", {}) val = d.get("created_at") if isinstance(val, datetime): - d["created_at"] = val.isoformat() + # SQLite drops tzinfo on read despite ``DateTime(timezone=True)``; + # ``coerce_iso`` normalizes naive datetimes as UTC. + d["created_at"] = coerce_iso(val) d.pop("id", None) # Restore structured content that was JSON-serialized on write. raw = d.get("content", "") diff --git a/backend/tests/test_persistence_timezone.py b/backend/tests/test_persistence_timezone.py new file mode 100644 index 000000000..7cd7b3310 --- /dev/null +++ b/backend/tests/test_persistence_timezone.py @@ -0,0 +1,106 @@ +"""Regression tests for #3120: SQLite-backed stores must emit tz-aware ISO timestamps. + +SQLAlchemy's ``DateTime(timezone=True)`` is a no-op on SQLite because the +backend has no native timezone type, so values read back are naive +``datetime`` instances. The four SQL ``_row_to_dict`` helpers therefore +have to normalize through :func:`deerflow.utils.time.coerce_iso` instead +of calling ``.isoformat()`` directly; otherwise the API ships +timezone-less strings (e.g. ``"2026-05-20T06:10:22.970977"``) and the +frontend's ``new Date(...)`` parses them as local time, shifting recent +threads by the local UTC offset. +""" + +import re + +import pytest + +_TZ_SUFFIX_RE = re.compile(r"(?:\+\d{2}:\d{2}|Z)$") + + +def _assert_tz_aware(value: str | None, *, context: str) -> None: + assert value, f"{context}: expected ISO string, got {value!r}" + assert _TZ_SUFFIX_RE.search(value), f"{context}: timestamp lacks tz suffix: {value!r}" + + +async def _init_sqlite(tmp_path): + from deerflow.persistence.engine import get_session_factory, init_engine + + url = f"sqlite+aiosqlite:///{tmp_path / 'tz.db'}" + await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) + return get_session_factory() + + +async def _cleanup(): + from deerflow.persistence.engine import close_engine + + await close_engine() + + +@pytest.mark.anyio +async def test_thread_meta_emits_tz_aware_timestamps(tmp_path): + from deerflow.persistence.thread_meta import ThreadMetaRepository + + repo = ThreadMetaRepository(await _init_sqlite(tmp_path)) + try: + created = await repo.create("t-tz", user_id="u1", display_name="tz") + _assert_tz_aware(created["created_at"], context="thread_meta.create.created_at") + _assert_tz_aware(created["updated_at"], context="thread_meta.create.updated_at") + + # Second read from DB exercises the same _row_to_dict path on a + # value that SQLite has round-tripped (where tzinfo is lost). + fetched = await repo.get("t-tz", user_id="u1") + _assert_tz_aware(fetched["created_at"], context="thread_meta.get.created_at") + _assert_tz_aware(fetched["updated_at"], context="thread_meta.get.updated_at") + + listed = await repo.search(user_id="u1") + assert listed, "search must return the created row" + _assert_tz_aware(listed[0]["created_at"], context="thread_meta.search.created_at") + _assert_tz_aware(listed[0]["updated_at"], context="thread_meta.search.updated_at") + finally: + await _cleanup() + + +@pytest.mark.anyio +async def test_run_repository_emits_tz_aware_timestamps(tmp_path): + from deerflow.persistence.run import RunRepository + + repo = RunRepository(await _init_sqlite(tmp_path)) + try: + await repo.put("r-tz", thread_id="t-tz", user_id="u1") + row = await repo.get("r-tz", user_id="u1") + _assert_tz_aware(row["created_at"], context="run.get.created_at") + _assert_tz_aware(row["updated_at"], context="run.get.updated_at") + finally: + await _cleanup() + + +@pytest.mark.anyio +async def test_feedback_repository_emits_tz_aware_timestamps(tmp_path): + from deerflow.persistence.feedback import FeedbackRepository + + repo = FeedbackRepository(await _init_sqlite(tmp_path)) + try: + record = await repo.create(run_id="r-tz", thread_id="t-tz", rating=1, user_id="u1") + _assert_tz_aware(record["created_at"], context="feedback.create.created_at") + finally: + await _cleanup() + + +@pytest.mark.anyio +async def test_run_event_store_emits_tz_aware_timestamps(tmp_path): + from deerflow.runtime.events.store.db import DbRunEventStore + + store = DbRunEventStore(await _init_sqlite(tmp_path)) + try: + await store.put( + thread_id="t-tz", + run_id="r-tz", + event_type="log", + category="log", + content="hello", + ) + events = await store.list_events("t-tz", "r-tz", user_id=None) + assert events, "expected at least one event" + _assert_tz_aware(events[0]["created_at"], context="run_event.list.created_at") + finally: + await _cleanup() From ca7042dec279d57139c17cc602e1fb674e4f469c Mon Sep 17 00:00:00 2001 From: john lee <64lamei@gmail.com> Date: Thu, 21 May 2026 16:42:26 +0800 Subject: [PATCH 59/86] chore(windows): add PYTHONIOENCODING and PYTHONUTF8 to backend Makefile targets (#3069) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit langgraph-api emits → and ⚠️ characters in version-check log lines. On Windows with cp1252 as the default stream encoding, each such line throws a UnicodeEncodeError inside the logging handler, littering startup output with tracebacks (though the server still boots). #1550 already fixed this for scripts/check.py via stream.reconfigure(). Apply the same treatment to the backend Makefile dev/gateway/test targets by setting PYTHONIOENCODING=utf-8 and PYTHONUTF8=1 before each uv run invocation. Both variables are no-ops on Linux/macOS where UTF-8 is already the default. Closes #2337 Co-authored-by: Claude Sonnet 4.6 --- backend/Makefile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/Makefile b/backend/Makefile index 81a055684..a1206547b 100644 --- a/backend/Makefile +++ b/backend/Makefile @@ -2,13 +2,13 @@ install: uv sync dev: - PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 --reload + PYTHONPATH=. PYTHONIOENCODING=utf-8 PYTHONUTF8=1 uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 --reload gateway: - PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 + PYTHONPATH=. PYTHONIOENCODING=utf-8 PYTHONUTF8=1 uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 test: - PYTHONPATH=. uv run pytest tests/ -v + PYTHONPATH=. PYTHONIOENCODING=utf-8 PYTHONUTF8=1 uv run pytest tests/ -v lint: uvx ruff check . From df951542827ff97f400494cec99739ac5ef99b87 Mon Sep 17 00:00:00 2001 From: Xinmin Zeng <135568692+fancyboi999@users.noreply.github.com> Date: Thu, 21 May 2026 16:49:31 +0800 Subject: [PATCH 60/86] fix(tracing): propagate session_id and user_id into Langfuse traces (#2944) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(tracing): propagate session_id and user_id into Langfuse traces Adds Langfuse v4 reserved trace attributes (langfuse_session_id, langfuse_user_id, langfuse_trace_name, langfuse_tags) to RunnableConfig.metadata inside the run worker, so the langchain CallbackHandler can lift them onto the root trace. - New deerflow.tracing.metadata.build_langfuse_trace_metadata() returns the reserved keys when Langfuse is in the enabled providers, else {}. - worker.run_agent merges them with setdefault so caller-supplied keys win, allowing per-request overrides from upstream metadata. - session_id mirrors the LangGraph thread_id; user_id reads get_effective_user_id() (falls back to "default" in no-auth mode). - trace_name defaults to "lead-agent"; tags carry env and model name when DEER_FLOW_ENV (or ENVIRONMENT) and a model name are present. Closes #2930 * fix(tracing): attach Langfuse callback at graph root so metadata propagates The first commit injected ``langfuse_session_id`` / ``langfuse_user_id`` / ``langfuse_trace_name`` / ``langfuse_tags`` into ``RunnableConfig.metadata``, but on ``main`` the Langfuse callback is attached at *model* level (``models/factory.py``). LangChain still threads ``parent_run_id`` through the contextvar, so the handler sees the model as a nested observation and ``__on_llm_action`` strips the ``langfuse_*`` keys (``keep_langfuse_trace_attributes=False``). The trace's top-level ``sessionId`` / ``userId`` therefore stayed empty in deer-flow's LangGraph runtime — confirmed live against a real Langfuse instance. This commit moves the callback to the **graph invocation root** so the handler fires ``on_chain_start(parent_run_id=None)`` and runs the ``propagate_attributes`` path that actually lifts ``session_id`` / ``user_id`` onto the trace: - ``models/factory.py``: add ``attach_tracing`` keyword (default ``True``) so standalone callers (``MemoryUpdater``, etc.) keep their direct model-level tracing. - ``agents/lead_agent/agent.py``: call ``build_tracing_callbacks()`` once inside ``_make_lead_agent`` and append the result to ``config["callbacks"]``; the four in-graph ``create_chat_model`` sites (bootstrap, default agent, sync + async summarization) pass ``attach_tracing=False`` to avoid duplicate spans. - ``agents/middlewares/title_middleware.py``: same ``attach_tracing=False`` for the title-generation model, since it inherits the graph's RunnableConfig via ``_get_runnable_config``. Test updates: - ``tests/test_lead_agent_model_resolution.py`` and ``tests/test_title_middleware_core_logic.py``: extend the fake ``create_chat_model`` signatures / mock assertions to accept the new ``attach_tracing`` kwarg. - ``tests/test_worker_langfuse_metadata.py``: switch the no-user fallback test from direct ContextVar mutation to ``monkeypatch.setattr`` on ``get_effective_user_id`` to avoid pollution across the langfuse OTel global tracer provider. - ``tests/conftest.py``: add an autouse fixture that resets ``deerflow.config.title_config._title_config`` to its pristine default after every test. Any test that loads the real ``config.yaml`` (via ``get_app_config()``) calls ``load_title_config_from_dict`` and mutates the module-level singleton, which previously poisoned the title-middleware suite when run after, e.g., the new ``test_worker_langfuse_metadata.py`` cases. The fixture is independent of this PR's main change but unblocks the cross-file test run. Live verification (same Langfuse instance as before): - Drove ``worker.run_agent`` against the real ``make_lead_agent`` + ``gpt-4o-mini`` for three distinct ``user_context`` identities (``fancy-engineer``, ``alice-pm``, ``bob-designer``). - Each run produced one ``lead-agent`` trace whose top-level ``sessionId`` / ``userId`` / ``tags`` carry the expected values, e.g. ``session=e2e-2930-8f347c-alice-pm user=alice-pm name='lead-agent' tags=['model:gpt-4o-mini']``. Refs #2930. * fix(tracing): extend root-callback + metadata injection to the embedded client Addresses Copilot review on PR #2944. Commit 2 disabled model-level tracing for ``TitleMiddleware`` and ``_create_summarization_middleware`` because ``_make_lead_agent`` now attaches the tracing callbacks at the graph invocation root. But the embedded ``DeerFlowClient`` does not call ``_make_lead_agent`` — it calls ``_build_middlewares`` directly and never appends the tracing handlers to its ``RunnableConfig``. So under the embedded path, title-generation and summarization LLM calls were left untraced — a regression introduced by this PR. This commit mirrors the gateway worker's injection in ``DeerFlowClient.stream``: - Append ``build_tracing_callbacks()`` to ``config["callbacks"]`` so the Langfuse handler sees ``on_chain_start(parent_run_id=None)`` at the graph root and runs the ``propagate_attributes`` path. - Merge ``build_langfuse_trace_metadata(...)`` into ``config["metadata"]`` with ``setdefault`` so caller-supplied keys still win. - ``_ensure_agent`` now creates its main model with ``attach_tracing=False`` to avoid duplicate spans now that the callback lives at the graph root. Docs: - ``backend/CLAUDE.md`` Tracing section rewritten to describe the graph-root attachment model (replacing the inaccurate "at model-creation time" wording). - ``README.md`` Langfuse section now lists both injection points (worker + client) instead of only the worker path. Tests: - ``tests/test_client_langfuse_metadata.py`` (new, 3 cases): callbacks + metadata are injected when Langfuse is enabled, caller-supplied metadata overrides win via ``setdefault``, and the injection is inert when Langfuse is disabled. Live verification on the real Langfuse instance: === user=fancy-client === id=cbd22847.. session=client-2930-6b9491-fancy-client user=fancy-client name='lead-agent' === user=alice-client === id=b4f6f576.. session=client-2930-6b9491-alice-client user=alice-client name='lead-agent' Refs #2930. * refactor(tracing): address maintainer review on PR #2944 Addresses @WillemJiang's 5 comments. 1. Duplicated metadata-injection code between worker.py and client.py New ``deerflow.tracing.inject_langfuse_metadata(config, ...)`` helper takes the 10-line build + merge + setdefault logic that was duplicated in ``runtime/runs/worker.py`` and ``client.py``. Both callers now share a single source of truth, so the two paths cannot drift. 2. Direct private-attribute mutation in conftest.py and tests Added public ``reset_tracing_config()`` / ``reset_title_config()`` functions. ``tests/conftest.py`` and every test that previously did ``tracing_module._tracing_config = None`` or ``title_module._title_config = TitleConfig()`` now goes through the public API. A future internal rename will surface as an ImportError instead of a silent no-op. 3. client.py reading os.environ directly ``DeerFlowClient.__init__`` grows an optional ``environment`` parameter so programmatic callers can pass the deployment label explicitly. ``stream()`` consults ``self._environment`` first and only falls back to ``DEER_FLOW_ENV`` / ``ENVIRONMENT`` env vars when nothing was passed in. Backwards compatible — env-var behaviour preserved for callers that opt to keep using it. 4. build_tracing_callbacks() cached on hot path Not implemented. Inspected the langfuse v4 ``langchain.CallbackHandler`` constructor: it only resolves the module-level singleton client via ``get_client()`` and initialises a few dicts (no I/O, no env parsing at construction time). The build is essentially free. Caching would trade a non-measurable speedup for two real risks: handler instances carry per-run state internally (``_run_states``, ``_root_run_states``, ``last_trace_id``), and tracing config can be reloaded by env-var changes between runs. Will revisit if profiling ever shows it as a hot spot. 5. attach_tracing=False easy to forget at new in-graph call sites - Module docstring at the top of ``lead_agent/agent.py`` documents the invariant ("every in-graph ``create_chat_model`` MUST pass ``attach_tracing=False``") and enumerates the current sites. - New regression test ``test_make_lead_agent_attaches_tracing_callbacks_at_graph_root`` in ``tests/test_lead_agent_model_resolution.py`` locks both halves of the invariant: ``config["callbacks"]`` carries the tracing handler after ``_make_lead_agent``, AND every ``create_chat_model`` call captured by the test passes ``attach_tracing=False``. A future in-graph site that forgets the flag will fail this test. Lint clean. Full touched-suite bundle: 246 passed. --------- Co-authored-by: Willem Jiang --- README.md | 9 + backend/CLAUDE.md | 18 ++ .../deerflow/agents/lead_agent/agent.py | 46 +++- .../agents/middlewares/title_middleware.py | 6 +- backend/packages/harness/deerflow/client.py | 38 ++- .../harness/deerflow/config/title_config.py | 13 + .../harness/deerflow/config/tracing_config.py | 12 + .../harness/deerflow/models/factory.py | 26 +- .../harness/deerflow/runtime/runs/worker.py | 16 ++ .../harness/deerflow/tracing/__init__.py | 7 +- .../harness/deerflow/tracing/metadata.py | 105 ++++++++ backend/tests/conftest.py | 25 ++ .../tests/test_client_langfuse_metadata.py | 159 +++++++++++ .../tests/test_lead_agent_model_resolution.py | 55 +++- .../tests/test_title_middleware_core_logic.py | 3 +- backend/tests/test_tracing_config.py | 3 +- backend/tests/test_tracing_factory.py | 10 +- backend/tests/test_tracing_metadata.py | 137 ++++++++++ .../tests/test_worker_langfuse_metadata.py | 248 ++++++++++++++++++ 19 files changed, 910 insertions(+), 26 deletions(-) create mode 100644 backend/packages/harness/deerflow/tracing/metadata.py create mode 100644 backend/tests/test_client_langfuse_metadata.py create mode 100644 backend/tests/test_tracing_metadata.py create mode 100644 backend/tests/test_worker_langfuse_metadata.py diff --git a/README.md b/README.md index 8248e8fe4..83b43fd93 100644 --- a/README.md +++ b/README.md @@ -546,6 +546,15 @@ LANGFUSE_BASE_URL=https://cloud.langfuse.com If you are using a self-hosted Langfuse instance, set `LANGFUSE_BASE_URL` to your deployment URL. +**Trace correlation fields.** Every agent run is annotated with Langfuse's reserved trace attributes so the Sessions and Users pages light up automatically: + +- `session_id` = LangGraph `thread_id` — groups every trace of the same conversation +- `user_id` = effective user from `get_effective_user_id()` (falls back to `default` in no-auth mode) +- `trace_name` = assistant id (defaults to `lead-agent`) +- `tags` = `[env:, model:]` (omitted when not set) + +These are injected into `RunnableConfig.metadata` at the graph invocation root for both the gateway path (`runtime/runs/worker.py::run_agent`) and the embedded path (`client.py::DeerFlowClient.stream`), so any LangChain-compatible callback can read them. Set `DEER_FLOW_ENV` (or `ENVIRONMENT`) to tag traces by deployment environment. + #### Using Both Providers If both LangSmith and Langfuse are enabled, DeerFlow attaches both tracing callbacks and reports the same model activity to both systems. diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 886b82dcb..8c4711395 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -397,6 +397,24 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_ - `resolve_variable(path)` - Import module and return variable (e.g., `module.path:variable_name`) - `resolve_class(path, base_class)` - Import and validate class against base class +### Tracing System (`packages/harness/deerflow/tracing/`) + +LangSmith and Langfuse are both supported. The wiring lives in two layers: + +- `factory.py::build_tracing_callbacks()` — returns the LangChain `CallbackHandler` list for the providers currently enabled via env vars (`LANGSMITH_TRACING`, `LANGFUSE_TRACING`, etc.). The handlers are attached at the **graph invocation root** for in-graph runs (`make_lead_agent` and `DeerFlowClient.stream` both append them to `config["callbacks"]` before invoking the graph) so a single run produces one trace with all node / LLM / tool calls as child spans. Standalone callers — anything that invokes a model outside such a graph (e.g. `MemoryUpdater`) — keep `create_chat_model`'s default `attach_tracing=True`, which falls back to model-level callback attachment. +- `metadata.py::build_langfuse_trace_metadata()` — builds the Langfuse-reserved trace attributes for `RunnableConfig.metadata`. The Langfuse v4 `langchain.CallbackHandler` lifts these onto the root trace (see its `_parse_langfuse_trace_attributes`), but only when it sees `on_chain_start(parent_run_id=None)` — which is why the callbacks have to live at the graph root, not the model. + +**Trace-attribute injection points**: both `runtime/runs/worker.py::run_agent` (gateway path) and `client.py::DeerFlowClient.stream` (embedded path) merge the metadata into `config["metadata"]` right before constructing the graph. Caller-supplied keys win via `setdefault`, so an external `session_id` override is preserved. Field mapping: + +| Langfuse field | Source | +|-----------------------|----------------------------------------------| +| `langfuse_session_id` | LangGraph `thread_id` | +| `langfuse_user_id` | `get_effective_user_id()` (`default` in no-auth) | +| `langfuse_trace_name` | `RunRecord.assistant_id` / client `agent_name` (defaults to `lead-agent`) | +| `langfuse_tags` | `env:` + `model:` | + +Returns `{}` when Langfuse is not in the enabled providers — LangSmith-only deployments are unaffected. Set `DEER_FLOW_ENV` (or `ENVIRONMENT`) to tag traces by deployment environment. Tests live in `tests/test_tracing_factory.py`, `tests/test_tracing_metadata.py`, `tests/test_worker_langfuse_metadata.py`, and `tests/test_client_langfuse_metadata.py`. + ### Config Schema **`config.yaml`** key sections: diff --git a/backend/packages/harness/deerflow/agents/lead_agent/agent.py b/backend/packages/harness/deerflow/agents/lead_agent/agent.py index f4330abc1..328a8a6e1 100644 --- a/backend/packages/harness/deerflow/agents/lead_agent/agent.py +++ b/backend/packages/harness/deerflow/agents/lead_agent/agent.py @@ -1,3 +1,23 @@ +"""Lead agent factory. + +INVARIANT — tracing callback placement +====================================== + +Tracing callbacks (Langfuse, LangSmith) are attached at the **graph +invocation root** in :func:`_make_lead_agent` (see the +``build_tracing_callbacks()`` block that appends to ``config["callbacks"]``). +Every ``create_chat_model(...)`` call inside this module — and inside any +middleware reachable from this graph (e.g. ``TitleMiddleware``) — MUST pass +``attach_tracing=False``. + +Forgetting that flag emits duplicate spans (one rooted at the graph, one at +the model) AND prevents the Langfuse handler's ``propagate_attributes`` +path from firing, so ``session_id`` / ``user_id`` never reach the trace. +The four current sites are: bootstrap agent, default agent, summarization +middleware, and the async path inside ``TitleMiddleware``. Any new in-graph +``create_chat_model`` call must add to this list and pass the flag. +""" + import logging from langchain.agents import create_agent @@ -22,6 +42,7 @@ from deerflow.config.app_config import AppConfig, get_app_config from deerflow.models import create_chat_model from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools from deerflow.skills.types import Skill +from deerflow.tracing import build_tracing_callbacks logger = logging.getLogger(__name__) @@ -73,10 +94,14 @@ def _create_summarization_middleware(*, app_config: AppConfig | None = None) -> # Bind "middleware:summarize" tag so RunJournal identifies these LLM calls # as middleware rather than lead_agent (SummarizationMiddleware is a # LangChain built-in, so we tag the model at creation time). + # attach_tracing=False because the graph-level RunnableConfig (set in + # ``_make_lead_agent``) already carries tracing callbacks; binding them + # again at the model level would emit duplicate spans and break + # ``session_id`` / ``user_id`` propagation. if config.model_name: - model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=resolved_app_config) + model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=resolved_app_config, attach_tracing=False) else: - model = create_chat_model(thinking_enabled=False, app_config=resolved_app_config) + model = create_chat_model(thinking_enabled=False, app_config=resolved_app_config, attach_tracing=False) model = model.with_config(tags=["middleware:summarize"]) # Prepare kwargs @@ -408,13 +433,26 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig): } ) + # Inject tracing callbacks at the graph invocation root so a single LangGraph + # run produces one trace with all node / LLM / tool calls as child spans, + # AND so the Langfuse handler sees ``on_chain_start(parent_run_id=None)`` and + # actually propagates ``langfuse_session_id`` / ``langfuse_user_id`` from + # ``config["metadata"]`` onto the trace. Without root-level attachment the + # model is a nested observation and the handler strips ``langfuse_*`` keys. + tracing_callbacks = build_tracing_callbacks() + if tracing_callbacks: + existing = config.get("callbacks") or [] + if not isinstance(existing, list): + existing = list(existing) + config["callbacks"] = [*existing, *tracing_callbacks] + skills_for_tool_policy = _load_enabled_skills_for_tool_policy(available_skills, app_config=resolved_app_config) if is_bootstrap: # Special bootstrap agent with minimal prompt for initial custom agent creation flow tools = get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent] return create_agent( - model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config), + model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config, attach_tracing=False), tools=filter_tools_by_skill_allowed_tools(tools, skills_for_tool_policy), middleware=_build_middlewares(config, model_name=model_name, app_config=resolved_app_config), system_prompt=apply_prompt_template( @@ -432,7 +470,7 @@ def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig): # Default lead agent (unchanged behavior) tools = get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled, app_config=resolved_app_config) return create_agent( - model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config), + model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config, attach_tracing=False), tools=filter_tools_by_skill_allowed_tools(tools + extra_tools, skills_for_tool_policy), middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config), system_prompt=apply_prompt_template( diff --git a/backend/packages/harness/deerflow/agents/middlewares/title_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/title_middleware.py index b259ce4a4..b6cc72b35 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/title_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/title_middleware.py @@ -160,7 +160,11 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]): prompt, user_msg = self._build_title_prompt(state) try: - model_kwargs = {"thinking_enabled": False} + # attach_tracing=False because ``_get_runnable_config()`` inherits + # the graph-level RunnableConfig (set in ``_make_lead_agent``) whose + # callbacks already carry tracing handlers; binding them again at + # the model level would emit duplicate spans. + model_kwargs = {"thinking_enabled": False, "attach_tracing": False} if self._app_config is not None: model_kwargs["app_config"] = self._app_config if config.model_name: diff --git a/backend/packages/harness/deerflow/client.py b/backend/packages/harness/deerflow/client.py index 786e7372f..8ffa89e2c 100644 --- a/backend/packages/harness/deerflow/client.py +++ b/backend/packages/harness/deerflow/client.py @@ -19,6 +19,7 @@ import asyncio import json import logging import mimetypes +import os import shutil import tempfile import uuid @@ -42,6 +43,7 @@ from deerflow.config.paths import get_paths from deerflow.models import create_chat_model from deerflow.runtime.user_context import get_effective_user_id from deerflow.skills.storage import get_or_new_skill_storage +from deerflow.tracing import build_tracing_callbacks, inject_langfuse_metadata from deerflow.uploads.manager import ( claim_unique_filename, delete_file_safe, @@ -123,6 +125,7 @@ class DeerFlowClient: agent_name: str | None = None, available_skills: set[str] | None = None, middlewares: Sequence[AgentMiddleware] | None = None, + environment: str | None = None, ): """Initialize the client. @@ -140,6 +143,12 @@ class DeerFlowClient: agent_name: Name of the agent to use. available_skills: Optional set of skill names to make available. If None (default), all scanned skills are available. middlewares: Optional list of custom middlewares to inject into the agent. + environment: Deployment environment label that ends up in + ``langfuse_tags`` (e.g. ``"production"`` / ``"staging"``). + When ``None`` the worker/client falls back to the + ``DEER_FLOW_ENV`` or ``ENVIRONMENT`` env vars. Pass an + explicit value for programmatic callers that do not want + env-var coupling. """ if config_path is not None: reload_app_config(config_path) @@ -156,6 +165,7 @@ class DeerFlowClient: self._agent_name = agent_name self._available_skills = set(available_skills) if available_skills is not None else None self._middlewares = list(middlewares) if middlewares else [] + self._environment = environment # Lazy agent — created on first call, recreated when config changes. self._agent = None @@ -228,7 +238,11 @@ class DeerFlowClient: max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3) kwargs: dict[str, Any] = { - "model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled), + # attach_tracing=False because ``stream()`` injects tracing + # callbacks at the graph invocation root so a single embedded run + # produces one trace with correct session_id / user_id propagation. + # Attaching them again on the model would emit duplicate spans. + "model": create_chat_model(name=model_name, thinking_enabled=thinking_enabled, attach_tracing=False), "tools": self._get_tools(model_name=model_name, subagent_enabled=subagent_enabled), "middleware": _build_middlewares(config, model_name=model_name, agent_name=self._agent_name, custom_middlewares=self._middlewares), "system_prompt": apply_prompt_template( @@ -571,6 +585,28 @@ class DeerFlowClient: thread_id = str(uuid.uuid4()) config = self._get_runnable_config(thread_id, **kwargs) + + # Inject tracing callbacks and Langfuse trace metadata at the graph + # invocation root so the embedded client matches the gateway worker's + # behaviour: a single ``stream()`` produces one trace with all node / + # LLM / tool calls nested under it, and the trace carries the reserved + # ``langfuse_session_id`` / ``langfuse_user_id`` keys that the Langfuse + # CallbackHandler lifts onto the root trace's ``sessionId`` / ``userId``. + tracing_callbacks = build_tracing_callbacks() + if tracing_callbacks: + existing_callbacks = list(config.get("callbacks") or []) + config["callbacks"] = [*existing_callbacks, *tracing_callbacks] + + configurable = config.get("configurable") or {} + inject_langfuse_metadata( + config, + thread_id=thread_id, + user_id=get_effective_user_id(), + assistant_id=self._agent_name or "lead-agent", + model_name=configurable.get("model_name") or self._model_name, + environment=self._environment or os.environ.get("DEER_FLOW_ENV") or os.environ.get("ENVIRONMENT"), + ) + self._ensure_agent(config) state: dict[str, Any] = {"messages": [HumanMessage(content=message)]} diff --git a/backend/packages/harness/deerflow/config/title_config.py b/backend/packages/harness/deerflow/config/title_config.py index f335b4952..2d2e73789 100644 --- a/backend/packages/harness/deerflow/config/title_config.py +++ b/backend/packages/harness/deerflow/config/title_config.py @@ -51,3 +51,16 @@ def load_title_config_from_dict(config_dict: dict) -> None: """Load title configuration from a dictionary.""" global _title_config _title_config = TitleConfig(**config_dict) + + +def reset_title_config() -> None: + """Restore the title configuration to its pristine ``TitleConfig()`` default. + + Public API so that tests do not have to reach into the private + ``_title_config`` module attribute. ``AppConfig.from_file()`` calls + :func:`load_title_config_from_dict`, which permanently mutates the + singleton; tests that need a clean slate between cases should call + this between tests. + """ + global _title_config + _title_config = TitleConfig() diff --git a/backend/packages/harness/deerflow/config/tracing_config.py b/backend/packages/harness/deerflow/config/tracing_config.py index 1ef5ebeb4..399e37424 100644 --- a/backend/packages/harness/deerflow/config/tracing_config.py +++ b/backend/packages/harness/deerflow/config/tracing_config.py @@ -147,3 +147,15 @@ def validate_enabled_tracing_providers() -> None: def is_tracing_enabled() -> bool: """Check if any tracing provider is enabled and fully configured.""" return get_tracing_config().is_configured + + +def reset_tracing_config() -> None: + """Discard the cached :class:`TracingConfig` so the next call rebuilds it. + + Public API so that tests do not have to reach into the private + ``_tracing_config`` module attribute. A future internal rename would + silently break callers that mutate the attribute directly. + """ + global _tracing_config + with _config_lock: + _tracing_config = None diff --git a/backend/packages/harness/deerflow/models/factory.py b/backend/packages/harness/deerflow/models/factory.py index 518bdc9f1..c6a3573f8 100644 --- a/backend/packages/harness/deerflow/models/factory.py +++ b/backend/packages/harness/deerflow/models/factory.py @@ -47,11 +47,24 @@ def _enable_stream_usage_by_default(model_use_path: str, model_settings_from_con model_settings_from_config["stream_usage"] = True -def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *, app_config: AppConfig | None = None, **kwargs) -> BaseChatModel: +def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *, app_config: AppConfig | None = None, attach_tracing: bool = True, **kwargs) -> BaseChatModel: """Create a chat model instance from the config. Args: name: The name of the model to create. If None, the first model in the config will be used. + thinking_enabled: Enable the model's extended-thinking mode when supported. + app_config: Explicit application config; falls back to the cached global if omitted. + attach_tracing: When True (default), attach tracing callbacks (Langfuse, + LangSmith) directly to the model instance. Standalone callers — anything + that invokes the model outside a LangGraph run that already wires tracing + at the invocation root (``MemoryUpdater``, ad-hoc utilities, etc.) — keep + this default so the model-level callback still produces traces. Callers + that already attach tracing at the graph root (``make_lead_agent``, the + in-graph ``TitleMiddleware``) MUST pass ``attach_tracing=False``; otherwise + the same LLM call emits duplicate spans (one rooted at the graph, one at + the model) and ``session_id`` / ``user_id`` metadata never reach the trace + because the model becomes a nested observation whose ``langfuse_*`` keys + get stripped. Returns: A chat model instance. @@ -149,9 +162,10 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, * model_instance = model_class(**kwargs, **model_settings_from_config) - callbacks = build_tracing_callbacks() - if callbacks: - existing_callbacks = model_instance.callbacks or [] - model_instance.callbacks = [*existing_callbacks, *callbacks] - logger.debug(f"Tracing attached to model '{name}' with providers={len(callbacks)}") + if attach_tracing: + callbacks = build_tracing_callbacks() + if callbacks: + existing_callbacks = model_instance.callbacks or [] + model_instance.callbacks = [*existing_callbacks, *callbacks] + logger.debug(f"Tracing attached to model '{name}' with providers={len(callbacks)}") return model_instance diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index 09e3c66e9..aa47cd39b 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -19,6 +19,7 @@ import asyncio import copy import inspect import logging +import os from dataclasses import dataclass, field from functools import lru_cache from typing import TYPE_CHECKING, Any, Literal, cast @@ -31,6 +32,8 @@ if TYPE_CHECKING: from deerflow.config.app_config import AppConfig from deerflow.runtime.serialization import serialize from deerflow.runtime.stream_bridge import StreamBridge +from deerflow.runtime.user_context import get_effective_user_id +from deerflow.tracing import inject_langfuse_metadata from .manager import RunManager, RunRecord from .naming import resolve_root_run_name @@ -225,6 +228,19 @@ async def run_agent( if journal is not None: config.setdefault("callbacks", []).append(journal) + # Inject Langfuse trace-attribute metadata so the langchain CallbackHandler + # can lift session_id / user_id / trace_name / tags onto the root trace. + # Shared helper with ``DeerFlowClient.stream`` so both entry points stay + # in sync; caller-provided metadata wins via setdefault inside the helper. + inject_langfuse_metadata( + config, + thread_id=thread_id, + user_id=get_effective_user_id(), + assistant_id=record.assistant_id, + model_name=record.model_name, + environment=os.environ.get("DEER_FLOW_ENV") or os.environ.get("ENVIRONMENT"), + ) + # Resolve after runtime context installation so context/configurable reflect # the agent name that this run will actually execute. config.setdefault("run_name", resolve_root_run_name(config, record.assistant_id)) diff --git a/backend/packages/harness/deerflow/tracing/__init__.py b/backend/packages/harness/deerflow/tracing/__init__.py index f132815fb..6d00e9c69 100644 --- a/backend/packages/harness/deerflow/tracing/__init__.py +++ b/backend/packages/harness/deerflow/tracing/__init__.py @@ -1,3 +1,8 @@ from .factory import build_tracing_callbacks +from .metadata import build_langfuse_trace_metadata, inject_langfuse_metadata -__all__ = ["build_tracing_callbacks"] +__all__ = [ + "build_langfuse_trace_metadata", + "build_tracing_callbacks", + "inject_langfuse_metadata", +] diff --git a/backend/packages/harness/deerflow/tracing/metadata.py b/backend/packages/harness/deerflow/tracing/metadata.py new file mode 100644 index 000000000..3dabf169a --- /dev/null +++ b/backend/packages/harness/deerflow/tracing/metadata.py @@ -0,0 +1,105 @@ +"""Langfuse trace-attribute metadata builders. + +The Langfuse v4 ``langchain.CallbackHandler`` lifts a fixed set of reserved +keys from ``RunnableConfig.metadata`` onto the root trace: + +- ``langfuse_session_id`` → groups traces (LangGraph thread → Langfuse Session) +- ``langfuse_user_id`` → trace user_id (powers the Users page) +- ``langfuse_trace_name`` → human-readable trace name +- ``langfuse_tags`` → trace tags + +See ``langfuse/langchain/CallbackHandler.py::_parse_langfuse_trace_attributes`` +and https://langfuse.com/docs/observability/features/sessions for the +contract. Builders here exist so the gateway/run worker can inject the +right metadata without leaking Langfuse internals into the call sites. +""" + +from __future__ import annotations + +from typing import Any + +from deerflow.config import get_enabled_tracing_providers + +# Lazy-imported below to avoid a circular import: ``deerflow.runtime`` eagerly +# imports the run worker, which in turn needs ``deerflow.tracing``. +_DEFAULT_TRACE_NAME = "lead-agent" + + +def build_langfuse_trace_metadata( + *, + thread_id: str | None, + user_id: str | None = None, + assistant_id: str | None = None, + model_name: str | None = None, + environment: str | None = None, +) -> dict[str, Any]: + """Return Langfuse trace-attribute metadata for ``RunnableConfig.metadata``. + + Returns ``{}`` when Langfuse is not in the enabled tracing providers so + callers can unconditionally merge the result without affecting LangSmith + or other tracers. + + Args: + thread_id: LangGraph thread id; mapped to ``langfuse_session_id``. + user_id: Effective user id; falls back to ``DEFAULT_USER_ID`` when + ``None`` so the Langfuse Users page works in no-auth mode. + assistant_id: Optional agent identifier; defaults to ``"lead-agent"``. + model_name: Model name; emitted as ``model:`` in ``langfuse_tags``. + environment: Deployment env (e.g. ``"production"``); emitted as + ``env:`` in ``langfuse_tags``. + """ + if "langfuse" not in get_enabled_tracing_providers(): + return {} + + from deerflow.runtime.user_context import DEFAULT_USER_ID + + metadata: dict[str, Any] = { + "langfuse_session_id": thread_id, + "langfuse_user_id": user_id or DEFAULT_USER_ID, + "langfuse_trace_name": assistant_id or _DEFAULT_TRACE_NAME, + } + + tags: list[str] = [] + if environment: + tags.append(f"env:{environment}") + if model_name: + tags.append(f"model:{model_name}") + if tags: + metadata["langfuse_tags"] = tags + + return metadata + + +def inject_langfuse_metadata( + config: dict, + *, + thread_id: str | None, + user_id: str | None = None, + assistant_id: str | None = None, + model_name: str | None = None, + environment: str | None = None, +) -> None: + """Merge Langfuse trace-attribute metadata into ``config["metadata"]``. + + Shared by the gateway worker (``runtime/runs/worker.py``) and the + embedded client (``client.py``) so the two paths cannot drift apart. + + Caller-supplied metadata wins via ``setdefault`` — an upstream value + for e.g. ``langfuse_session_id`` set by the frontend stays untouched. + The ``config`` dict is mutated in place; the call is a no-op when + Langfuse is not in the enabled tracing providers. + """ + langfuse_metadata = build_langfuse_trace_metadata( + thread_id=thread_id, + user_id=user_id, + assistant_id=assistant_id, + model_name=model_name, + environment=environment, + ) + if not langfuse_metadata: + return + + merged_metadata = dict(config.get("metadata") or {}) + for key, value in langfuse_metadata.items(): + merged_metadata.setdefault(key, value) + config["metadata"] = merged_metadata diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 9bc8d4884..5293652f6 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -176,6 +176,31 @@ def _reset_skill_storage_singleton(): reset_skill_storage() +@pytest.fixture(autouse=True) +def _restore_title_config_singleton(): + """Reset ``_title_config`` to its pristine default after every test. + + ``AppConfig.from_file()`` writes the on-disk ``title`` block into the + module-level singleton (``config/app_config.py`` calls + ``load_title_config_from_dict``). Any test that loads the real + ``config.yaml`` therefore leaves the singleton in a state that + ``test_title_middleware_core_logic.py`` does not expect; that suite + relies on the pristine ``TitleConfig()`` default (``enabled=True``). + We restore the default after every test so test files stay + independent regardless of order. + """ + try: + from deerflow.config.title_config import reset_title_config + except ImportError: + yield + return + + try: + yield + finally: + reset_title_config() + + @pytest.fixture(autouse=True) def _auto_user_context(request): """Inject a default ``test-user-autouse`` into the contextvar. diff --git a/backend/tests/test_client_langfuse_metadata.py b/backend/tests/test_client_langfuse_metadata.py new file mode 100644 index 000000000..3116fd331 --- /dev/null +++ b/backend/tests/test_client_langfuse_metadata.py @@ -0,0 +1,159 @@ +"""Tests for DeerFlowClient's graph-root tracing wiring. + +Regression coverage for the Copilot review on PR #2944: when the title +and summarization middlewares request ``attach_tracing=False`` we must +make sure ``DeerFlowClient`` injects the tracing callbacks at the graph +invocation root instead, otherwise those middlewares produce untraced +LLM calls. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest + +from deerflow.client import DeerFlowClient + + +class _FakeAgent: + """Capture the ``config`` handed to ``agent.stream``.""" + + def __init__(self) -> None: + self.captured_config: dict | None = None + self.checkpointer = None + self.store = None + + def stream(self, state, *, config, context, stream_mode): + self.captured_config = config + return iter(()) # empty stream + + +@pytest.fixture(autouse=True) +def _clear_langfuse_env(monkeypatch): + from deerflow.config.tracing_config import reset_tracing_config + + for name in ("LANGFUSE_TRACING", "LANGFUSE_PUBLIC_KEY", "LANGFUSE_SECRET_KEY", "LANGFUSE_BASE_URL"): + monkeypatch.delenv(name, raising=False) + reset_tracing_config() + yield + reset_tracing_config() + + +def _stub_agent_creation(monkeypatch, fake_agent: _FakeAgent) -> dict[str, Any]: + """Short-circuit the heavy parts of ``_ensure_agent`` so we can drive + ``stream()`` against a fake graph without touching real models, tools + or middleware factories. + """ + captured: dict[str, Any] = {} + + def _stub_ensure_agent(self, config): + captured["config"] = config + self._agent = fake_agent + self._agent_config_key = ("stub",) + + monkeypatch.setattr(DeerFlowClient, "_ensure_agent", _stub_ensure_agent) + return captured + + +def _make_client(_monkeypatch) -> DeerFlowClient: + """Build a client without going through ``__init__`` so we never load + config.yaml or perform any other side-effectful startup work.""" + fake_app_config = SimpleNamespace(models=[SimpleNamespace(name="stub-model")]) + client = DeerFlowClient.__new__(DeerFlowClient) + client._app_config = fake_app_config + client._extensions_config = None + client._model_name = "stub-model" + client._thinking_enabled = False + client._plan_mode = False + client._subagent_enabled = False + client._agent_name = None + client._available_skills = None + client._middlewares = None + client._checkpointer = None + client._agent = None + client._agent_config_key = None + client._environment = None + return client + + +def test_stream_injects_langfuse_metadata_when_enabled(monkeypatch): + monkeypatch.setenv("LANGFUSE_TRACING", "true") + monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "pk-lf-test") + monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-test") + from deerflow.config.tracing_config import reset_tracing_config + + reset_tracing_config() + + class _SentinelHandler: + pass + + sentinel = _SentinelHandler() + monkeypatch.setattr("deerflow.client.build_tracing_callbacks", lambda: [sentinel]) + + fake_agent = _FakeAgent() + captured = _stub_agent_creation(monkeypatch, fake_agent) + client = _make_client(monkeypatch) + + list(client.stream("hi", thread_id="thread-client-1")) + + config = captured["config"] + metadata = config.get("metadata") or {} + assert metadata.get("langfuse_session_id") == "thread-client-1" + assert metadata.get("langfuse_trace_name") == "lead-agent" + # Default no-auth context falls back to ``"default"`` user. + assert metadata.get("langfuse_user_id") in {"default", "test-user-autouse"} + callbacks = config.get("callbacks") or [] + assert sentinel in callbacks + + +def test_stream_is_inert_when_langfuse_disabled(monkeypatch): + monkeypatch.setattr("deerflow.client.build_tracing_callbacks", lambda: []) + + fake_agent = _FakeAgent() + captured = _stub_agent_creation(monkeypatch, fake_agent) + client = _make_client(monkeypatch) + + list(client.stream("hi", thread_id="thread-client-2")) + + config = captured["config"] + assert "callbacks" not in config or not config["callbacks"] + metadata = config.get("metadata") or {} + assert "langfuse_session_id" not in metadata + assert "langfuse_user_id" not in metadata + + +def test_stream_preserves_caller_metadata_overrides(monkeypatch): + monkeypatch.setenv("LANGFUSE_TRACING", "true") + monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "pk-lf-test") + monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-test") + from deerflow.config.tracing_config import reset_tracing_config + + reset_tracing_config() + monkeypatch.setattr("deerflow.client.build_tracing_callbacks", lambda: []) + + fake_agent = _FakeAgent() + captured = _stub_agent_creation(monkeypatch, fake_agent) + client = _make_client(monkeypatch) + + # Drive stream with a pre-populated metadata so the worker-equivalent + # ``setdefault`` semantics are exercised. + original_get_config = DeerFlowClient._get_runnable_config + + def patched_get_runnable_config(self, thread_id, **overrides): + cfg = original_get_config(self, thread_id, **overrides) + cfg["metadata"] = { + "langfuse_session_id": "explicit-session-override", + "langfuse_user_id": "explicit-user", + } + return cfg + + monkeypatch.setattr(DeerFlowClient, "_get_runnable_config", patched_get_runnable_config) + list(client.stream("hi", thread_id="thread-client-3")) + + metadata = captured["config"].get("metadata") or {} + assert metadata["langfuse_session_id"] == "explicit-session-override" + assert metadata["langfuse_user_id"] == "explicit-user" + # ``trace_name`` was not supplied by caller so the worker still fills it. + assert metadata["langfuse_trace_name"] == "lead-agent" diff --git a/backend/tests/test_lead_agent_model_resolution.py b/backend/tests/test_lead_agent_model_resolution.py index 976730d44..7ac4b97e6 100644 --- a/backend/tests/test_lead_agent_model_resolution.py +++ b/backend/tests/test_lead_agent_model_resolution.py @@ -41,6 +41,49 @@ def test_make_lead_agent_signature_matches_langgraph_server_factory_abi(): assert list(inspect.signature(lead_agent_module.make_lead_agent).parameters) == ["config"] +def test_make_lead_agent_attaches_tracing_callbacks_at_graph_root(monkeypatch): + """Regression guard: tracing handlers must be appended to + ``config["callbacks"]`` (graph invocation root), and every in-graph + ``create_chat_model`` call must pass ``attach_tracing=False``. + + Catches future contributors who forget the flag when adding new + in-graph model creation, which would silently produce duplicate + spans and break Langfuse session/user propagation. + """ + app_config = _make_app_config([_make_model("safe-model", supports_thinking=False)]) + + import deerflow.tools as tools_module + + monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: app_config) + monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: []) + monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None, **kwargs: []) + + sentinel_handler = object() + monkeypatch.setattr(lead_agent_module, "build_tracing_callbacks", lambda: [sentinel_handler]) + + seen_attach_tracing: list[bool] = [] + + def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None, attach_tracing=True): + seen_attach_tracing.append(attach_tracing) + return object() + + monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model) + monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs) + + config: dict = {"configurable": {"model_name": "safe-model"}} + lead_agent_module._make_lead_agent(config, app_config=app_config) + + # Handler must land on the graph invocation config so the Langfuse + # CallbackHandler fires ``on_chain_start(parent_run_id=None)`` and + # propagates ``session_id`` / ``user_id`` onto the trace. + assert sentinel_handler in (config.get("callbacks") or []), "build_tracing_callbacks output must be appended to config['callbacks']" + + # Every in-graph create_chat_model call must opt out of model-level + # tracing to avoid duplicate spans. + assert seen_attach_tracing, "_make_lead_agent did not call create_chat_model" + assert all(flag is False for flag in seen_attach_tracing), f"in-graph create_chat_model must pass attach_tracing=False; got {seen_attach_tracing}" + + def test_internal_make_lead_agent_uses_explicit_app_config(monkeypatch): app_config = _make_app_config([_make_model("explicit-model", supports_thinking=False)]) @@ -55,7 +98,7 @@ def test_internal_make_lead_agent_uses_explicit_app_config(monkeypatch): captured: dict[str, object] = {} - def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None): + def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None, attach_tracing=True): captured["name"] = name captured["app_config"] = app_config return object() @@ -89,7 +132,7 @@ def test_make_lead_agent_uses_runtime_app_config_from_context_without_global_rea captured: dict[str, object] = {} - def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None): + def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None, attach_tracing=True): captured["name"] = name captured["app_config"] = app_config return object() @@ -168,7 +211,7 @@ def test_make_lead_agent_disables_thinking_when_model_does_not_support_it(monkey captured: dict[str, object] = {} - def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None): + def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None, attach_tracing=True): captured["name"] = name captured["thinking_enabled"] = thinking_enabled captured["reasoning_effort"] = reasoning_effort @@ -212,7 +255,7 @@ def test_make_lead_agent_reads_runtime_options_from_context(monkeypatch): captured: dict[str, object] = {} - def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None): + def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None, attach_tracing=True): captured["name"] = name captured["thinking_enabled"] = thinking_enabled captured["reasoning_effort"] = reasoning_effort @@ -407,7 +450,7 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch fake_model = MagicMock() fake_model.with_config.return_value = fake_model - def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None): + def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None, attach_tracing=True): captured["name"] = name captured["thinking_enabled"] = thinking_enabled captured["reasoning_effort"] = reasoning_effort @@ -441,7 +484,7 @@ def test_create_summarization_middleware_threads_resolved_app_config_to_model(mo fake_model = MagicMock() fake_model.with_config.return_value = fake_model - def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None): + def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None, attach_tracing=True): captured["app_config"] = app_config return fake_model diff --git a/backend/tests/test_title_middleware_core_logic.py b/backend/tests/test_title_middleware_core_logic.py index 3fdf4d3f9..ac10848e1 100644 --- a/backend/tests/test_title_middleware_core_logic.py +++ b/backend/tests/test_title_middleware_core_logic.py @@ -109,7 +109,7 @@ class TestTitleMiddlewareCoreLogic: title = result["title"] assert title == "短标题" - title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False) + title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False, attach_tracing=False) model.ainvoke.assert_awaited_once() assert model.ainvoke.await_args.kwargs["config"] == { "run_name": "title_agent", @@ -141,6 +141,7 @@ class TestTitleMiddlewareCoreLogic: title_middleware_module.create_chat_model.assert_called_once_with( name="title-model", thinking_enabled=False, + attach_tracing=False, app_config=app_config, ) diff --git a/backend/tests/test_tracing_config.py b/backend/tests/test_tracing_config.py index a13be516d..943401c97 100644 --- a/backend/tests/test_tracing_config.py +++ b/backend/tests/test_tracing_config.py @@ -5,10 +5,11 @@ from __future__ import annotations import pytest from deerflow.config import tracing_config as tracing_module +from deerflow.config.tracing_config import reset_tracing_config def _reset_tracing_cache() -> None: - tracing_module._tracing_config = None + reset_tracing_config() @pytest.fixture(autouse=True) diff --git a/backend/tests/test_tracing_factory.py b/backend/tests/test_tracing_factory.py index b3e77935f..723e42e80 100644 --- a/backend/tests/test_tracing_factory.py +++ b/backend/tests/test_tracing_factory.py @@ -12,7 +12,7 @@ from deerflow.tracing import factory as tracing_factory @pytest.fixture(autouse=True) def clear_tracing_env(monkeypatch): - from deerflow.config import tracing_config as tracing_module + from deerflow.config.tracing_config import reset_tracing_config for name in ( "LANGSMITH_TRACING", @@ -30,9 +30,9 @@ def clear_tracing_env(monkeypatch): "LANGFUSE_BASE_URL", ): monkeypatch.delenv(name, raising=False) - tracing_module._tracing_config = None + reset_tracing_config() yield - tracing_module._tracing_config = None + reset_tracing_config() def test_build_tracing_callbacks_returns_empty_list_when_disabled(monkeypatch): @@ -114,12 +114,12 @@ def test_build_tracing_callbacks_raises_when_enabled_provider_fails(monkeypatch) def test_build_tracing_callbacks_raises_for_explicitly_enabled_misconfigured_provider(monkeypatch): - from deerflow.config import tracing_config as tracing_module + from deerflow.config.tracing_config import reset_tracing_config monkeypatch.setenv("LANGFUSE_TRACING", "true") monkeypatch.delenv("LANGFUSE_PUBLIC_KEY", raising=False) monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-test") - tracing_module._tracing_config = None + reset_tracing_config() with pytest.raises(ValueError, match="LANGFUSE_PUBLIC_KEY"): tracing_factory.build_tracing_callbacks() diff --git a/backend/tests/test_tracing_metadata.py b/backend/tests/test_tracing_metadata.py new file mode 100644 index 000000000..6c758e40d --- /dev/null +++ b/backend/tests/test_tracing_metadata.py @@ -0,0 +1,137 @@ +"""Tests for deerflow.tracing.metadata.build_langfuse_trace_metadata.""" + +from __future__ import annotations + +import pytest + +from deerflow.tracing import metadata as tracing_metadata + + +@pytest.fixture(autouse=True) +def _clear_tracing_env(monkeypatch): + from deerflow.config.tracing_config import reset_tracing_config + + for name in ( + "LANGFUSE_TRACING", + "LANGFUSE_PUBLIC_KEY", + "LANGFUSE_SECRET_KEY", + "LANGFUSE_BASE_URL", + "LANGSMITH_TRACING", + "LANGCHAIN_TRACING_V2", + "LANGCHAIN_TRACING", + "LANGSMITH_API_KEY", + "LANGCHAIN_API_KEY", + ): + monkeypatch.delenv(name, raising=False) + reset_tracing_config() + yield + reset_tracing_config() + + +def _enable_langfuse(monkeypatch): + monkeypatch.setenv("LANGFUSE_TRACING", "true") + monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "pk-lf-test") + monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-test") + + +def test_returns_empty_when_langfuse_disabled(monkeypatch): + # No env vars set → langfuse not in enabled providers. + result = tracing_metadata.build_langfuse_trace_metadata( + thread_id="t-1", + user_id="u-1", + assistant_id="lead-agent", + model_name="gpt-4o", + ) + assert result == {} + + +def test_session_id_maps_to_thread_id(monkeypatch): + _enable_langfuse(monkeypatch) + + result = tracing_metadata.build_langfuse_trace_metadata( + thread_id="thread-abc", + user_id="user-42", + ) + + assert result["langfuse_session_id"] == "thread-abc" + + +def test_user_id_falls_back_to_default(monkeypatch): + _enable_langfuse(monkeypatch) + + result = tracing_metadata.build_langfuse_trace_metadata( + thread_id="thread-abc", + user_id=None, + ) + + assert result["langfuse_user_id"] == "default" + + +def test_user_id_explicit_value_wins(monkeypatch): + _enable_langfuse(monkeypatch) + + result = tracing_metadata.build_langfuse_trace_metadata( + thread_id="thread-abc", + user_id="alice@example.com", + ) + + assert result["langfuse_user_id"] == "alice@example.com" + + +def test_trace_name_uses_assistant_id_when_provided(monkeypatch): + _enable_langfuse(monkeypatch) + + result = tracing_metadata.build_langfuse_trace_metadata( + thread_id="t", + assistant_id="custom-agent", + ) + + assert result["langfuse_trace_name"] == "custom-agent" + + +def test_trace_name_defaults_to_lead_agent(monkeypatch): + _enable_langfuse(monkeypatch) + + result = tracing_metadata.build_langfuse_trace_metadata( + thread_id="t", + assistant_id=None, + ) + + assert result["langfuse_trace_name"] == "lead-agent" + + +def test_tags_include_env_and_model(monkeypatch): + _enable_langfuse(monkeypatch) + + result = tracing_metadata.build_langfuse_trace_metadata( + thread_id="t", + environment="production", + model_name="gpt-4o", + ) + + assert result["langfuse_tags"] == ["env:production", "model:gpt-4o"] + + +def test_tags_omitted_when_no_tag_inputs(monkeypatch): + _enable_langfuse(monkeypatch) + + result = tracing_metadata.build_langfuse_trace_metadata( + thread_id="t", + user_id="u", + ) + + assert "langfuse_tags" not in result + + +def test_thread_id_none_still_produces_metadata(monkeypatch): + # Stateless run paths may not have a thread_id — we still want + # user_id / trace_name to flow through so Users page works. + _enable_langfuse(monkeypatch) + + result = tracing_metadata.build_langfuse_trace_metadata( + thread_id=None, + user_id="u-1", + ) + + assert result["langfuse_session_id"] is None + assert result["langfuse_user_id"] == "u-1" diff --git a/backend/tests/test_worker_langfuse_metadata.py b/backend/tests/test_worker_langfuse_metadata.py new file mode 100644 index 000000000..7b7544771 --- /dev/null +++ b/backend/tests/test_worker_langfuse_metadata.py @@ -0,0 +1,248 @@ +"""Integration test: worker.run_agent injects Langfuse trace metadata. + +Verifies that the agent factory's resulting graph receives a +``RunnableConfig`` whose ``metadata`` carries the Langfuse reserved keys +(``langfuse_session_id`` / ``langfuse_user_id`` / ``langfuse_trace_name``). +""" + +from __future__ import annotations + +import asyncio + +import pytest + +from deerflow.runtime.runs.manager import RunRecord +from deerflow.runtime.runs.schemas import DisconnectMode, RunStatus +from deerflow.runtime.runs.worker import RunContext, run_agent + + +class _FakeAgent: + """Minimal LangGraph-like graph that captures the runnable config.""" + + def __init__(self) -> None: + self.captured_config: dict | None = None + self.metadata: dict = {} + # Worker may assign these attributes; need them to exist. + self.checkpointer = None + self.store = None + self.interrupt_before_nodes: list[str] = [] + self.interrupt_after_nodes: list[str] = [] + + async def astream(self, graph_input, *, config, stream_mode, **kwargs): + self.captured_config = config + # Empty async generator — no chunks produced. + return + yield # pragma: no cover (makes this an async generator) + + +class _FakeRunManager: + async def set_status(self, *_args, **_kwargs) -> None: + return None + + async def update_model_name(self, *_args, **_kwargs) -> None: + return None + + async def update_run_completion(self, *_args, **_kwargs) -> None: + return None + + +class _FakeBridge: + def __init__(self) -> None: + self.events: list[tuple[str, object]] = [] + + async def publish(self, _run_id, event, payload) -> None: + self.events.append((event, payload)) + + async def publish_end(self, _run_id) -> None: + self.events.append(("end", None)) + + async def cleanup(self, _run_id, *, delay: int = 0) -> None: + return None + + +@pytest.fixture(autouse=True) +def _clear_tracing_env(monkeypatch): + from deerflow.config.tracing_config import reset_tracing_config + + for name in ("LANGFUSE_TRACING", "LANGFUSE_PUBLIC_KEY", "LANGFUSE_SECRET_KEY", "LANGFUSE_BASE_URL"): + monkeypatch.delenv(name, raising=False) + reset_tracing_config() + yield + reset_tracing_config() + + +@pytest.mark.asyncio +async def test_run_agent_injects_langfuse_metadata(monkeypatch): + monkeypatch.setenv("LANGFUSE_TRACING", "true") + monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "pk-lf-test") + monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-test") + from deerflow.config.tracing_config import reset_tracing_config + + reset_tracing_config() + + fake_agent = _FakeAgent() + + def agent_factory(config): + return fake_agent + + record = RunRecord( + run_id="run-1", + thread_id="thread-xyz", + assistant_id="lead-agent", + status=RunStatus.pending, + on_disconnect=DisconnectMode.cancel, + model_name="gpt-4o", + ) + record.abort_event = asyncio.Event() + ctx = RunContext(checkpointer=None) + + await run_agent( + _FakeBridge(), + _FakeRunManager(), + record, + ctx=ctx, + agent_factory=agent_factory, + graph_input={"messages": []}, + config={"configurable": {"thread_id": "thread-xyz"}}, + ) + + assert fake_agent.captured_config is not None, "astream was not invoked" + metadata = fake_agent.captured_config.get("metadata") or {} + assert metadata.get("langfuse_session_id") == "thread-xyz" + # conftest.py autouse fixture injects ``test-user-autouse`` into the + # contextvar — the worker should read it via ``get_effective_user_id``. + user_id = metadata.get("langfuse_user_id") + assert user_id == "test-user-autouse", f"expected test-user-autouse, got {user_id}" + assert metadata.get("langfuse_trace_name") == "lead-agent" + tags = metadata.get("langfuse_tags") or [] + assert "model:gpt-4o" in tags + + +@pytest.mark.asyncio +async def test_run_agent_falls_back_to_default_user_when_unset(monkeypatch): + """When no user is in the contextvar, langfuse_user_id falls back to 'default'. + + Uses ``monkeypatch.setattr`` to redirect ``get_effective_user_id`` to return + ``"default"`` rather than directly mutating the contextvar — direct contextvar + operations across pytest test boundaries have produced spooky cross-file + pollution when combined with the langfuse OTel global tracer provider. + """ + monkeypatch.setenv("LANGFUSE_TRACING", "true") + monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "pk-lf-test") + monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-test") + from deerflow.config.tracing_config import reset_tracing_config + from deerflow.runtime.runs import worker as worker_module + from deerflow.runtime.user_context import DEFAULT_USER_ID + + reset_tracing_config() + monkeypatch.setattr(worker_module, "get_effective_user_id", lambda: DEFAULT_USER_ID) + + fake_agent = _FakeAgent() + + def agent_factory(config): + return fake_agent + + record = RunRecord( + run_id="run-fallback", + thread_id="thread-fb", + assistant_id="lead-agent", + status=RunStatus.pending, + on_disconnect=DisconnectMode.cancel, + ) + record.abort_event = asyncio.Event() + ctx = RunContext(checkpointer=None) + + await run_agent( + _FakeBridge(), + _FakeRunManager(), + record, + ctx=ctx, + agent_factory=agent_factory, + graph_input={"messages": []}, + config={"configurable": {"thread_id": "thread-fb"}}, + ) + + metadata = fake_agent.captured_config.get("metadata") or {} + assert metadata.get("langfuse_user_id") == "default" + + +@pytest.mark.asyncio +async def test_run_agent_preserves_caller_metadata_overrides(monkeypatch): + """Caller-provided langfuse_* keys must NOT be overridden by the default injection.""" + monkeypatch.setenv("LANGFUSE_TRACING", "true") + monkeypatch.setenv("LANGFUSE_PUBLIC_KEY", "pk-lf-test") + monkeypatch.setenv("LANGFUSE_SECRET_KEY", "sk-lf-test") + from deerflow.config.tracing_config import reset_tracing_config + + reset_tracing_config() + + fake_agent = _FakeAgent() + + def agent_factory(config): + return fake_agent + + record = RunRecord( + run_id="run-2", + thread_id="thread-default", + assistant_id="lead-agent", + status=RunStatus.pending, + on_disconnect=DisconnectMode.cancel, + ) + record.abort_event = asyncio.Event() + ctx = RunContext(checkpointer=None) + + await run_agent( + _FakeBridge(), + _FakeRunManager(), + record, + ctx=ctx, + agent_factory=agent_factory, + graph_input={"messages": []}, + config={ + "configurable": {"thread_id": "thread-default"}, + "metadata": { + "langfuse_session_id": "custom-session-id", + "langfuse_user_id": "explicit-user", + }, + }, + ) + + metadata = fake_agent.captured_config.get("metadata") or {} + # Caller-supplied keys win. + assert metadata["langfuse_session_id"] == "custom-session-id" + assert metadata["langfuse_user_id"] == "explicit-user" + # Worker still fills in keys that the caller didn't set. + assert metadata["langfuse_trace_name"] == "lead-agent" + + +@pytest.mark.asyncio +async def test_run_agent_skips_metadata_when_langfuse_disabled(monkeypatch): + fake_agent = _FakeAgent() + + def agent_factory(config): + return fake_agent + + record = RunRecord( + run_id="run-3", + thread_id="thread-noop", + assistant_id="lead-agent", + status=RunStatus.pending, + on_disconnect=DisconnectMode.cancel, + ) + record.abort_event = asyncio.Event() + ctx = RunContext(checkpointer=None) + + await run_agent( + _FakeBridge(), + _FakeRunManager(), + record, + ctx=ctx, + agent_factory=agent_factory, + graph_input={"messages": []}, + config={"configurable": {"thread_id": "thread-noop"}}, + ) + + metadata = fake_agent.captured_config.get("metadata") or {} + assert "langfuse_session_id" not in metadata + assert "langfuse_user_id" not in metadata + assert "langfuse_trace_name" not in metadata From 1c5c585741f0cd6ed84a50eec9eb868229c9c4b0 Mon Sep 17 00:00:00 2001 From: Lawrance_YXLiao <32213920+kibabsquirrel@users.noreply.github.com> Date: Thu, 21 May 2026 20:35:46 +0800 Subject: [PATCH 61/86] fix(runtime): bound write_file execution-failure observations (#3133) * fix(runtime): bound write_file execution-failure observations * fix(runtime): preserve write_file error prefixes * test(runtime): trim write_file prefix assertions * refactor(runtime): drop redundant exception suffix for permission/directory write errors Address Copilot review on #3133: the PermissionError and IsADirectoryError branches now return self-contained, non-redundant messages (e.g. "Error: Permission denied writing to file: /mnt/...") via direct truncation, instead of going through _format_write_file_error which appended a duplicate ": PermissionError: permission denied" suffix. OSError, SandboxError and the generic Exception branches keep the unified "Failed to write file '{path}': {ExceptionType}: {detail}" format so the model still sees a stable, machine-readable error class. Removes the now-unused message= parameter from _format_write_file_error, keeping a single code path. Truncation contract (<= 2000 chars) and host-path sanitization unchanged. * fix(runtime): handle write_file sandbox init errors Initialize the requested path before sandbox setup so early sandbox failures can still return a bounded write_file error. Add a regression test for sandbox initialization failures. * style(test): format sandbox security tests --- .../harness/deerflow/sandbox/tools.py | 55 +++++- backend/tests/test_sandbox_tools_security.py | 165 ++++++++++++++++++ 2 files changed, 214 insertions(+), 6 deletions(-) diff --git a/backend/packages/harness/deerflow/sandbox/tools.py b/backend/packages/harness/deerflow/sandbox/tools.py index c8c0b06fb..6edc88882 100644 --- a/backend/packages/harness/deerflow/sandbox/tools.py +++ b/backend/packages/harness/deerflow/sandbox/tools.py @@ -42,6 +42,7 @@ _DEFAULT_GLOB_MAX_RESULTS = 200 _MAX_GLOB_MAX_RESULTS = 1000 _DEFAULT_GREP_MAX_RESULTS = 100 _MAX_GREP_MAX_RESULTS = 500 +_DEFAULT_WRITE_FILE_ERROR_MAX_CHARS = 2000 _LOCAL_BASH_CWD_COMMANDS = {"cd", "pushd"} _LOCAL_BASH_COMMAND_WRAPPERS = {"command", "builtin"} _LOCAL_BASH_COMMAND_PREFIX_KEYWORDS = {"!", "{", "case", "do", "elif", "else", "for", "if", "select", "then", "time", "until", "while"} @@ -435,6 +436,42 @@ def _sanitize_error(error: Exception, runtime: Runtime | None = None) -> str: return msg +def _truncate_write_file_error_detail(detail: str, max_chars: int) -> str: + """Middle-truncate write_file error details, preserving the head and tail.""" + if max_chars == 0: + return detail + if len(detail) <= max_chars: + return detail + total = len(detail) + marker_max_len = len(f"\n... [write_file error truncated: {total} chars skipped] ...\n") + kept = max(0, max_chars - marker_max_len) + if kept == 0: + return detail[:max_chars] + head_len = kept // 2 + tail_len = kept - head_len + skipped = total - kept + marker = f"\n... [write_file error truncated: {skipped} chars skipped] ...\n" + return f"{detail[:head_len]}{marker}{detail[-tail_len:] if tail_len > 0 else ''}" + + +def _format_write_file_error( + requested_path: str, + error: Exception, + runtime: Runtime | None = None, + *, + max_chars: int = _DEFAULT_WRITE_FILE_ERROR_MAX_CHARS, +) -> str: + """Return a bounded, sanitized error string for write_file failures.""" + header = f"Error: Failed to write file '{requested_path}'" + detail = _sanitize_error(error, runtime) + if max_chars == 0: + return f"{header}: {detail}" + detail_budget = max_chars - len(header) - 2 + if detail_budget <= 0: + return _truncate_write_file_error_detail(f"{header}: {detail}", max_chars) + return f"{header}: {_truncate_write_file_error_detail(detail, detail_budget)}" + + def replace_virtual_path(path: str, thread_data: ThreadDataState | None) -> str: """Replace virtual /mnt/user-data paths with actual thread data paths. @@ -1651,9 +1688,9 @@ def write_file_tool( append: Whether to append content to the end of the file instead of overwriting it. Defaults to false. """ try: + requested_path = path sandbox = ensure_sandbox_initialized(runtime) ensure_thread_directories_exist(runtime) - requested_path = path if is_local_sandbox(runtime): thread_data = get_thread_data(runtime) validate_local_tool_path(path, thread_data) @@ -1664,15 +1701,21 @@ def write_file_tool( sandbox.write_file(path, content, append) return "OK" except SandboxError as e: - return f"Error: {e}" + return _format_write_file_error(requested_path, e, runtime) except PermissionError: - return f"Error: Permission denied writing to file: {requested_path}" + return _truncate_write_file_error_detail( + f"Error: Permission denied writing to file: {requested_path}", + _DEFAULT_WRITE_FILE_ERROR_MAX_CHARS, + ) except IsADirectoryError: - return f"Error: Path is a directory, not a file: {requested_path}" + return _truncate_write_file_error_detail( + f"Error: Path is a directory, not a file: {requested_path}", + _DEFAULT_WRITE_FILE_ERROR_MAX_CHARS, + ) except OSError as e: - return f"Error: Failed to write file '{requested_path}': {_sanitize_error(e, runtime)}" + return _format_write_file_error(requested_path, e, runtime) except Exception as e: - return f"Error: Unexpected error writing file: {_sanitize_error(e, runtime)}" + return _format_write_file_error(requested_path, e, runtime) async def _write_file_tool_async( diff --git a/backend/tests/test_sandbox_tools_security.py b/backend/tests/test_sandbox_tools_security.py index 57466a0fe..d43a1fcf0 100644 --- a/backend/tests/test_sandbox_tools_security.py +++ b/backend/tests/test_sandbox_tools_security.py @@ -5,6 +5,7 @@ from unittest.mock import patch import pytest +from deerflow.sandbox.exceptions import SandboxError from deerflow.sandbox.tools import ( VIRTUAL_PATH_PREFIX, _apply_cwd_prefix, @@ -1140,6 +1141,170 @@ def test_str_replace_and_append_on_same_path_should_preserve_both_updates(monkey assert sandbox.content == "ALPHA\ntail\n" +def test_write_file_tool_bounds_large_oserror_and_masks_local_paths(monkeypatch) -> None: + class FailingSandbox: + id = "sandbox-write-large-oserror" + + def write_file(self, path: str, content: str, append: bool = False) -> None: + host_path = f"{_THREAD_DATA['workspace_path']}/nested/output.txt" + raise OSError(f"write failed at {host_path}\n{'A' * 12000}\nremote tail marker") + + runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}) + sandbox = FailingSandbox() + + monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox) + monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None) + monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: True) + monkeypatch.setattr("deerflow.sandbox.tools.get_thread_data", lambda runtime: _THREAD_DATA) + monkeypatch.setattr("deerflow.sandbox.tools.validate_local_tool_path", lambda path, thread_data: None) + monkeypatch.setattr( + "deerflow.sandbox.tools._resolve_and_validate_user_data_path", + lambda path, thread_data: f"{_THREAD_DATA['workspace_path']}/output.txt", + ) + + result = write_file_tool.func( + runtime=runtime, + description="写入大文件失败", + path="/mnt/user-data/workspace/output.txt", + content="report body", + ) + + assert len(result) <= 2000 + assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result + assert "/tmp/deer-flow/threads/t1/user-data/workspace" not in result + assert "/mnt/user-data/workspace/nested/output.txt" in result + assert "remote tail marker" in result + assert "[write_file error truncated:" in result + + +def test_write_file_tool_preserves_short_oserror_without_truncation(monkeypatch) -> None: + class FailingSandbox: + id = "sandbox-write-short-oserror" + + def write_file(self, path: str, content: str, append: bool = False) -> None: + raise OSError("disk quota exceeded") + + runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}) + sandbox = FailingSandbox() + + monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox) + monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None) + monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False) + + result = write_file_tool.func( + runtime=runtime, + description="写入失败", + path="/mnt/user-data/workspace/output.txt", + content="tiny payload", + ) + + assert result == "Error: Failed to write file '/mnt/user-data/workspace/output.txt': OSError: disk quota exceeded" + assert "[write_file error truncated:" not in result + + +def test_write_file_tool_bounds_large_sandbox_error(monkeypatch) -> None: + class FailingSandbox: + id = "sandbox-write-large-sandbox-error" + + def write_file(self, path: str, content: str, append: bool = False) -> None: + raise SandboxError(f"remote write rejected {'B' * 12000} final detail") + + runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}) + sandbox = FailingSandbox() + + monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox) + monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None) + monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False) + + result = write_file_tool.func( + runtime=runtime, + description="远端写入失败", + path="/mnt/user-data/workspace/output.txt", + content="tiny payload", + ) + + assert len(result) <= 2000 + assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result + assert "SandboxError: remote write rejected" in result + assert "final detail" in result + assert "[write_file error truncated:" in result + + +@pytest.mark.parametrize( + ("raised_error", "expected_fragment"), + [ + pytest.param( + PermissionError("permission denied"), + "Error: Permission denied writing to file: /mnt/user-data/workspace/output.txt", + id="permission", + ), + pytest.param( + IsADirectoryError("target is a directory"), + "Error: Path is a directory, not a file: /mnt/user-data/workspace/output.txt", + id="directory", + ), + pytest.param( + Exception("remote sandbox timeout"), + "Exception: remote sandbox timeout", + id="generic", + ), + ], +) +def test_write_file_tool_formats_all_other_failure_branches( + monkeypatch, + raised_error: Exception, + expected_fragment: str, +) -> None: + class FailingSandbox: + id = "sandbox-write-other-failure" + + def write_file(self, path: str, content: str, append: bool = False) -> None: + raise raised_error + + runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}) + sandbox = FailingSandbox() + + monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: sandbox) + monkeypatch.setattr("deerflow.sandbox.tools.ensure_thread_directories_exist", lambda runtime: None) + monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False) + + result = write_file_tool.func( + runtime=runtime, + description="验证错误分支格式化", + path="/mnt/user-data/workspace/output.txt", + content="tiny payload", + ) + + assert "/mnt/user-data/workspace/output.txt" in result + assert expected_fragment in result + assert "[write_file error truncated:" not in result + + +def test_write_file_tool_handles_sandbox_init_failure(monkeypatch) -> None: + """Regression for #3133 review: SandboxError raised during sandbox + initialization (before the local `requested_path` assignment) must still + surface as a bounded tool error rather than an UnboundLocalError. + """ + + def raise_sandbox_error(runtime): + raise SandboxError("sandbox missing") + + runtime = SimpleNamespace(state={}, context={"thread_id": "thread-1"}, config={}) + monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", raise_sandbox_error) + monkeypatch.setattr("deerflow.sandbox.tools.is_local_sandbox", lambda runtime: False) + + result = write_file_tool.func( + runtime=runtime, + description="sandbox 初始化失败", + path="/mnt/user-data/workspace/output.txt", + content="tiny payload", + ) + + assert "Error: Failed to write file '/mnt/user-data/workspace/output.txt':" in result + assert "SandboxError: sandbox missing" in result + assert "[write_file error truncated:" not in result + + def test_file_operation_lock_memory_cleanup() -> None: """Verify that released locks are eventually cleaned up by WeakValueDictionary. From 9c03a71a07e4365be6799c70270bd3cc6a90ef8e Mon Sep 17 00:00:00 2001 From: Xinmin Zeng <135568692+fancyboi999@users.noreply.github.com> Date: Thu, 21 May 2026 21:06:19 +0800 Subject: [PATCH 62/86] fix(gateway): preserve message additional_kwargs in normalize_input (#3132) (#3136) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(gateway): preserve message additional_kwargs in normalize_input (#3132) The gateway's hand-rolled dict→message coercion only forwarded `content` and collapsed every role to `HumanMessage`, silently dropping the frontend's `additional_kwargs.files` payload (along with `id`, `name`, and ai/system/tool roles). Effect on issue #3132: - `UploadsMiddleware` saw no `files` on the last human message, so the just-uploaded file got bucketed under "previous messages" while the current turn was reported as `(empty)`. - The persisted human message had no `files`, so the attachment chip on the message disappeared the moment the optimistic UI cleared. Delegate the conversion to `langchain_core.messages.utils.convert_to_messages` so `additional_kwargs`, `id`, `name`, and non-human roles round-trip unchanged. * fix(gateway): convert malformed-message ValueError into HTTP 400 normalize_input now sits at the request boundary, so a malformed input.messages[N] dict (missing role/type/content, unsupported role, etc.) should surface as 400 with the offending index — not bubble out of FastAPI as 500. Per Copilot review on #3136. --- backend/app/gateway/services.py | 39 ++++++++---- backend/tests/test_gateway_services.py | 88 ++++++++++++++++++++++++++ 2 files changed, 115 insertions(+), 12 deletions(-) diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index 4713d303e..95e26144a 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -15,7 +15,8 @@ from collections.abc import Mapping from typing import Any from fastapi import HTTPException, Request -from langchain_core.messages import HumanMessage +from langchain_core.messages import BaseMessage +from langchain_core.messages.utils import convert_to_messages from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge from app.gateway.utils import sanitize_log_param @@ -76,21 +77,35 @@ def normalize_stream_modes(raw: list[str] | str | None) -> list[str]: def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]: - """Convert LangGraph Platform input format to LangChain state dict.""" + """Convert LangGraph Platform input format to LangChain state dict. + + Delegates dict→message coercion to ``langchain_core.messages.utils.convert_to_messages`` + so that ``additional_kwargs`` (e.g. uploaded-file metadata — gh #3132), ``id``, + ``name``, and non-human roles (ai/system/tool) survive unchanged. An earlier + hand-rolled version only forwarded ``content`` and collapsed every role to + ``HumanMessage``, which silently stripped frontend-supplied attachments. + + Malformed message dicts (missing ``role``/``type``/``content``, unsupported + role, etc.) raise ``HTTPException(400)`` with the offending index, instead + of bubbling up as a 500. The gateway is a system boundary, so per-entry + validation errors are the right shape for clients to retry against. + """ if raw_input is None: return {} messages = raw_input.get("messages") if messages and isinstance(messages, list): - converted = [] - for msg in messages: - if isinstance(msg, dict): - role = msg.get("role", msg.get("type", "user")) - content = msg.get("content", "") - if role in ("user", "human"): - converted.append(HumanMessage(content=content)) - else: - # TODO: handle other message types (system, ai, tool) - converted.append(HumanMessage(content=content)) + converted: list[Any] = [] + for index, msg in enumerate(messages): + if isinstance(msg, BaseMessage): + converted.append(msg) + elif isinstance(msg, dict): + try: + converted.extend(convert_to_messages([msg])) + except (ValueError, TypeError, NotImplementedError) as exc: + raise HTTPException( + status_code=400, + detail=f"Invalid message at input.messages[{index}]: {exc}", + ) from exc else: converted.append(msg) return {**raw_input, "messages": converted} diff --git a/backend/tests/test_gateway_services.py b/backend/tests/test_gateway_services.py index aa9e20e78..2ccd372bf 100644 --- a/backend/tests/test_gateway_services.py +++ b/backend/tests/test_gateway_services.py @@ -81,6 +81,94 @@ def test_normalize_input_passthrough(): assert result == {"custom_key": "value"} +def test_normalize_input_preserves_additional_kwargs_and_id(): + """Regression: gh #3132 — frontend ships uploaded-file metadata in + additional_kwargs.files (and a client-side message id). The gateway must + not strip them before the graph runs, otherwise UploadsMiddleware reports + "(empty)" for new uploads and the frontend message loses its file chip. + """ + from langchain_core.messages import HumanMessage + + from app.gateway.services import normalize_input + + files = [{"filename": "a.csv", "size": 100, "path": "/mnt/user-data/uploads/a.csv", "status": "uploaded"}] + result = normalize_input( + { + "messages": [ + { + "type": "human", + "id": "client-msg-1", + "name": "user-input", + "content": [{"type": "text", "text": "clean it"}], + "additional_kwargs": {"files": files, "custom": "keep-me"}, + } + ] + } + ) + assert len(result["messages"]) == 1 + msg = result["messages"][0] + assert isinstance(msg, HumanMessage) + assert msg.id == "client-msg-1" + assert msg.name == "user-input" + assert msg.content == [{"type": "text", "text": "clean it"}] + assert msg.additional_kwargs == {"files": files, "custom": "keep-me"} + + +def test_normalize_input_passes_through_basemessage_instances(): + from langchain_core.messages import HumanMessage + + from app.gateway.services import normalize_input + + msg = HumanMessage(content="hello", id="m-1", additional_kwargs={"files": [{"filename": "x"}]}) + result = normalize_input({"messages": [msg]}) + assert result["messages"][0] is msg + + +def test_normalize_input_rejects_malformed_message_with_400(): + """Boundary validation: ``convert_to_messages`` raises ``ValueError`` when a + message dict is missing ``role``/``type``/``content``. ``normalize_input`` + runs inside the gateway HTTP boundary, so a malformed payload should surface + as a 400 referencing the offending entry — not bubble up as a 500. + + Raised after the Copilot review on PR #3136. + """ + import pytest + from fastapi import HTTPException + + from app.gateway.services import normalize_input + + with pytest.raises(HTTPException) as excinfo: + normalize_input({"messages": [{"role": "human", "content": "ok"}, {"oops": "no role here"}]}) + assert excinfo.value.status_code == 400 + assert "input.messages[1]" in excinfo.value.detail + + +def test_normalize_input_handles_non_human_roles(): + """The previous implementation collapsed every role to HumanMessage with a + `# TODO: handle other message types` comment. Resuming a thread with prior + AI/tool messages would silently rewrite them as human turns — corrupting + the conversation. Use langchain's standard conversion so ai/system/tool + roles round-trip correctly. + """ + from langchain_core.messages import AIMessage, SystemMessage, ToolMessage + + from app.gateway.services import normalize_input + + result = normalize_input( + { + "messages": [ + {"role": "system", "content": "sys"}, + {"role": "ai", "content": "hi", "id": "ai-1"}, + {"role": "tool", "content": "result", "tool_call_id": "call-1"}, + ] + } + ) + types = [type(m) for m in result["messages"]] + assert types == [SystemMessage, AIMessage, ToolMessage] + assert result["messages"][1].id == "ai-1" + assert result["messages"][2].tool_call_id == "call-1" + + def test_build_run_config_basic(): from app.gateway.services import build_run_config From 4cb2a22400c4087bd098c2f46dcb8557888f7eb5 Mon Sep 17 00:00:00 2001 From: john lee <64lamei@gmail.com> Date: Thu, 21 May 2026 21:13:24 +0800 Subject: [PATCH 63/86] =?UTF-8?q?docs(config.example):=20fix=20Claude=20th?= =?UTF-8?q?inking=20example=20=E2=80=94=20add=20supports=5Fthinking=20and?= =?UTF-8?q?=20budget=5Ftokens=20(#3068)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The commented Claude example used Claude 3.5 Sonnet with when_thinking_enabled but lacked supports_thinking: true. Copying the block and swapping to a Claude 4 model name would silently fall back to non-thinking mode (agent.py line 380 suppresses the error and logs only a warning). A second trap: budget_tokens is required by the Anthropic API when thinking.type == "enabled"; there is no server default. The old example omitted it, so any user who did add supports_thinking: true would get an API error on the first thinking request. Replace with a Claude Sonnet 4 example that includes both fields and inline comments explaining the constraints. Closes #2336 Co-authored-by: Claude Sonnet 4.6 --- config.example.yaml | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/config.example.yaml b/config.example.yaml index 7396f6cfb..9ea4e4c08 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -118,19 +118,25 @@ models: # For Docker deployments, use host.docker.internal instead of localhost: # base_url: http://host.docker.internal:11434 - # Example: Anthropic Claude model - # - name: claude-3-5-sonnet - # display_name: Claude 3.5 Sonnet + # Example: Anthropic Claude model (with extended thinking) + # supports_thinking: true is required — without it, DeerFlow silently falls + # back to non-thinking mode even when the UI thinking toggle is on. + # budget_tokens is required by the Anthropic API when thinking.type=enabled + # (no server default; min 1024; must be less than max_tokens). + # - name: claude-sonnet-4 + # display_name: Claude Sonnet 4 # use: langchain_anthropic:ChatAnthropic - # model: claude-3-5-sonnet-20241022 + # model: claude-sonnet-4-20250514 # api_key: $ANTHROPIC_API_KEY # default_request_timeout: 600.0 # max_retries: 2 - # max_tokens: 8192 - # supports_vision: true # Enable vision support for view_image tool + # max_tokens: 16000 + # supports_vision: true + # supports_thinking: true # when_thinking_enabled: # thinking: # type: enabled + # budget_tokens: 4096 # required; min 1024; must be < max_tokens # when_thinking_disabled: # thinking: # type: disabled From e93f65847207f8cbcf03587218e8d4ebaa51abcf Mon Sep 17 00:00:00 2001 From: Xinmin Zeng <135568692+fancyboi999@users.noreply.github.com> Date: Thu, 21 May 2026 21:18:10 +0800 Subject: [PATCH 64/86] fix(stability): resolve P0 blockers from v2.0-m1-rc1 stability audit (#3107) (#3131) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(task-tool): unwrap callback manager when locating usage recorder `config["callbacks"]` may arrive as a `BaseCallbackManager` (e.g. the `AsyncCallbackManager` LangChain hands to async tool runs), not just a plain list. The previous `for cb in callbacks` loop raised `TypeError: 'AsyncCallbackManager' object is not iterable`, which `ToolErrorHandlingMiddleware` then converted into a failed `task` ToolMessage even though the subagent had completed internally — Ultra mode lost subagent results and the lead agent fell back to redoing the work. Unwrap `BaseCallbackManager.handlers` before searching for the recorder. Refs: bytedance/deer-flow#3107 (BUG-002) * fix(frontend): treat any task tool error as a terminal subtask failure The subtask card status machine matched only three English prefixes (`Task Succeeded. Result:`, `Task failed.`, `Task timed out`). Anything else fell through to `in_progress`, so a `task` tool error wrapped by `ToolErrorHandlingMiddleware` (`Error: Tool 'task' failed ...`) left the card spinning forever even after the run had ended. Extract the prefix logic into `parseSubtaskResult` and recognise any leading `Error:` token as a terminal failure. The extracted function is unit-tested against the legacy prefixes plus the `AsyncCallbackManager` regression captured in the upstream issue. Refs: bytedance/deer-flow#3107 (BUG-007) * fix(frontend): exclude hidden, reasoning, and tool payloads from chat export `formatThreadAsMarkdown` / `formatThreadAsJSON` iterated raw messages without running the UI-level `isHiddenFromUIMessage` filter. Exported transcripts therefore included `hide_from_ui` system reminders, memory injections, provider `reasoning_content`, tool calls, and tool result messages — content that is intentionally hidden in the chat view. Filter the export to the user-visible transcript by default and gate reasoning / tool calls / tool messages / hidden messages behind explicit `ExportOptions` flags so a future debug export can opt back in without forking the formatter. Refs: bytedance/deer-flow#3107 (BUG-006) * fix(gateway): route get_config through get_app_config for mtime hot reload `get_config(request)` returned the `app.state.config` snapshot captured at startup. The worker / lead-agent path then threaded that frozen `AppConfig` through `RunContext` and `agent_factory`, so per-run fields edited in `config.yaml` (notably `max_tokens`) were ignored until the gateway process was restarted — even though `get_app_config()` already does mtime-based reload at the bottom layer. Route the request dependency through `get_app_config()` directly. Runtime `ContextVar` overrides (`push_current_app_config`) and test-injected singletons (`set_app_config`) keep working; `app.state.config` is now only read at startup for one-shot bootstrap (logging level, IM channels, `langgraph_runtime` engines). `tests/test_gateway_deps_config.py` encoded the old snapshot contract and is removed; `tests/test_gateway_config_freshness.py` replaces it with mtime, ContextVar, and `set_app_config` coverage. `test_skills_custom_router.py` and `test_uploads_router.py` now inject test configs via FastAPI `dependency_overrides[get_config]` instead of mutating `app.state.config`. Document the hot-reload boundary in `backend/CLAUDE.md` so reviewers know which fields are picked up on the next request vs. which still require a restart (`database`, `checkpointer`, `run_events`, `stream_bridge`, `sandbox.use`, `log_level`, `channels.*`). Refs: bytedance/deer-flow#3107 (BUG-001) * fix(gateway): broaden get_config 503 to any config-load failure Address review feedback on the previous commit: 1. Narrow exception catch removed. The old contract returned 503 whenever `app.state.config is None`. The first cut only mapped `FileNotFoundError`, leaving `PermissionError`, YAML parse errors, and pydantic `ValidationError` to bubble up as 500. At the request boundary we treat any inability to materialise the config as "configuration not available" (503) and log the original exception so the operator still has the stack. 2. Removed the unused `request: Request` parameter and the matching `# noqa: ARG001`. FastAPI's `Depends()` does not require the dependency to accept `Request`; the only call site uses the no-arg form. 3. `backend/CLAUDE.md` boundary now lists the *reason* each field is restart-required (engine binding, singleton caching, one-shot `apply_logging_level`, etc.), not just the field name, so reviewers do not have to reverse-engineer the boundary themselves. Tests parametrise four exception classes (`FileNotFoundError`, `PermissionError`, `ValueError`, `RuntimeError`) and assert 503 for each. Refs: bytedance/deer-flow#3107 (BUG-001) * fix(task-tool): defend _find_usage_recorder against non-list callbacks Address review feedback. The previous commit handled the two common shapes LangChain hands to async tool runs — a plain `list[BaseCallbackHandler]` and a `BaseCallbackManager` subclass — but iterated any other shape directly, which would still raise `TypeError` if e.g. a single handler instance leaked through without a list wrapper. Treat any non-list, non-manager `config["callbacks"]` value as "no recorder" rather than crash. Docstring now lists all four shapes explicitly. New tests cover the single-handler-object case, `runtime is None`, `callbacks is None`, and `runtime.config` being a non-dict — all required to be silent no-ops. Refs: bytedance/deer-flow#3107 (BUG-002) * fix(frontend): drop dead identity ternary and add opt-in export tests Address review feedback on the previous export commit: 1. Removed the no-op `typeof msg.content === "string" ? msg.content : msg.content` expression in `formatThreadAsJSON`. Both branches returned the same value; the message content now flows through unchanged whether it is a string or the rich `MessageContent[]` shape (LangChain JSON-serialises the array structure correctly already). 2. Expanded the JSDoc on `ExportOptions` to make it clearer that the four flags are not currently wired to any UI control — callers wanting a debug export must build the options object explicitly. The default behaviour continues to match the explicit prescription in bytedance/deer-flow#3107 BUG-006. 3. Added opt-in coverage. The previous tests only exercised the `options = {}` default path; the new cases verify each flag flips the corresponding payload back into the export so a future debug-export surface does not silently break the contract. Refs: bytedance/deer-flow#3107 (BUG-006) * fix(frontend): export subtask prefix constants and document fallback intent Address review feedback on the previous BUG-007 commit: 1. `SUCCESS_PREFIX`, `FAILURE_PREFIX`, `TIMEOUT_PREFIX`, and the `ERROR_WRAPPER_PATTERN` regex are now exported. The JSDoc explicitly pins them as part of the backend↔frontend contract defined in `task_tool.py` and `tool_error_handling_middleware.py`, so any future structured-status migration (e.g. backend writing `additional_kwargs.subagent_status` instead of leading text) can reference these from one canonical place rather than redefine them. 2. The `in_progress` fallback now carries a docstring explaining the deliberate choice — LangChain only ever emits a `ToolMessage` once the tool itself has returned, so unrecognised content means the contract has drifted and "still running" is the right operator signal (eagerly marking it terminal-failed would mask the drift). No behaviour change; this is documentation and an API export. Refs: bytedance/deer-flow#3107 (BUG-007) * fix(gateway): drop app.state.config snapshot and freeze run_events_config Address @ShenAC-SAC's BUG-001 review on #3131. The previous cut still stored an ``AppConfig`` snapshot on ``app.state.config`` for startup bootstrap. Two follow-on hazards from that: 1. Future code touching the gateway lifespan could accidentally start reading ``app.state.config`` again, silently regressing the request hot path back to a stale snapshot. 2. ``get_run_context()`` paired a freshly-reloaded ``AppConfig`` with the startup-bound ``event_store`` and a *live* ``run_events_config`` field — so an operator who edited ``run_events.backend`` mid-flight would have produced a run context whose ``event_store`` and ``run_events_config`` referred to different backends. Clean approach (aligned with the direction in PR #3128): - ``lifespan()`` keeps a local ``startup_config`` variable and passes it explicitly into ``langgraph_runtime(app, startup_config)`` and into ``start_channel_service``. No ``app.state.config`` attribute is set at any point. - ``langgraph_runtime`` now accepts ``startup_config`` as a required parameter, removing the ``getattr(app.state, "config", None)`` lookup and the "config not initialised" runtime error. - The matching ``run_events_config`` is frozen onto ``app.state`` next to ``run_event_store`` so ``get_run_context`` reads the two from the same startup-time source. ``app_config`` continues to be resolved live via ``get_app_config()``. - ``backend/CLAUDE.md`` boundary explanation updated to spell out the ``startup_config`` / ``get_app_config()`` split. New regression test ``test_run_context_app_config_reflects_yaml_edit`` exercises the worker-feeding path: it asserts that ``ctx.app_config`` follows a mid-flight ``config.yaml`` edit while ``ctx.run_events_config`` stays frozen to the startup snapshot the event store was built from. Refs: bytedance/deer-flow#3107 (BUG-001), bytedance/deer-flow#3131 review * fix(frontend): parse Task cancelled and polling timed out as terminal Address @ShenAC-SAC's BUG-007 review on #3131. `task_tool.py` actually emits five terminal strings: - `Task Succeeded. Result: …` - `Task failed. …` - `Task timed out. …` - `Task cancelled by user.` ← previously matched none - `Task polling timed out after N minutes …` ← previously matched none The previous cut handled three; the last two fell through to the "unknown content" branch and pushed the subtask card back to `in_progress` even though the backend had already reached a terminal state. Add explicit matches plus regression tests for both. The `in_progress` fallback is now reserved for genuinely unrecognised output (i.e. contract drift), as documented. Refs: bytedance/deer-flow#3107 (BUG-007), bytedance/deer-flow#3131 review * fix(frontend): sanitize JSON export content via the Markdown content path Address @ShenAC-SAC's BUG-006 review and the Copilot inline comment on #3131. The previous cut filtered hidden/tool messages out of the JSON export but still serialised `msg.content` verbatim, so: - inline `` wrappers stayed in the exported `content` even with `includeReasoning: false`, - content-array thinking blocks leaked the `thinking` field, - `` markers leaked the workspace paths a user uploaded files to. JSON now goes through the same sanitiser the Markdown path uses (`extractContentFromMessage` + `stripUploadedFilesTag`). Reasoning and tool_calls remain gated behind their `ExportOptions` flags. AI / human rows that sanitise to empty content with no opted-in reasoning or tool calls are dropped so the JSON matches the Markdown path's `continue` on empty assistant fragments. New regression tests cover the three leak shapes the reviewer called out plus the empty-content-drop case. Refs: bytedance/deer-flow#3107 (BUG-006), bytedance/deer-flow#3131 review * test(gateway): align lifespan stub with langgraph_runtime two-arg signature Codex round-3 review of c0bc7a06 flagged this: changing `langgraph_runtime` to require `startup_config` as a second positional argument broke the one-arg stub `_noop_langgraph_runtime(_app)` in `test_gateway_lifespan_shutdown.py`, which is patched into `app.gateway.app.langgraph_runtime` by the lifespan shutdown bounded-timeout regression. Lifespan would then call the stub with two args and raise `TypeError` before the bounded-shutdown assertion ran. Update the stub to match the new signature. The shutdown test itself is unaffected — it only cares about the channel `stop_channel_service` hang path. Refs: bytedance/deer-flow#3107 (BUG-001), bytedance/deer-flow#3131 review * fix(frontend): strip every known backend marker in export, not just uploads Codex round-3 review of 258ca800 and the matching maintainer feedback on PR #3131 made the same point: the JSON export now ran the Markdown-side sanitiser, but that sanitiser only stripped ``. The full set of payloads middleware embeds inside message `content` is larger: - `` — `UploadsMiddleware` - `` — `DynamicContextMiddleware` - `` — `DynamicContextMiddleware` (nested inside system-reminder) - `` — `DynamicContextMiddleware` The primary protection is still `isHiddenFromUIMessage`: the `` HumanMessage is marked `hide_from_ui: true` and never reaches the formatter. This commit adds the second line of defence so a regression that drops the `hide_from_ui` flag — or any future middleware that injects the same tag vocabulary into a visible HumanMessage — cannot leak the payload into the export file. Concrete changes: - New `INTERNAL_MARKER_TAGS` constant + `stripInternalMarkers(content)` helper in `core/messages/utils.ts`. The constant doubles as documentation for the backend↔frontend contract. - `formatMessageContent` in `export.ts` now calls `stripInternalMarkers` instead of `stripUploadedFilesTag`. UI render paths (`message-list-item.tsx`) keep using the narrower function so a user legitimately typing `` in a meta-discussion is preserved. - The "drop empty rows" guard in `buildJSONMessage` switched from `=== undefined` to truthy `!` checks. Codex spotted the asymmetry: when `extractReasoningContentFromMessage` returned the empty string (which it legitimately can), the JSON path emitted `{reasoning: ""}` while the Markdown path's `!reasoning` `continue` correctly dropped the row. New regression tests cover the defence-in-depth strip with a `` payload deliberately *not* marked `hide_from_ui`; tool-message sanitization under `includeToolMessages: true`; the mixed-content-array case (`thinking + text + image_url`); and the opted-in empty-reasoning drop. Live verification on a real Ultra-mode thread that uploaded a PDF (`曾鑫民-薪资交易流水.pdf`): backend state's first HumanMessage carries the `` block (with `/mnt/user-data/uploads/...` paths) as part of a content-array. The Markdown and JSON export blobs both come back free of ``, ``, ``, `tool_calls`, and reasoning — while preserving the user's `这是什么 ?` prompt and the assistant's visible answer. Refs: bytedance/deer-flow#3107 (BUG-006), bytedance/deer-flow#3131 review * test(frontend): cover trim, varied N, and pre-execution Error: prefixes Codex round-3 review of 50e2c257 flagged three coverage gaps in the subtask-status parser: 1. `Task cancelled by user.` and `Task polling timed out` previously had no whitespace-trim coverage — the original trim test only exercised the success prefix. Streaming chunks can arrive with leading/trailing newlines; the regex needed an explicit assertion. 2. The polling-timeout case was tested only at one `N` (15 minutes). The backend interpolates the live `timeout_seconds // 60` value, so the matcher must hold for any positive integer. Now we run the case for 1, 5, and 60 minutes. 3. `task_tool.py` also emits three `Error:` strings for pre-execution failures — unknown subagent type, host-bash disabled, and "task disappeared from background tasks". They are intentionally handled by `ERROR_WRAPPER_PATTERN` rather than dedicated prefixes (the wrapper already produces the right terminal-failed shape) but had no test coverage proving that wiring. Codex was right that a refactor splitting one of them off into its own prefix would silently break things. The JSDoc on the constants block now spells the three pre-execution errors out so the relationship between `task_tool.py` returns and the prefix vocabulary is explicit. No production code change beyond the docstring — this commit is pure coverage hardening for the contract that already exists. Refs: bytedance/deer-flow#3107 (BUG-007), bytedance/deer-flow#3131 review --- backend/CLAUDE.md | 12 + backend/app/gateway/app.py | 16 +- backend/app/gateway/deps.py | 86 ++++- .../deerflow/tools/builtins/task_tool.py | 21 +- .../tests/test_gateway_config_freshness.py | 189 +++++++++++ backend/tests/test_gateway_deps_config.py | 41 --- .../tests/test_gateway_lifespan_shutdown.py | 2 +- backend/tests/test_skills_custom_router.py | 4 +- .../tests/test_task_tool_usage_recorder.py | 91 +++++ backend/tests/test_uploads_router.py | 2 + .../workspace/messages/message-list.tsx | 32 +- frontend/src/core/messages/utils.ts | 44 +++ frontend/src/core/tasks/subtask-result.ts | 88 +++++ frontend/src/core/threads/export.ts | 110 +++++- .../unit/core/tasks/subtask-result.test.ts | 112 +++++++ .../tests/unit/core/threads/export.test.ts | 317 ++++++++++++++++++ 16 files changed, 1060 insertions(+), 107 deletions(-) create mode 100644 backend/tests/test_gateway_config_freshness.py delete mode 100644 backend/tests/test_gateway_deps_config.py create mode 100644 backend/tests/test_task_tool_usage_recorder.py create mode 100644 frontend/src/core/tasks/subtask-result.ts create mode 100644 frontend/tests/unit/core/tasks/subtask-result.test.ts create mode 100644 frontend/tests/unit/core/threads/export.test.ts diff --git a/backend/CLAUDE.md b/backend/CLAUDE.md index 8c4711395..f04774050 100644 --- a/backend/CLAUDE.md +++ b/backend/CLAUDE.md @@ -184,6 +184,18 @@ Setup: Copy `config.example.yaml` to `config.yaml` in the **project root** direc **Config Caching**: `get_app_config()` caches the parsed config, but automatically reloads it when the resolved config path changes or the file's mtime increases. This keeps Gateway and LangGraph reads aligned with `config.yaml` edits without requiring a manual process restart. +**Config Hot-Reload Boundary**: Gateway dependencies route through `get_app_config()` on every request, so per-run fields like `models[*].max_tokens`, `summarization.*`, `title.*`, `memory.*`, `subagents.*`, `tools[*]`, and the agent system prompt pick up `config.yaml` edits on the next message. `AppConfig` is intentionally **not** cached on `app.state` — `lifespan()` keeps a local `startup_config` variable for one-shot bootstrap work (logging level, channels, `langgraph_runtime` engines) and passes it explicitly to `langgraph_runtime(app, startup_config)`. Infrastructure fields are **restart-required**: + +| Field | Why a restart is required | +|---|---| +| `database.*` | `init_engine_from_config()` runs once during `langgraph_runtime()` startup; the SQLAlchemy engine holds the connection pool. | +| `checkpointer.*` (including SQLite WAL/journal settings) | `make_checkpointer()` binds the persistent checkpointer once at startup. | +| `run_events.*` | `make_run_event_store()` selects memory- vs. SQL-backed implementation at startup. | +| `stream_bridge.*` | `make_stream_bridge()` constructs the bridge object once. | +| `sandbox.use` | `get_sandbox_provider()` caches the provider singleton (`_default_sandbox_provider`); a new class path takes effect only on next process start. | +| `log_level` | `apply_logging_level()` is called only in `app.py` startup; it mutates the root logger's level, and `get_app_config()` returning a fresh `AppConfig` does not retrigger it. | +| `channels.*` IM platform credentials | `start_channel_service()` is invoked once during startup; live channels are not rebuilt on config change. | + Configuration priority: 1. Explicit `config_path` argument 2. `DEER_FLOW_CONFIG_PATH` environment variable diff --git a/backend/app/gateway/app.py b/backend/app/gateway/app.py index 2c13f571c..8baecb363 100644 --- a/backend/app/gateway/app.py +++ b/backend/app/gateway/app.py @@ -161,10 +161,16 @@ async def _migrate_orphaned_threads(store, admin_user_id: str) -> int: async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: """Application lifespan handler.""" - # Load config and check necessary environment variables at startup + # Load config and check necessary environment variables at startup. + # `startup_config` is a local snapshot used only for one-shot bootstrap + # work (logging level, langgraph_runtime engines, channels). Request-time + # config resolution always routes through `get_app_config()` in + # `app/gateway/deps.py::get_config()` so `config.yaml` edits become + # visible without a process restart. We deliberately do NOT cache this + # snapshot on `app.state` to keep that contract enforceable. try: - app.state.config = get_app_config() - apply_logging_level(app.state.config.log_level) + startup_config = get_app_config() + apply_logging_level(startup_config.log_level) logger.info("Configuration loaded successfully") except Exception as e: error_msg = f"Failed to load configuration during gateway startup: {e}" @@ -174,7 +180,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: logger.info(f"Starting API Gateway on {config.host}:{config.port}") # Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store) - async with langgraph_runtime(app): + async with langgraph_runtime(app, startup_config): logger.info("LangGraph runtime initialised") # Check admin bootstrap state and migrate orphan threads after admin exists. @@ -185,7 +191,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: try: from app.channels.service import start_channel_service - channel_service = await start_channel_service(app.state.config) + channel_service = await start_channel_service(startup_config) logger.info("Channel service started: %s", channel_service.get_status()) except Exception: logger.exception("No IM channels configured or channel service failed to start") diff --git a/backend/app/gateway/deps.py b/backend/app/gateway/deps.py index 96ea7c5ea..f045a2ee3 100644 --- a/backend/app/gateway/deps.py +++ b/backend/app/gateway/deps.py @@ -3,11 +3,21 @@ **Getters** (used by routers): raise 503 when a required dependency is missing, except ``get_store`` which returns ``None``. +``AppConfig`` is intentionally *not* cached on ``app.state``. Routers and the +run path resolve it through :func:`deerflow.config.app_config.get_app_config`, +which performs mtime-based hot reload, so edits to ``config.yaml`` take +effect on the next request without a process restart. The engines created in +:func:`langgraph_runtime` (stream bridge, persistence, checkpointer, store, +run-event store) accept a ``startup_config`` snapshot — they are +restart-required by design and stay bound to that snapshot to keep the live +process consistent with itself. + Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`. """ from __future__ import annotations +import logging from collections.abc import AsyncGenerator, Callable from contextlib import AsyncExitStack, asynccontextmanager from typing import TYPE_CHECKING, TypeVar, cast @@ -15,12 +25,14 @@ from typing import TYPE_CHECKING, TypeVar, cast from fastapi import FastAPI, HTTPException, Request from langgraph.types import Checkpointer -from deerflow.config.app_config import AppConfig +from deerflow.config.app_config import AppConfig, get_app_config from deerflow.persistence.feedback import FeedbackRepository from deerflow.runtime import RunContext, RunManager, StreamBridge from deerflow.runtime.events.store.base import RunEventStore from deerflow.runtime.runs.store.base import RunStore +logger = logging.getLogger(__name__) + if TYPE_CHECKING: from app.gateway.auth.local_provider import LocalAuthProvider from app.gateway.auth.repositories.sqlite import SQLiteUserRepository @@ -30,21 +42,55 @@ if TYPE_CHECKING: T = TypeVar("T") -def get_config(request: Request) -> AppConfig: - """Return the app-scoped ``AppConfig`` stored on ``app.state``.""" - config = getattr(request.app.state, "config", None) - if config is None: - raise HTTPException(status_code=503, detail="Configuration not available") - return config +def get_config() -> AppConfig: + """Return the freshest ``AppConfig`` for the current request. + + Routes through :func:`deerflow.config.app_config.get_app_config`, which + honours runtime ``ContextVar`` overrides and reloads ``config.yaml`` from + disk when its mtime changes. ``AppConfig`` is not cached on ``app.state`` + at all — the only startup-time snapshot lives as a local + ``startup_config`` variable inside ``lifespan()`` and is passed + explicitly into :func:`langgraph_runtime` for the engines that are + restart-required by design. Routing every request through + :func:`get_app_config` closes the bytedance/deer-flow issue #3107 BUG-001 + split-brain where the worker / lead-agent thread saw a stale startup + snapshot. + + Any failure to materialise the config (missing file, permission denied, + YAML parse error, validation error) is reported as 503 — semantically + "the gateway cannot serve requests without a usable configuration" — and + logged with the original exception so operators have something to debug. + """ + try: + return get_app_config() + except Exception as exc: # noqa: BLE001 - request boundary: log and degrade gracefully + logger.exception("Failed to load AppConfig at request time") + raise HTTPException(status_code=503, detail="Configuration not available") from exc @asynccontextmanager -async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: +async def langgraph_runtime(app: FastAPI, startup_config: AppConfig) -> AsyncGenerator[None, None]: """Bootstrap and tear down all LangGraph runtime singletons. + ``startup_config`` is the ``AppConfig`` snapshot taken once during + ``lifespan()`` for one-shot infrastructure bootstrap. The engines and + stores constructed here (stream bridge, persistence engine, checkpointer, + store, run-event store) are restart-required by design — they hold live + connections, file handles, or singleton providers — so they bind to this + snapshot and survive across `config.yaml` edits. Request-time consumers + must still go through :func:`get_config` for any field that should be + hot-reloadable. See ``backend/CLAUDE.md`` "Config Hot-Reload Boundary". + + The matching ``run_events_config`` is frozen onto ``app.state`` so + :func:`get_run_context` pairs a freshly-loaded ``AppConfig`` with the + *startup-time* run-events configuration the underlying ``event_store`` + was built from — otherwise the runtime could end up combining a live + new ``run_events_config`` with an event store still bound to the + previous backend. + Usage in ``app.py``:: - async with langgraph_runtime(app): + async with langgraph_runtime(app, startup_config): yield """ from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config @@ -53,9 +99,7 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: from deerflow.runtime.events.store import make_run_event_store async with AsyncExitStack() as stack: - config = getattr(app.state, "config", None) - if config is None: - raise RuntimeError("langgraph_runtime() requires app.state.config to be initialized") + config = startup_config app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge(config)) @@ -84,8 +128,12 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: app.state.thread_store = make_thread_store(sf, app.state.store) - # Run event store (has its own factory with config-driven backend selection) + # Run event store. The store and the matching ``run_events_config`` are + # both frozen at startup so ``get_run_context`` does not combine a + # freshly-reloaded ``AppConfig.run_events`` with a store still bound to + # the previous backend. run_events_config = getattr(config, "run_events", None) + app.state.run_events_config = run_events_config app.state.run_event_store = make_run_event_store(run_events_config) # RunManager with store backing for persistence @@ -139,16 +187,20 @@ def get_thread_store(request: Request) -> ThreadMetaStore: def get_run_context(request: Request) -> RunContext: """Build a :class:`RunContext` from ``app.state`` singletons. - Returns a *base* context with infrastructure dependencies. + Returns a *base* context with infrastructure dependencies. The + ``app_config`` field is resolved live so per-run fields (e.g. + ``models[*].max_tokens``) follow ``config.yaml`` edits; the + ``event_store`` / ``run_events_config`` pair stays frozen to the snapshot + captured in :func:`langgraph_runtime` so callers never see a store bound + to one backend paired with a config pointing at another. """ - config = get_config(request) return RunContext( checkpointer=get_checkpointer(request), store=get_store(request), event_store=get_run_event_store(request), - run_events_config=getattr(config, "run_events", None), + run_events_config=getattr(request.app.state, "run_events_config", None), thread_store=get_thread_store(request), - app_config=config, + app_config=get_config(), ) diff --git a/backend/packages/harness/deerflow/tools/builtins/task_tool.py b/backend/packages/harness/deerflow/tools/builtins/task_tool.py index a45bff787..dab1377c6 100644 --- a/backend/packages/harness/deerflow/tools/builtins/task_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/task_tool.py @@ -7,6 +7,7 @@ from dataclasses import replace from typing import TYPE_CHECKING, Annotated, Any, cast from langchain.tools import InjectedToolCallId, tool +from langchain_core.callbacks import BaseCallbackManager from langgraph.config import get_stream_writer from deerflow.config import get_app_config @@ -99,15 +100,31 @@ def _schedule_deferred_subagent_cleanup(task_id: str, trace_id: str, max_polls: def _find_usage_recorder(runtime: Any) -> Any | None: - """Find a callback handler with ``record_external_llm_usage_records`` in the runtime config.""" + """Find a callback handler with ``record_external_llm_usage_records`` in the runtime config. + + LangChain may pass ``config["callbacks"]`` in three different shapes: + + - ``None`` (no callbacks registered): no recorder. + - A plain ``list[BaseCallbackHandler]``: iterate it directly. + - A ``BaseCallbackManager`` instance (e.g. ``AsyncCallbackManager`` on async + tool runs): managers are not iterable, so we unwrap ``.handlers`` first. + + Any other shape (e.g. a single handler object accidentally passed without a + list wrapper) cannot be iterated safely; treat it as "no recorder" rather + than raise. + """ if runtime is None: return None config = getattr(runtime, "config", None) if not isinstance(config, dict): return None - callbacks = config.get("callbacks", []) + callbacks = config.get("callbacks") + if isinstance(callbacks, BaseCallbackManager): + callbacks = callbacks.handlers if not callbacks: return None + if not isinstance(callbacks, list): + return None for cb in callbacks: if hasattr(cb, "record_external_llm_usage_records"): return cb diff --git a/backend/tests/test_gateway_config_freshness.py b/backend/tests/test_gateway_config_freshness.py new file mode 100644 index 000000000..8f38ab6cc --- /dev/null +++ b/backend/tests/test_gateway_config_freshness.py @@ -0,0 +1,189 @@ +"""Regression tests for gateway config freshness on the request hot path. + +Bytedance/deer-flow issue #3107 BUG-001: the worker and lead-agent path +captured ``app.state.config`` at gateway startup. ``config.yaml`` edits during +runtime were therefore ignored — ``get_app_config()``'s mtime-based reload +existed but was bypassed because the snapshot object was passed through +explicitly. + +These tests pin the desired behaviour: a request-time ``get_config`` call must +observe the most recent on-disk ``config.yaml`` (mtime reload), and the +runtime ``ContextVar`` override must keep working for per-request injection. +""" + +from __future__ import annotations + +import os +from pathlib import Path + +import pytest +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient + +from app.gateway import deps as gateway_deps +from app.gateway.deps import get_config +from deerflow.config.app_config import ( + AppConfig, + pop_current_app_config, + push_current_app_config, + reset_app_config, + set_app_config, +) +from deerflow.config.sandbox_config import SandboxConfig + + +@pytest.fixture(autouse=True) +def _isolate_app_config_singleton(): + """Ensure each test starts with a clean module-level cache.""" + reset_app_config() + yield + reset_app_config() + + +def _write_config_yaml(path: Path, *, log_level: str) -> None: + path.write_text( + f""" +sandbox: + use: deerflow.sandbox.local.provider:LocalSandboxProvider +log_level: {log_level} +""".strip() + + "\n", + encoding="utf-8", + ) + + +def _build_app() -> FastAPI: + app = FastAPI() + + @app.get("/probe") + def probe(cfg: AppConfig = Depends(get_config)): + return {"log_level": cfg.log_level} + + return app + + +def test_get_config_reflects_file_mtime_reload(tmp_path, monkeypatch): + """Editing config.yaml at runtime must be visible to /probe without restart. + + This is the literal repro for the issue: the gateway must not freeze the + config to whatever was on disk when the process started. + """ + config_file = tmp_path / "config.yaml" + _write_config_yaml(config_file, log_level="info") + monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file)) + + app = _build_app() + client = TestClient(app) + assert client.get("/probe").json() == {"log_level": "info"} + + # Edit the file and bump its mtime — simulating a maintainer changing + # max_tokens / model settings in production while the gateway is live. + _write_config_yaml(config_file, log_level="debug") + future_mtime = config_file.stat().st_mtime + 5 + os.utime(config_file, (future_mtime, future_mtime)) + + assert client.get("/probe").json() == {"log_level": "debug"} + + +def test_get_config_respects_runtime_context_override(tmp_path, monkeypatch): + """Per-request ``push_current_app_config`` injection must still win.""" + config_file = tmp_path / "config.yaml" + _write_config_yaml(config_file, log_level="info") + monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file)) + + override = AppConfig(sandbox=SandboxConfig(use="test"), log_level="trace") + push_current_app_config(override) + try: + app = _build_app() + client = TestClient(app) + assert client.get("/probe").json() == {"log_level": "trace"} + finally: + pop_current_app_config() + + +def test_get_config_respects_test_set_app_config(): + """``set_app_config`` (used by upload/skills router tests) keeps working.""" + injected = AppConfig(sandbox=SandboxConfig(use="test"), log_level="warning") + set_app_config(injected) + + app = _build_app() + client = TestClient(app) + assert client.get("/probe").json() == {"log_level": "warning"} + + +def test_run_context_app_config_reflects_yaml_edit(tmp_path, monkeypatch): + """``RunContext.app_config`` must follow live `config.yaml` edits. + + BUG-001 review feedback: the run-context that feeds worker / lead-agent + factories must observe the same mtime reload that `get_config()` does; + otherwise stale config slips back in through the run path even after the + request dependency is fixed. + """ + from unittest.mock import MagicMock + + from app.gateway.deps import get_run_context + + config_file = tmp_path / "config.yaml" + _write_config_yaml(config_file, log_level="info") + monkeypatch.setenv("DEER_FLOW_CONFIG_PATH", str(config_file)) + + app = FastAPI() + # Sentinel values for the rest of the RunContext wiring — we only care + # about ``ctx.app_config`` for this assertion. + app.state.checkpointer = MagicMock() + app.state.store = MagicMock() + app.state.run_event_store = MagicMock() + app.state.run_events_config = {"frozen": "startup"} + app.state.thread_store = MagicMock() + + @app.get("/run-ctx-log-level") + def probe(ctx=Depends(get_run_context)): + return { + "log_level": ctx.app_config.log_level, + "run_events_config": ctx.run_events_config, + } + + client = TestClient(app) + first = client.get("/run-ctx-log-level").json() + assert first == {"log_level": "info", "run_events_config": {"frozen": "startup"}} + + _write_config_yaml(config_file, log_level="debug") + future_mtime = config_file.stat().st_mtime + 5 + os.utime(config_file, (future_mtime, future_mtime)) + + second = client.get("/run-ctx-log-level").json() + # app_config follows the edit; run_events_config stays frozen to the + # startup snapshot we wrote onto app.state above. + assert second == {"log_level": "debug", "run_events_config": {"frozen": "startup"}} + + +@pytest.mark.parametrize( + "exception", + [ + FileNotFoundError("config.yaml not found"), + PermissionError("config.yaml not readable"), + ValueError("invalid config"), + RuntimeError("yaml parse error"), + ], +) +def test_get_config_returns_503_on_any_load_failure(monkeypatch, exception): + """Any failure to materialise the config must surface as 503, not 500. + + Bytedance/deer-flow issue #3107 BUG-001 review: the original snapshot + contract returned 503 when ``app.state.config is None``. The first cut of + this fix only mapped ``FileNotFoundError`` to 503, which left + ``PermissionError`` / ``yaml.YAMLError`` / ``ValidationError`` etc. bubbling + up as 500. Catch every load failure at the request boundary. + """ + + def _broken_get_app_config(): + raise exception + + monkeypatch.setattr(gateway_deps, "get_app_config", _broken_get_app_config) + + app = _build_app() + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/probe") + + assert response.status_code == 503 + assert response.json() == {"detail": "Configuration not available"} diff --git a/backend/tests/test_gateway_deps_config.py b/backend/tests/test_gateway_deps_config.py deleted file mode 100644 index 70f9124b6..000000000 --- a/backend/tests/test_gateway_deps_config.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import annotations - -from fastapi import Depends, FastAPI -from fastapi.testclient import TestClient - -from app.gateway.deps import get_config -from deerflow.config.app_config import AppConfig -from deerflow.config.sandbox_config import SandboxConfig - - -def test_get_config_returns_app_state_config(): - """get_config should return the exact AppConfig stored on app.state.""" - app = FastAPI() - config = AppConfig(sandbox=SandboxConfig(use="test")) - app.state.config = config - - @app.get("/probe") - def probe(cfg: AppConfig = Depends(get_config)): - return {"same_identity": cfg is config, "log_level": cfg.log_level} - - client = TestClient(app) - response = client.get("/probe") - - assert response.status_code == 200 - assert response.json() == {"same_identity": True, "log_level": "info"} - - -def test_get_config_reads_updated_app_state(): - """Swapping app.state.config should be visible to the dependency.""" - app = FastAPI() - app.state.config = AppConfig(sandbox=SandboxConfig(use="test"), log_level="info") - - @app.get("/log-level") - def log_level(cfg: AppConfig = Depends(get_config)): - return {"level": cfg.log_level} - - client = TestClient(app) - assert client.get("/log-level").json() == {"level": "info"} - - app.state.config = app.state.config.model_copy(update={"log_level": "debug"}) - assert client.get("/log-level").json() == {"level": "debug"} diff --git a/backend/tests/test_gateway_lifespan_shutdown.py b/backend/tests/test_gateway_lifespan_shutdown.py index 9319c6268..a694ab00a 100644 --- a/backend/tests/test_gateway_lifespan_shutdown.py +++ b/backend/tests/test_gateway_lifespan_shutdown.py @@ -17,7 +17,7 @@ from fastapi import FastAPI @asynccontextmanager -async def _noop_langgraph_runtime(_app): +async def _noop_langgraph_runtime(_app, _startup_config): yield diff --git a/backend/tests/test_skills_custom_router.py b/backend/tests/test_skills_custom_router.py index ed93e5510..e8a86d8ab 100644 --- a/backend/tests/test_skills_custom_router.py +++ b/backend/tests/test_skills_custom_router.py @@ -7,6 +7,7 @@ from types import SimpleNamespace from fastapi import FastAPI from fastapi.testclient import TestClient +from app.gateway.deps import get_config from app.gateway.routers import skills as skills_router from deerflow.skills.storage import get_or_new_skill_storage from deerflow.skills.types import Skill @@ -38,7 +39,8 @@ def _make_skill(name: str, *, enabled: bool) -> Skill: def _make_test_app(config) -> FastAPI: app = FastAPI() - app.state.config = config + app.state.config = config # kept for any startup-style reads + app.dependency_overrides[get_config] = lambda: config app.include_router(skills_router.router) return app diff --git a/backend/tests/test_task_tool_usage_recorder.py b/backend/tests/test_task_tool_usage_recorder.py new file mode 100644 index 000000000..d7b4ea3b5 --- /dev/null +++ b/backend/tests/test_task_tool_usage_recorder.py @@ -0,0 +1,91 @@ +"""Regression tests for _find_usage_recorder callback shape handling. + +Bytedance issue #3107 BUG-002: When LangChain passes ``config["callbacks"]`` as +an ``AsyncCallbackManager`` (instead of a plain list), the previous +``for cb in callbacks`` loop raised ``TypeError: 'AsyncCallbackManager' object +is not iterable``. ToolErrorHandlingMiddleware then converted the entire ``task`` +tool call into an error ToolMessage, losing the subagent result. +""" + +from types import SimpleNamespace + +from langchain_core.callbacks import AsyncCallbackManager, CallbackManager + +from deerflow.tools.builtins.task_tool import _find_usage_recorder + + +class _RecorderHandler: + def record_external_llm_usage_records(self, records): + self.records = records + + +class _OtherHandler: + pass + + +def _make_runtime(callbacks): + return SimpleNamespace(config={"callbacks": callbacks}) + + +def test_find_usage_recorder_with_plain_list(): + recorder = _RecorderHandler() + runtime = _make_runtime([_OtherHandler(), recorder]) + assert _find_usage_recorder(runtime) is recorder + + +def test_find_usage_recorder_with_async_callback_manager(): + """LangChain wraps callbacks in AsyncCallbackManager for async tool runs. + + The old implementation raised TypeError here. The recorder lives on + ``manager.handlers``; we must look there too. + """ + recorder = _RecorderHandler() + manager = AsyncCallbackManager(handlers=[_OtherHandler(), recorder]) + runtime = _make_runtime(manager) + assert _find_usage_recorder(runtime) is recorder + + +def test_find_usage_recorder_with_sync_callback_manager(): + """Sync flavor of the same wrapper used by some langchain code paths.""" + recorder = _RecorderHandler() + manager = CallbackManager(handlers=[recorder]) + runtime = _make_runtime(manager) + assert _find_usage_recorder(runtime) is recorder + + +def test_find_usage_recorder_returns_none_when_no_recorder(): + manager = AsyncCallbackManager(handlers=[_OtherHandler()]) + runtime = _make_runtime(manager) + assert _find_usage_recorder(runtime) is None + + +def test_find_usage_recorder_handles_empty_manager(): + manager = AsyncCallbackManager(handlers=[]) + runtime = _make_runtime(manager) + assert _find_usage_recorder(runtime) is None + + +def test_find_usage_recorder_returns_none_for_none_runtime(): + assert _find_usage_recorder(None) is None + + +def test_find_usage_recorder_returns_none_when_callbacks_is_none(): + runtime = _make_runtime(None) + assert _find_usage_recorder(runtime) is None + + +def test_find_usage_recorder_returns_none_for_single_handler_object(): + """A single handler instance (not wrapped in a list or manager) should not crash. + + LangChain's contract is that ``config["callbacks"]`` is a list-or-manager, + but we treat any other shape defensively rather than letting a ``for`` loop + blow up at runtime. + """ + runtime = _make_runtime(_RecorderHandler()) + assert _find_usage_recorder(runtime) is None + + +def test_find_usage_recorder_returns_none_when_config_not_dict(): + """Defensive: a runtime without a dict-shaped config should not raise.""" + runtime = SimpleNamespace(config="not-a-dict") + assert _find_usage_recorder(runtime) is None diff --git a/backend/tests/test_uploads_router.py b/backend/tests/test_uploads_router.py index 7846865b8..1bcdb2eb7 100644 --- a/backend/tests/test_uploads_router.py +++ b/backend/tests/test_uploads_router.py @@ -11,6 +11,7 @@ from _router_auth_helpers import call_unwrapped, make_authed_test_app from fastapi import HTTPException, UploadFile from fastapi.testclient import TestClient +from app.gateway.deps import get_config from app.gateway.routers import uploads @@ -631,6 +632,7 @@ def test_upload_limits_endpoint_requires_thread_access(): cfg.uploads = {} app = make_authed_test_app(owner_check_passes=False) app.state.config = cfg + app.dependency_overrides[get_config] = lambda: cfg app.include_router(uploads.router) with TestClient(app) as client: diff --git a/frontend/src/components/workspace/messages/message-list.tsx b/frontend/src/components/workspace/messages/message-list.tsx index ffbf3e3ad..74dca2af5 100644 --- a/frontend/src/components/workspace/messages/message-list.tsx +++ b/frontend/src/components/workspace/messages/message-list.tsx @@ -27,6 +27,7 @@ import { import { useRehypeSplitWordsIntoSpans } from "@/core/rehype"; import type { Subtask } from "@/core/tasks"; import { useUpdateSubtask } from "@/core/tasks/context"; +import { parseSubtaskResult } from "@/core/tasks/subtask-result"; import type { AgentThreadState } from "@/core/threads"; import { cn } from "@/lib/utils"; @@ -359,33 +360,10 @@ export function MessageList({ } else if (message.type === "tool") { const taskId = message.tool_call_id; if (taskId) { - const result = extractTextFromMessage(message); - if (result.startsWith("Task Succeeded. Result:")) { - updateSubtask({ - id: taskId, - status: "completed", - result: result - .split("Task Succeeded. Result:")[1] - ?.trim(), - }); - } else if (result.startsWith("Task failed.")) { - updateSubtask({ - id: taskId, - status: "failed", - error: result.split("Task failed.")[1]?.trim(), - }); - } else if (result.startsWith("Task timed out")) { - updateSubtask({ - id: taskId, - status: "failed", - error: result, - }); - } else { - updateSubtask({ - id: taskId, - status: "in_progress", - }); - } + const parsed = parseSubtaskResult( + extractTextFromMessage(message), + ); + updateSubtask({ id: taskId, ...parsed }); } } } diff --git a/frontend/src/core/messages/utils.ts b/frontend/src/core/messages/utils.ts index 22f985009..1c165fd8d 100644 --- a/frontend/src/core/messages/utils.ts +++ b/frontend/src/core/messages/utils.ts @@ -397,6 +397,50 @@ export function stripUploadedFilesTag(content: string): string { .trim(); } +/** + * Tag names that backend middlewares wrap around internal payloads before + * letting them ride along inside LangGraph message ``content``. + * + * These markers are *not* user copy — they come from: + * + * - ``UploadsMiddleware`` → ```` + * - ``DynamicContextMiddleware`` → ```` (carrying + * ```` / ```` inside) + * - ``TodoListMiddleware`` / ``LoopDetectionMiddleware`` style reminders + * live in ``hide_from_ui`` HumanMessages, but their inner payload uses + * the same tag vocabulary. + * + * The primary export filter is {@link isHiddenFromUIMessage}. This list is + * the defence-in-depth strip for any message that — by middleware bug, + * provider quirk, or merge-conflict regression — slips through without + * its ``hide_from_ui`` flag set. + */ +export const INTERNAL_MARKER_TAGS = [ + "uploaded_files", + "system-reminder", + "memory", + "current_date", +] as const; + +const INTERNAL_MARKER_RE = new RegExp( + `<(${INTERNAL_MARKER_TAGS.join("|")})>[\\s\\S]*?`, + "g", +); + +/** + * Strip every known backend-injected marker from message content. + * + * Intended for the chat export path where a marker leaking through is a + * privacy regression. UI render paths should keep using + * {@link stripUploadedFilesTag} — they receive ``hide_from_ui`` messages + * via a separate filter and the narrower function avoids stripping content + * a user might legitimately type into a meta-discussion (e.g. asking the + * model about its own ```` system). + */ +export function stripInternalMarkers(content: string): string { + return content.replace(INTERNAL_MARKER_RE, "").trim(); +} + export function parseUploadedFiles(content: string): FileInMessage[] { // Match ... tag const uploadedFilesRegex = /([\s\S]*?)<\/uploaded_files>/; diff --git a/frontend/src/core/tasks/subtask-result.ts b/frontend/src/core/tasks/subtask-result.ts new file mode 100644 index 000000000..ac4a422a9 --- /dev/null +++ b/frontend/src/core/tasks/subtask-result.ts @@ -0,0 +1,88 @@ +import type { Subtask } from "./types"; + +export type SubtaskStatus = Subtask["status"]; + +export interface SubtaskResultUpdate { + status: SubtaskStatus; + result?: string; + error?: string; +} + +/** + * Prefix strings the backend `task` tool writes into its result `content`. + * + * These values are not user-facing copy — they are part of the + * backend↔frontend contract defined in + * `backend/packages/harness/deerflow/tools/builtins/task_tool.py` (returned + * from the tool body) and in + * `backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py` + * (wrapper for tool exceptions). Any change here must be paired with the + * matching backend change. Exported so a future structured-status migration + * can reference the same values from one place. + * + * `task_tool.py` also emits three `Error:` strings for pre-execution failures + * — unknown subagent type, host-bash disabled, and "task disappeared from + * background tasks". They are handled by {@link ERROR_WRAPPER_PATTERN} + * rather than dedicated prefixes because the wrapper already produces + * exactly the right `terminal failed` shape. + */ +export const SUCCESS_PREFIX = "Task Succeeded. Result:"; +export const FAILURE_PREFIX = "Task failed."; +export const TIMEOUT_PREFIX = "Task timed out"; +export const CANCELLED_PREFIX = "Task cancelled by user."; +export const POLLING_TIMEOUT_PREFIX = "Task polling timed out"; +export const ERROR_WRAPPER_PATTERN = /^Error\b/i; + +/** + * Map a `task` tool result string to a {@link SubtaskStatus}. + * + * Bytedance/deer-flow issue #3107 BUG-007: parent-visible task tool errors do + * not always start with one of the three legacy prefixes (e.g. when + * `ToolErrorHandlingMiddleware` wraps an exception as + * `Error: Tool 'task' failed ...`). Treat any leading `Error:` token as a + * terminal failure so subtask cards stop being stuck on "in_progress". + * + * Returning `in_progress` is the **deliberate** fallback for content that + * matches none of the known prefixes. LangChain only ever emits a + * `ToolMessage` once the tool itself has returned (success or wrapped + * exception), so an unknown shape means "the contract changed underneath us" + * — surfacing it as still-running prompts the operator to investigate, where + * eagerly marking it terminal-failed would mask the drift. + */ +export function parseSubtaskResult(text: string): SubtaskResultUpdate { + const trimmed = text.trim(); + + if (trimmed.startsWith(SUCCESS_PREFIX)) { + return { + status: "completed", + result: trimmed.slice(SUCCESS_PREFIX.length).trim(), + }; + } + + if (trimmed.startsWith(FAILURE_PREFIX)) { + return { + status: "failed", + error: trimmed.slice(FAILURE_PREFIX.length).trim(), + }; + } + + if (trimmed.startsWith(TIMEOUT_PREFIX)) { + return { status: "failed", error: trimmed }; + } + + if (trimmed.startsWith(CANCELLED_PREFIX)) { + return { status: "failed", error: trimmed }; + } + + if (trimmed.startsWith(POLLING_TIMEOUT_PREFIX)) { + return { status: "failed", error: trimmed }; + } + + // ToolErrorHandlingMiddleware-style wrapper, or any other terminal error + // signal the backend forwards to the lead agent. + if (ERROR_WRAPPER_PATTERN.test(trimmed)) { + return { status: "failed", error: trimmed }; + } + + return { status: "in_progress" }; +} diff --git a/frontend/src/core/threads/export.ts b/frontend/src/core/threads/export.ts index cf1f92e47..02eeeb910 100644 --- a/frontend/src/core/threads/export.ts +++ b/frontend/src/core/threads/export.ts @@ -5,16 +5,53 @@ import { extractReasoningContentFromMessage, hasContent, hasToolCalls, - stripUploadedFilesTag, + isHiddenFromUIMessage, + stripInternalMarkers, } from "../messages/utils"; import type { AgentThread } from "./types"; import { titleOfThread } from "./utils"; +/** + * Optional debug switches for advanced exports. + * + * Bytedance/deer-flow issue #3107 BUG-006 explicitly prescribes that the + * default export includes only the user-visible transcript and excludes + * thinking/reasoning content, tool calls, tool results, hidden messages, + * memory injection, and `` payloads. These options let a + * future "debug export" surface re-include any of those categories without + * forking the formatter. They are not currently wired to any UI control — + * callers that want them must construct the options object explicitly. + */ +export interface ExportOptions { + includeReasoning?: boolean; + includeToolCalls?: boolean; + includeToolMessages?: boolean; + includeHidden?: boolean; +} + +function visibleMessages( + messages: Message[], + options: ExportOptions, +): Message[] { + return messages.filter((message) => { + if (!options.includeHidden && isHiddenFromUIMessage(message)) { + return false; + } + if (!options.includeToolMessages && message.type === "tool") { + return false; + } + return true; + }); +} + function formatMessageContent(message: Message): string { const text = extractContentFromMessage(message); if (!text) return ""; - return stripUploadedFilesTag(text); + // Defence-in-depth: even if a middleware-injected marker slipped through + // the `hide_from_ui` filter, scrub every known internal tag before the + // content lands in a user-visible export file. + return stripInternalMarkers(text); } function formatToolCalls(message: Message): string { @@ -26,6 +63,7 @@ function formatToolCalls(message: Message): string { export function formatThreadAsMarkdown( thread: AgentThread, messages: Message[], + options: ExportOptions = {}, ): string { const title = titleOfThread(thread); const createdAt = thread.created_at @@ -41,16 +79,20 @@ export function formatThreadAsMarkdown( "", ]; - for (const message of messages) { + for (const message of visibleMessages(messages, options)) { if (message.type === "human") { const content = formatMessageContent(message); if (content) { lines.push(`## 🧑 User`, "", content, "", "---", ""); } } else if (message.type === "ai") { - const reasoning = extractReasoningContentFromMessage(message); + const reasoning = options.includeReasoning + ? extractReasoningContentFromMessage(message) + : undefined; const content = formatMessageContent(message); - const toolCalls = formatToolCalls(message); + const toolCalls = options.includeToolCalls + ? formatToolCalls(message) + : ""; if (!content && !toolCalls && !reasoning) continue; @@ -83,23 +125,65 @@ export function formatThreadAsMarkdown( return lines.join("\n").trimEnd() + "\n"; } +interface JSONExportMessage { + type: Message["type"]; + id: string | undefined; + content: string; + reasoning?: string; + tool_calls?: unknown; +} + +function buildJSONMessage( + msg: Message, + options: ExportOptions, +): JSONExportMessage | null { + // Run the same sanitiser the Markdown path uses so the JSON `content` + // field never carries inline `...` wrappers, content-array + // thinking blocks, `` markers, or other internal payloads. + const content = formatMessageContent(msg); + const reasoning = + options.includeReasoning && msg.type === "ai" + ? (extractReasoningContentFromMessage(msg) ?? undefined) + : undefined; + const toolCalls = + options.includeToolCalls && + msg.type === "ai" && + "tool_calls" in msg && + msg.tool_calls?.length + ? msg.tool_calls + : undefined; + + // Drop rows with no exportable payload (empty content + no opted-in + // reasoning / tool_calls). Uses falsy semantics so `reasoning: ""` (the + // empty string ``extractReasoningContentFromMessage`` can hand back) is + // treated the same way Markdown's `!reasoning` continue does — otherwise + // an opted-in but empty reasoning field would leak as `{reasoning: ""}`. + if (!content && !reasoning && !toolCalls) { + return null; + } + + return { + type: msg.type, + id: msg.id, + content, + ...(reasoning !== undefined ? { reasoning } : {}), + ...(toolCalls !== undefined ? { tool_calls: toolCalls } : {}), + }; +} + export function formatThreadAsJSON( thread: AgentThread, messages: Message[], + options: ExportOptions = {}, ): string { const exportData = { title: titleOfThread(thread), thread_id: thread.thread_id, created_at: thread.created_at, exported_at: new Date().toISOString(), - messages: messages.map((msg) => ({ - type: msg.type, - id: msg.id, - content: typeof msg.content === "string" ? msg.content : msg.content, - ...(msg.type === "ai" && msg.tool_calls?.length - ? { tool_calls: msg.tool_calls } - : {}), - })), + messages: visibleMessages(messages, options) + .map((msg) => buildJSONMessage(msg, options)) + .filter((m): m is JSONExportMessage => m !== null), }; return JSON.stringify(exportData, null, 2); } diff --git a/frontend/tests/unit/core/tasks/subtask-result.test.ts b/frontend/tests/unit/core/tasks/subtask-result.test.ts new file mode 100644 index 000000000..4f0597fda --- /dev/null +++ b/frontend/tests/unit/core/tasks/subtask-result.test.ts @@ -0,0 +1,112 @@ +import { describe, expect, it } from "vitest"; + +import { parseSubtaskResult } from "@/core/tasks/subtask-result"; + +describe("parseSubtaskResult", () => { + it("recognises the standard success prefix", () => { + const parsed = parseSubtaskResult( + "Task Succeeded. Result: investigated and produced a 3-page report", + ); + expect(parsed.status).toBe("completed"); + expect(parsed.result).toBe("investigated and produced a 3-page report"); + }); + + it("recognises the standard failure prefix", () => { + const parsed = parseSubtaskResult( + "Task failed. underlying tool raised RuntimeError", + ); + expect(parsed.status).toBe("failed"); + expect(parsed.error).toBe("underlying tool raised RuntimeError"); + }); + + it("recognises the standard timeout prefix", () => { + const parsed = parseSubtaskResult("Task timed out after 900s"); + expect(parsed.status).toBe("failed"); + expect(parsed.error).toBe("Task timed out after 900s"); + }); + + it("recognises the cancelled-by-user prefix", () => { + // bytedance/deer-flow#3131 review: this is one of the five terminal + // strings task_tool.py actually emits — the previous cut treated it as + // unrecognised content and pushed the card back to in_progress. + const parsed = parseSubtaskResult("Task cancelled by user."); + expect(parsed.status).toBe("failed"); + expect(parsed.error).toBe("Task cancelled by user."); + }); + + it("recognises the polling-timed-out prefix", () => { + // Emitted by task_tool when the background polling loop runs out of + // budget waiting for the subagent to reach a terminal state. + const parsed = parseSubtaskResult( + "Task polling timed out after 15 minutes. This may indicate the background task is stuck. Status: RUNNING", + ); + expect(parsed.status).toBe("failed"); + expect(parsed.error).toContain("polling timed out"); + }); + + it("recognises polling-timed-out with different durations", () => { + // `task_tool` emits `Task polling timed out after {N} minutes` where N + // varies with the configured subagent timeout. Guard against the regex + // accidentally being pinned to a specific number. + for (const n of [1, 5, 60]) { + const parsed = parseSubtaskResult( + `Task polling timed out after ${n} minutes. Status: RUNNING`, + ); + expect(parsed.status).toBe("failed"); + } + }); + + it("trims whitespace around cancelled and polling-timed-out prefixes", () => { + // Streaming chunks sometimes arrive with leading/trailing newlines. + expect(parseSubtaskResult(" Task cancelled by user. \n").status).toBe( + "failed", + ); + expect( + parseSubtaskResult("\n\nTask polling timed out after 3 minutes").status, + ).toBe("failed"); + }); + + it("recognises task_tool pre-execution Error: returns via the wrapper", () => { + // `task_tool.py` returns three `Error:` strings for unknown subagent + // type, host-bash disabled, and "task disappeared". They share the + // ERROR_WRAPPER_PATTERN, not a dedicated prefix, so this guards + // against a refactor splitting them off. + for (const text of [ + "Error: Unknown subagent type 'foo'. Available: bash, general-purpose", + "Error: Host bash subagent is disabled by configuration", + "Error: Task 1234 disappeared from background tasks", + ]) { + expect(parseSubtaskResult(text).status).toBe("failed"); + } + }); + + it("treats middleware-wrapped tool errors as terminal failures", () => { + // bytedance/deer-flow issue #3107 BUG-007: the parent-visible ToolMessage + // produced by ToolErrorHandlingMiddleware never matches the three legacy + // prefixes, so subtask cards stay stuck on "in_progress". + const parsed = parseSubtaskResult( + "Error: Tool 'task' failed with TypeError: 'AsyncCallbackManager' object is not iterable. Continue with available context, or choose an alternative tool.", + ); + expect(parsed.status).toBe("failed"); + expect(parsed.error).toContain("AsyncCallbackManager"); + }); + + it("treats any other Error: prefix as a terminal failure", () => { + const parsed = parseSubtaskResult("Error: subagent worker pool exhausted"); + expect(parsed.status).toBe("failed"); + }); + + it("keeps unrecognised non-error output as in_progress", () => { + // Streaming partial chunks should not flip the card to terminal early. + const parsed = parseSubtaskResult("Investigating ..."); + expect(parsed.status).toBe("in_progress"); + expect(parsed.error).toBeUndefined(); + expect(parsed.result).toBeUndefined(); + }); + + it("trims surrounding whitespace before matching prefixes", () => { + const parsed = parseSubtaskResult(" Task Succeeded. Result: ok "); + expect(parsed.status).toBe("completed"); + expect(parsed.result).toBe("ok"); + }); +}); diff --git a/frontend/tests/unit/core/threads/export.test.ts b/frontend/tests/unit/core/threads/export.test.ts new file mode 100644 index 000000000..8ee520aa3 --- /dev/null +++ b/frontend/tests/unit/core/threads/export.test.ts @@ -0,0 +1,317 @@ +import type { Message } from "@langchain/langgraph-sdk"; +import { describe, expect, it } from "vitest"; + +import { + formatThreadAsJSON, + formatThreadAsMarkdown, +} from "@/core/threads/export"; +import type { AgentThread } from "@/core/threads/types"; + +// Bytedance/deer-flow issue #3107 BUG-006: the chat export path bypasses the +// UI-level hidden-message filter and emits reasoning content, tool calls, and +// any other "internal" payload as if it were part of the user transcript. + +function makeThread(): AgentThread { + return { + thread_id: "thread-1", + created_at: "2026-05-21T00:00:00Z", + updated_at: "2026-05-21T00:00:00Z", + metadata: { title: "Demo thread" }, + status: "idle", + values: { messages: [] }, + } as unknown as AgentThread; +} + +function human(content: string, extra: Partial = {}): Message { + return { + id: `h-${content}`, + type: "human", + content, + ...extra, + } as Message; +} + +function ai( + content: string, + extra: Partial & { tool_calls?: unknown } = {}, +): Message { + return { + id: `a-${content}`, + type: "ai", + content, + ...extra, + } as Message; +} + +function toolMsg(content: string): Message { + return { + id: `t-${content}`, + type: "tool", + content, + name: "task", + tool_call_id: "call-1", + } as unknown as Message; +} + +describe("formatThreadAsMarkdown", () => { + it("includes plain user and assistant text", () => { + const md = formatThreadAsMarkdown(makeThread(), [ + human("hello"), + ai("hi there"), + ]); + expect(md).toContain("hello"); + expect(md).toContain("hi there"); + }); + + it("drops messages marked hide_from_ui", () => { + const hidden = human("internal system reminder", { + additional_kwargs: { hide_from_ui: true }, + } as Partial); + const md = formatThreadAsMarkdown(makeThread(), [ + hidden, + ai("public answer"), + ]); + expect(md).not.toContain("internal system reminder"); + expect(md).toContain("public answer"); + }); + + it("does not emit reasoning_content by default", () => { + const message = ai("final answer", { + additional_kwargs: { + reasoning_content: "secret chain of thought", + }, + } as Partial); + const md = formatThreadAsMarkdown(makeThread(), [message]); + expect(md).not.toContain("secret chain of thought"); + expect(md).not.toContain("Thinking"); + }); + + it("does not emit tool calls by default", () => { + const message = ai("ok", { + tool_calls: [{ id: "1", name: "task", args: { description: "do work" } }], + } as Partial); + const md = formatThreadAsMarkdown(makeThread(), [message]); + expect(md).not.toContain("**Tool:**"); + expect(md).not.toContain("`task`"); + }); + + it("drops tool result messages", () => { + const md = formatThreadAsMarkdown(makeThread(), [ + ai("delegating"), + toolMsg("Task Succeeded. Result: confidential"), + ]); + expect(md).not.toContain("confidential"); + }); +}); + +describe("formatThreadAsMarkdown opt-in flags", () => { + it("emits reasoning when includeReasoning is true", () => { + const message = ai("final answer", { + additional_kwargs: { + reasoning_content: "step-by-step chain of thought", + }, + } as Partial); + const md = formatThreadAsMarkdown(makeThread(), [message], { + includeReasoning: true, + }); + expect(md).toContain("step-by-step chain of thought"); + expect(md).toContain("Thinking"); + }); + + it("emits tool call rows when includeToolCalls is true", () => { + const message = ai("ok", { + tool_calls: [{ id: "1", name: "task", args: { description: "do work" } }], + } as Partial); + const md = formatThreadAsMarkdown(makeThread(), [message], { + includeToolCalls: true, + }); + expect(md).toContain("**Tool:**"); + expect(md).toContain("`task`"); + }); + + it("keeps hidden messages when includeHidden is true", () => { + const hidden = human("internal reminder", { + additional_kwargs: { hide_from_ui: true }, + } as Partial); + const md = formatThreadAsMarkdown(makeThread(), [hidden], { + includeHidden: true, + }); + expect(md).toContain("internal reminder"); + }); +}); + +describe("formatThreadAsJSON opt-in flags", () => { + it("emits tool_calls field when includeToolCalls is true", () => { + const message = ai("ok", { + tool_calls: [{ id: "1", name: "task", args: { description: "x" } }], + } as Partial); + const raw = formatThreadAsJSON(makeThread(), [message], { + includeToolCalls: true, + }); + expect(raw).toContain("tool_calls"); + expect(raw).toContain('"task"'); + }); + + it("keeps tool messages when includeToolMessages is true", () => { + const raw = formatThreadAsJSON( + makeThread(), + [toolMsg("Task Succeeded. Result: keep me")], + { includeToolMessages: true }, + ); + const parsed = JSON.parse(raw) as { messages: { type: string }[] }; + expect(parsed.messages.some((m) => m.type === "tool")).toBe(true); + expect(raw).toContain("keep me"); + }); +}); + +describe("formatThreadAsJSON", () => { + it("strips hidden messages, tool messages, reasoning, and tool calls", () => { + const messages = [ + human("hello"), + human("secret reminder", { + additional_kwargs: { hide_from_ui: true }, + } as Partial), + ai("answer", { + additional_kwargs: { + reasoning_content: "secret reasoning", + }, + tool_calls: [{ id: "1", name: "task", args: {} }], + } as Partial), + toolMsg("internal trace"), + ]; + const raw = formatThreadAsJSON(makeThread(), messages); + const parsed = JSON.parse(raw) as { + messages: { type: string; tool_calls?: unknown[] }[]; + }; + + expect(parsed.messages).toHaveLength(2); + expect(parsed.messages.every((m) => m.type !== "tool")).toBe(true); + expect(raw).not.toContain("secret reminder"); + expect(raw).not.toContain("secret reasoning"); + expect(raw).not.toContain("internal trace"); + expect(raw).not.toContain("tool_calls"); + }); + + it("strips inline ... wrappers from content", () => { + // bytedance/deer-flow#3131 review: JSON export must run the same + // sanitiser the Markdown path uses so inline reasoning never leaks + // even when `includeReasoning` is left at its default false. + const message = ai("internal monologuevisible answer", { + id: "ai-1", + } as Partial); + const raw = formatThreadAsJSON(makeThread(), [message]); + expect(raw).not.toContain("internal monologue"); + expect(raw).not.toContain(""); + expect(raw).toContain("visible answer"); + }); + + it("strips content-array thinking blocks from content", () => { + const message = ai("placeholder", { + id: "ai-2", + content: [ + { type: "thinking", thinking: "hidden reasoning step" }, + { type: "text", text: "final visible text" }, + ], + } as unknown as Partial); + const raw = formatThreadAsJSON(makeThread(), [message]); + expect(raw).not.toContain("hidden reasoning step"); + expect(raw).toContain("final visible text"); + }); + + it("strips markers from content", () => { + const message = human( + "real prompt\n\n/mnt/user-data/uploads/secret.pdf\n", + { id: "h-clean" } as Partial, + ); + const raw = formatThreadAsJSON(makeThread(), [message]); + expect(raw).not.toContain(""); + expect(raw).not.toContain("secret.pdf"); + expect(raw).toContain("real prompt"); + }); + + it("drops AI messages that sanitise to empty content", () => { + // Pure-reasoning AI fragments (no visible text, no tool calls) should + // not survive as `{content: ""}` rows in the export. + const message = ai("only thinking, no answer", { + id: "ai-3", + } as Partial); + const raw = formatThreadAsJSON(makeThread(), [message]); + const parsed = JSON.parse(raw) as { messages: unknown[] }; + expect(parsed.messages).toHaveLength(0); + }); + + it("strips // as defence in depth", () => { + // Primary protection is `isHiddenFromUIMessage` filtering the whole + // hidden HumanMessage. If a regression strips the `hide_from_ui` flag + // (or the marker leaks into an otherwise-visible message), the + // sanitiser must still scrub the payload before export. + const leaky = human("real user text", { + id: "leak-1", + content: + "\nsecret fact A\n2026-01-01, Tuesday\n\nreal user text", + // Deliberately *not* setting hide_from_ui to model the regression + // case the defence-in-depth strip is guarding against. + } as unknown as Partial); + const raw = formatThreadAsJSON(makeThread(), [leaky]); + expect(raw).not.toContain(""); + expect(raw).not.toContain(""); + expect(raw).not.toContain(""); + expect(raw).not.toContain("secret fact A"); + expect(raw).toContain("real user text"); + }); + + it("sanitises tool message content when includeToolMessages is true", () => { + const message = { + id: "t-leak", + type: "tool", + content: + "Task Succeeded. Result: payload\n\n/mnt/user-data/uploads/secret.pdf\n", + name: "task", + tool_call_id: "call-leak", + } as unknown as Message; + + const raw = formatThreadAsJSON(makeThread(), [message], { + includeToolMessages: true, + }); + expect(raw).toContain("Task Succeeded"); + expect(raw).not.toContain(""); + expect(raw).not.toContain("secret.pdf"); + }); + + it("preserves text and image_url parts in mixed content arrays", () => { + // `extractContentFromMessage` keeps `text` and `image_url` parts and + // drops `thinking` parts. The JSON export must agree with that + // contract. + const message = ai("placeholder", { + id: "ai-mixed", + content: [ + { type: "thinking", thinking: "internal reasoning" }, + { type: "text", text: "user-visible answer" }, + { + type: "image_url", + image_url: { url: "https://example.invalid/cat.png" }, + }, + ], + } as unknown as Partial); + const raw = formatThreadAsJSON(makeThread(), [message]); + expect(raw).toContain("user-visible answer"); + expect(raw).toContain("https://example.invalid/cat.png"); + expect(raw).not.toContain("internal reasoning"); + }); + + it("drops opted-in empty reasoning rather than emit reasoning: ''", () => { + // `extractReasoningContentFromMessage` can legitimately hand back "" + // for an AI message that has no reasoning content. The export must + // mirror the Markdown path's `!reasoning` `continue` and drop the row + // instead of leaking `{reasoning: ""}`. + const message = ai("", { + id: "ai-empty-reasoning", + additional_kwargs: { reasoning_content: "" }, + } as Partial); + const raw = formatThreadAsJSON(makeThread(), [message], { + includeReasoning: true, + }); + const parsed = JSON.parse(raw) as { messages: unknown[] }; + expect(parsed.messages).toHaveLength(0); + }); +}); From c881d95898414b14ad6116d8fdf9c011c6f3eaf4 Mon Sep 17 00:00:00 2001 From: Willem Jiang Date: Thu, 21 May 2026 23:22:20 +0800 Subject: [PATCH 65/86] fix(mcp): persist MCP sessions across tool calls for stateful servers (#3089) * fix(mcp): persist MCP sessions across tool calls for stateful servers MCP tools loaded via langchain-mcp-adapters created a new session on every call, causing stateful servers like Playwright to lose browser state (pages, forms) between consecutive tool invocations within the same thread. Add MCPSessionPool that maintains persistent sessions scoped by (server_name, thread_id). Tool calls within the same thread now reuse the same MCP session, preserving server-side state. Sessions are evicted in LRU order (max 256) and cleaned up on cache invalidation. Fixes #3054 * fix(sandbox): add group/other read permissions to uploaded files for Docker sandbox (#3127) When using AIO sandbox with LocalContainerBackend, uploaded files are created with 0o600 (owner-only) permissions by the gateway process running as root. The sandbox process inside the Docker container runs as a non-root user and cannot read these bind-mounted files, causing a "Permission denied" error on read_file. Add `needs_upload_permission_adjustment` attribute to SandboxProvider (default True) to indicate that uploaded files need chmod adjustment. LocalSandboxProvider opts out (same user). A new `_make_file_sandbox_readable` function adds S_IRGRP | S_IROTH bits after files are written, changing permissions from 0o600 to 0o644 so the sandbox can read the uploads. * fix(mcp): address review comments on session pool and tools - _extract_thread_id: return "default" instead of stringifying None when get_config() returns no thread_id - call_with_persistent_session: fix **arguments annotation from dict[str,Any] to Any - Replace private _convert_call_tool_result import with a local implementation that handles all MCP content block types - _make_session_pool_tool: accept tool_interceptors and apply the configured interceptor chain on every call (preserving OAuth and custom interceptors) - MCPSessionPool: replace asyncio.Lock with threading.Lock; restructure get/close methods to never await while holding the lock; add close_all_sync() that closes sessions on their owning event loops - reset_mcp_tools_cache: use pool.close_all_sync() instead of asyncio.run-in-thread to close sessions deterministically - test: add test_session_pool_tool_sync_wrapper_path_is_safe covering tool invocation via the sync wrapper (tool.func) path Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/9e7f9e7f-1d2b-464a-b3b7-7f1649b74122 Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com> * fix(mcp): extract SESSION_CLOSE_TIMEOUT to class constant Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/9e7f9e7f-1d2b-464a-b3b7-7f1649b74122 Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com> * Potential fix for pull request finding 'Empty except' Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> --- .../packages/harness/deerflow/mcp/cache.py | 16 + .../harness/deerflow/mcp/session_pool.py | 198 +++++++++ .../packages/harness/deerflow/mcp/tools.py | 198 ++++++++- backend/tests/test_mcp_session_pool.py | 409 ++++++++++++++++++ 4 files changed, 813 insertions(+), 8 deletions(-) create mode 100644 backend/packages/harness/deerflow/mcp/session_pool.py create mode 100644 backend/tests/test_mcp_session_pool.py diff --git a/backend/packages/harness/deerflow/mcp/cache.py b/backend/packages/harness/deerflow/mcp/cache.py index c1121f59d..f04fe0054 100644 --- a/backend/packages/harness/deerflow/mcp/cache.py +++ b/backend/packages/harness/deerflow/mcp/cache.py @@ -134,9 +134,25 @@ def reset_mcp_tools_cache() -> None: """Reset the MCP tools cache. This is useful for testing or when you want to reload MCP tools. + Also closes all persistent MCP sessions so they are recreated on + the next tool load. """ global _mcp_tools_cache, _cache_initialized, _config_mtime _mcp_tools_cache = None _cache_initialized = False _config_mtime = None + + # Close persistent sessions – they will be recreated by the next + # get_mcp_tools() call with the (possibly updated) connection config. + try: + from deerflow.mcp.session_pool import get_session_pool + + pool = get_session_pool() + pool.close_all_sync() + except Exception: + logger.debug("Could not close MCP session pool on cache reset", exc_info=True) + + from deerflow.mcp.session_pool import reset_session_pool + + reset_session_pool() logger.info("MCP tools cache reset") diff --git a/backend/packages/harness/deerflow/mcp/session_pool.py b/backend/packages/harness/deerflow/mcp/session_pool.py new file mode 100644 index 000000000..8450cac8e --- /dev/null +++ b/backend/packages/harness/deerflow/mcp/session_pool.py @@ -0,0 +1,198 @@ +"""Persistent MCP session pool for stateful tool calls. + +When MCP tools are loaded via langchain-mcp-adapters with ``session=None``, +each tool call creates a new MCP session. For stateful servers like Playwright, +this means browser state (opened pages, filled forms) is lost between calls. + +This module provides a session pool that maintains persistent MCP sessions, +scoped by ``(server_name, scope_key)`` — typically scope_key is the thread_id — +so that consecutive tool calls share the same session and server-side state. +Sessions are evicted in LRU order when the pool reaches capacity. +""" + +from __future__ import annotations + +import asyncio +import logging +import threading +from collections import OrderedDict +from typing import Any + +from mcp import ClientSession + +logger = logging.getLogger(__name__) + + +class MCPSessionPool: + """Manages persistent MCP sessions scoped by ``(server_name, scope_key)``.""" + + MAX_SESSIONS = 256 + SESSION_CLOSE_TIMEOUT = 5.0 # seconds to wait when closing a session via run_coroutine_threadsafe + + def __init__(self) -> None: + self._entries: OrderedDict[ + tuple[str, str], + tuple[ClientSession, asyncio.AbstractEventLoop], + ] = OrderedDict() + self._context_managers: dict[tuple[str, str], Any] = {} + # threading.Lock is not bound to any event loop, so it is safe to + # acquire from both async paths and sync/worker-thread paths. + self._lock = threading.Lock() + + async def get_session( + self, + server_name: str, + scope_key: str, + connection: dict[str, Any], + ) -> ClientSession: + """Get or create a persistent MCP session. + + If an existing session was created in a different event loop (e.g. + the sync-wrapper path), it is closed and replaced with a fresh one + in the current loop. + + Args: + server_name: MCP server name. + scope_key: Isolation key (typically thread_id). + connection: Connection configuration for ``create_session``. + + Returns: + An initialized ``ClientSession``. + """ + key = (server_name, scope_key) + current_loop = asyncio.get_running_loop() + + # Phase 1: inspect/mutate the registry under the thread lock (no awaits). + cms_to_close: list[tuple[tuple[str, str], Any]] = [] + with self._lock: + if key in self._entries: + session, loop = self._entries[key] + if loop is current_loop: + self._entries.move_to_end(key) + return session + # Session belongs to a different event loop – evict it. + cm = self._context_managers.pop(key, None) + self._entries.pop(key) + if cm is not None: + cms_to_close.append((key, cm)) + + # Evict LRU entries when at capacity. + while len(self._entries) >= self.MAX_SESSIONS: + oldest_key = next(iter(self._entries)) + cm = self._context_managers.pop(oldest_key, None) + self._entries.pop(oldest_key) + if cm is not None: + cms_to_close.append((oldest_key, cm)) + + # Phase 2: async cleanup outside the lock so we never await while holding it. + for close_key, cm in cms_to_close: + try: + await cm.__aexit__(None, None, None) + except Exception: + logger.warning("Error closing MCP session %s", close_key, exc_info=True) + + from langchain_mcp_adapters.sessions import create_session + + cm = create_session(connection) + session = await cm.__aenter__() + await session.initialize() + + # Phase 3: register the new session under the lock. + with self._lock: + self._entries[key] = (session, current_loop) + self._context_managers[key] = cm + logger.info("Created persistent MCP session for %s/%s", server_name, scope_key) + return session + + # ------------------------------------------------------------------ + # Cleanup helpers + # ------------------------------------------------------------------ + + async def _close_cm(self, key: tuple[str, str], cm: Any) -> None: + """Close a single context manager (must be called WITHOUT the lock).""" + try: + await cm.__aexit__(None, None, None) + except Exception: + logger.warning("Error closing MCP session %s", key, exc_info=True) + + async def close_scope(self, scope_key: str) -> None: + """Close all sessions for a given scope (e.g. thread_id).""" + with self._lock: + keys = [k for k in self._entries if k[1] == scope_key] + cms = [(k, self._context_managers.pop(k, None)) for k in keys] + for k in keys: + self._entries.pop(k, None) + for key, cm in cms: + if cm is not None: + await self._close_cm(key, cm) + + async def close_server(self, server_name: str) -> None: + """Close all sessions for a given server.""" + with self._lock: + keys = [k for k in self._entries if k[0] == server_name] + cms = [(k, self._context_managers.pop(k, None)) for k in keys] + for k in keys: + self._entries.pop(k, None) + for key, cm in cms: + if cm is not None: + await self._close_cm(key, cm) + + async def close_all(self) -> None: + """Close every managed session.""" + with self._lock: + cms = list(self._context_managers.items()) + self._context_managers.clear() + self._entries.clear() + for key, cm in cms: + await self._close_cm(key, cm) + + def close_all_sync(self) -> None: + """Close all sessions using their owning event loops (synchronous). + + Each session is closed on the loop it was created in, avoiding + cross-loop resource leaks. Safe to call from any thread without an + active event loop. + """ + with self._lock: + entries = list(self._entries.items()) + cms = dict(self._context_managers) + self._entries.clear() + self._context_managers.clear() + + for key, (_, loop) in entries: + cm = cms.get(key) + if cm is None or loop.is_closed(): + continue + try: + if loop.is_running(): + # Schedule on the owning loop from this (different) thread. + future = asyncio.run_coroutine_threadsafe(cm.__aexit__(None, None, None), loop) + future.result(timeout=self.SESSION_CLOSE_TIMEOUT) + else: + loop.run_until_complete(cm.__aexit__(None, None, None)) + except Exception: + logger.debug("Error closing MCP session %s during sync close", key, exc_info=True) + + +# ------------------------------------------------------------------ +# Module-level singleton +# ------------------------------------------------------------------ + +_pool: MCPSessionPool | None = None +_pool_lock = threading.Lock() + + +def get_session_pool() -> MCPSessionPool: + """Return the global session-pool singleton.""" + global _pool + if _pool is None: + with _pool_lock: + if _pool is None: + _pool = MCPSessionPool() + return _pool + + +def reset_session_pool() -> None: + """Reset the singleton (for tests).""" + global _pool + _pool = None diff --git a/backend/packages/harness/deerflow/mcp/tools.py b/backend/packages/harness/deerflow/mcp/tools.py index d27641692..d08e7efd6 100644 --- a/backend/packages/harness/deerflow/mcp/tools.py +++ b/backend/packages/harness/deerflow/mcp/tools.py @@ -1,21 +1,181 @@ -"""Load MCP tools using langchain-mcp-adapters.""" +"""Load MCP tools using langchain-mcp-adapters with persistent sessions.""" + +from __future__ import annotations import logging +from typing import Any -from langchain_core.tools import BaseTool +from langchain_core.tools import BaseTool, StructuredTool +from langgraph.config import get_config from deerflow.config.extensions_config import ExtensionsConfig from deerflow.mcp.client import build_servers_config from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers +from deerflow.mcp.session_pool import get_session_pool from deerflow.reflection import resolve_variable from deerflow.tools.sync import make_sync_tool_wrapper +from deerflow.tools.types import Runtime logger = logging.getLogger(__name__) +def _extract_thread_id(runtime: Runtime | None) -> str: + """Extract thread_id from the injected tool runtime or LangGraph config.""" + if runtime is not None: + tid = runtime.context.get("thread_id") if runtime.context else None + if tid is not None: + return str(tid) + config = runtime.config or {} + tid = config.get("configurable", {}).get("thread_id") + if tid is not None: + return str(tid) + + try: + tid = get_config().get("configurable", {}).get("thread_id") + return str(tid) if tid is not None else "default" + except RuntimeError: + return "default" + + +def _convert_call_tool_result(call_tool_result: Any) -> Any: + """Convert an MCP CallToolResult to the LangChain ``content_and_artifact`` format. + + Implements the same conversion logic as the adapter without relying on + the private ``langchain_mcp_adapters.tools._convert_call_tool_result`` symbol. + """ + from langchain_core.messages import ToolMessage + from langchain_core.messages.content import create_file_block, create_image_block, create_text_block + from langchain_core.tools import ToolException + from mcp.types import EmbeddedResource, ImageContent, ResourceLink, TextContent, TextResourceContents + + # Pass ToolMessage through directly (interceptor short-circuit). + if isinstance(call_tool_result, ToolMessage): + return call_tool_result, None + + # Pass LangGraph Command through directly when langgraph is installed. + try: + from langgraph.types import Command + + if isinstance(call_tool_result, Command): + return call_tool_result, None + except ImportError: + # langgraph is optional; if unavailable, continue with standard MCP content conversion. + pass + + # Convert MCP content blocks to LangChain content blocks. + lc_content = [] + for item in call_tool_result.content: + if isinstance(item, TextContent): + lc_content.append(create_text_block(text=item.text)) + elif isinstance(item, ImageContent): + lc_content.append(create_image_block(base64=item.data, mime_type=item.mimeType)) + elif isinstance(item, ResourceLink): + mime = item.mimeType or None + if mime and mime.startswith("image/"): + lc_content.append(create_image_block(url=str(item.uri), mime_type=mime)) + else: + lc_content.append(create_file_block(url=str(item.uri), mime_type=mime)) + elif isinstance(item, EmbeddedResource): + from mcp.types import BlobResourceContents + + res = item.resource + if isinstance(res, TextResourceContents): + lc_content.append(create_text_block(text=res.text)) + elif isinstance(res, BlobResourceContents): + mime = res.mimeType or None + if mime and mime.startswith("image/"): + lc_content.append(create_image_block(base64=res.blob, mime_type=mime)) + else: + lc_content.append(create_file_block(base64=res.blob, mime_type=mime)) + else: + lc_content.append(create_text_block(text=str(res))) + else: + lc_content.append(create_text_block(text=str(item))) + + if call_tool_result.isError: + error_parts = [item["text"] for item in lc_content if isinstance(item, dict) and item.get("type") == "text"] + raise ToolException("\n".join(error_parts) if error_parts else str(lc_content)) + + artifact = None + if call_tool_result.structuredContent is not None: + artifact = {"structured_content": call_tool_result.structuredContent} + + return lc_content, artifact + + +def _make_session_pool_tool( + tool: BaseTool, + server_name: str, + connection: dict[str, Any], + tool_interceptors: list[Any] | None = None, +) -> BaseTool: + """Wrap an MCP tool so it reuses a persistent session from the pool. + + Replaces the per-call session creation with pool-managed sessions scoped + by ``(server_name, thread_id)``. This ensures stateful MCP servers (e.g. + Playwright) keep their state across tool calls within the same thread. + + The configured ``tool_interceptors`` (OAuth, custom) are preserved and + applied on every call before invoking the pooled session. + """ + # Strip the server-name prefix to recover the original MCP tool name. + original_name = tool.name + prefix = f"{server_name}_" + if original_name.startswith(prefix): + original_name = original_name[len(prefix) :] + + pool = get_session_pool() + + async def call_with_persistent_session( + runtime: Runtime | None = None, + **arguments: Any, + ) -> Any: + thread_id = _extract_thread_id(runtime) + session = await pool.get_session(server_name, thread_id, connection) + + if tool_interceptors: + from langchain_mcp_adapters.interceptors import MCPToolCallRequest + + async def base_handler(request: MCPToolCallRequest) -> Any: + return await session.call_tool(request.name, request.args) + + handler = base_handler + for interceptor in reversed(tool_interceptors): + outer = handler + + async def wrapped(req: Any, _i: Any = interceptor, _h: Any = outer) -> Any: + return await _i(req, _h) + + handler = wrapped + + request = MCPToolCallRequest( + name=original_name, + args=arguments, + server_name=server_name, + runtime=runtime, + ) + call_tool_result = await handler(request) + else: + call_tool_result = await session.call_tool(original_name, arguments) + + return _convert_call_tool_result(call_tool_result) + + return StructuredTool( + name=tool.name, + description=tool.description, + args_schema=tool.args_schema, + coroutine=call_with_persistent_session, + response_format="content_and_artifact", + metadata=tool.metadata, + ) + + async def get_mcp_tools() -> list[BaseTool]: """Get all tools from enabled MCP servers. + Tools are wrapped with persistent-session logic so that consecutive + calls within the same thread reuse the same MCP session. + Returns: List of LangChain tools from all enabled MCP servers. """ @@ -50,7 +210,7 @@ async def get_mcp_tools() -> list[BaseTool]: existing_headers["Authorization"] = auth_header servers_config[server_name]["headers"] = existing_headers - tool_interceptors = [] + tool_interceptors: list[Any] = [] oauth_interceptor = build_oauth_tool_interceptor(extensions_config) if oauth_interceptor is not None: tool_interceptors.append(oauth_interceptor) @@ -74,20 +234,42 @@ async def get_mcp_tools() -> list[BaseTool]: elif interceptor is not None: logger.warning(f"Builder {interceptor_path} returned non-callable {type(interceptor).__name__}; skipping") except Exception as e: - logger.warning(f"Failed to load MCP interceptor {interceptor_path}: {e}", exc_info=True) + logger.warning( + f"Failed to load MCP interceptor {interceptor_path}: {e}", + exc_info=True, + ) - client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True) + client = MultiServerMCPClient( + servers_config, + tool_interceptors=tool_interceptors, + tool_name_prefix=True, + ) - # Get all tools from all servers + # Get all tools from all servers (discovers tool definitions via + # temporary sessions – the persistent-session wrapping is applied below). tools = await client.get_tools() logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers") - # Patch tools to support sync invocation, as deerflow client streams synchronously + # Wrap each tool with persistent-session logic. + wrapped_tools: list[BaseTool] = [] for tool in tools: + tool_server: str | None = None + for name in servers_config: + if tool.name.startswith(f"{name}_"): + tool_server = name + break + + if tool_server is not None: + wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors)) + else: + wrapped_tools.append(tool) + + # Patch tools to support sync invocation, as deerflow client streams synchronously + for tool in wrapped_tools: if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None: tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name) - return tools + return wrapped_tools except Exception as e: logger.error(f"Failed to load MCP tools: {e}", exc_info=True) diff --git a/backend/tests/test_mcp_session_pool.py b/backend/tests/test_mcp_session_pool.py new file mode 100644 index 000000000..822ad2e81 --- /dev/null +++ b/backend/tests/test_mcp_session_pool.py @@ -0,0 +1,409 @@ +"""Tests for the MCP persistent-session pool.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from deerflow.mcp.session_pool import MCPSessionPool, get_session_pool, reset_session_pool + + +@pytest.fixture(autouse=True) +def _reset_pool(): + reset_session_pool() + yield + reset_session_pool() + + +# --------------------------------------------------------------------------- +# MCPSessionPool unit tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_session_creates_new(): + """First call for a key creates a new session.""" + pool = MCPSessionPool() + + mock_session = AsyncMock() + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + session = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []}) + + assert session is mock_session + mock_session.initialize.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_session_reuses_existing(): + """Second call for the same key returns the cached session.""" + pool = MCPSessionPool() + + mock_session = AsyncMock() + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []}) + s2 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []}) + + assert s1 is s2 + # Only one session should have been created. + assert mock_cm.__aenter__.await_count == 1 + + +@pytest.mark.asyncio +async def test_different_scope_creates_different_session(): + """Different scope keys get different sessions.""" + pool = MCPSessionPool() + + sessions = [AsyncMock(), AsyncMock()] + idx = 0 + + class CmFactory: + def __init__(self): + self.enter_count = 0 + + async def __aenter__(self): + nonlocal idx + s = sessions[idx] + idx += 1 + self.enter_count += 1 + return s + + async def __aexit__(self, *args): + return False + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=lambda *a, **kw: CmFactory()): + s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []}) + s2 = await pool.get_session("server", "thread-2", {"transport": "stdio", "command": "x", "args": []}) + + assert s1 is not s2 + assert s1 is sessions[0] + assert s2 is sessions[1] + + +@pytest.mark.asyncio +async def test_lru_eviction(): + """Oldest entries are evicted when the pool is full.""" + pool = MCPSessionPool() + pool.MAX_SESSIONS = 2 + + class CmFactory: + def __init__(self): + self.closed = False + + async def __aenter__(self): + return AsyncMock() + + async def __aexit__(self, *args): + self.closed = True + return False + + cms: list[CmFactory] = [] + + def make_cm(*a, **kw): + cm = CmFactory() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []}) + await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []}) + # Pool is full (2). Adding t3 should evict t1. + await pool.get_session("s", "t3", {"transport": "stdio", "command": "x", "args": []}) + + assert cms[0].closed is True + assert cms[1].closed is False + assert cms[2].closed is False + + +@pytest.mark.asyncio +async def test_close_scope(): + """close_scope shuts down sessions for a specific scope key.""" + pool = MCPSessionPool() + + class CmFactory: + def __init__(self): + self.closed = False + + async def __aenter__(self): + return AsyncMock() + + async def __aexit__(self, *args): + self.closed = True + return False + + cms: list[CmFactory] = [] + + def make_cm(*a, **kw): + cm = CmFactory() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []}) + await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []}) + + await pool.close_scope("t1") + + assert cms[0].closed is True + assert cms[1].closed is False + + # t2 session still exists. + assert ("s", "t2") in pool._entries + + +@pytest.mark.asyncio +async def test_close_all(): + """close_all shuts down every session.""" + pool = MCPSessionPool() + + class CmFactory: + def __init__(self): + self.closed = False + + async def __aenter__(self): + return AsyncMock() + + async def __aexit__(self, *args): + self.closed = True + return False + + cms: list[CmFactory] = [] + + def make_cm(*a, **kw): + cm = CmFactory() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + await pool.get_session("s1", "t1", {"transport": "stdio", "command": "x", "args": []}) + await pool.get_session("s2", "t2", {"transport": "stdio", "command": "x", "args": []}) + + await pool.close_all() + + assert all(cm.closed for cm in cms) + assert len(pool._entries) == 0 + + +# --------------------------------------------------------------------------- +# Singleton helpers +# --------------------------------------------------------------------------- + + +def test_get_session_pool_singleton(): + """get_session_pool returns the same instance.""" + p1 = get_session_pool() + p2 = get_session_pool() + assert p1 is p2 + + +def test_reset_session_pool(): + """reset_session_pool clears the singleton.""" + p1 = get_session_pool() + reset_session_pool() + p2 = get_session_pool() + assert p1 is not p2 + + +# --------------------------------------------------------------------------- +# Integration: _make_session_pool_tool uses the pool +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_session_pool_tool_wrapping(): + """The wrapper tool delegates to a pool-managed session.""" + # Build a dummy StructuredTool (as returned by langchain-mcp-adapters). + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + + class Args(BaseModel): + url: str = Field(..., description="url") + + original_tool = StructuredTool( + name="playwright_navigate", + description="Navigate browser", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + connection = {"transport": "stdio", "command": "pw", "args": []} + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + wrapped = _make_session_pool_tool(original_tool, "playwright", connection) + + # Simulate a tool call with a runtime context containing thread_id. + mock_runtime = MagicMock() + mock_runtime.context = {"thread_id": "thread-42"} + mock_runtime.config = {} + + await wrapped.coroutine(runtime=mock_runtime, url="https://example.com") + + mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"}) + + +@pytest.mark.asyncio +async def test_session_pool_tool_extracts_thread_id(): + """Thread ID is extracted from runtime.config when not in context.""" + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + + class Args(BaseModel): + x: int = Field(..., description="x") + + original_tool = StructuredTool( + name="server_tool", + description="test", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []}) + + mock_runtime = MagicMock() + mock_runtime.context = {} + mock_runtime.config = {"configurable": {"thread_id": "from-config"}} + + await wrapped.coroutine(runtime=mock_runtime, x=1) + + # Verify the session was created with the correct scope key. + pool = get_session_pool() + assert ("server", "from-config") in pool._entries + + +@pytest.mark.asyncio +async def test_session_pool_tool_default_scope(): + """When no thread_id is available, 'default' is used as scope key.""" + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + + class Args(BaseModel): + x: int = Field(..., description="x") + + original_tool = StructuredTool( + name="server_tool", + description="test", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []}) + + # No thread_id in runtime at all. + await wrapped.coroutine(runtime=None, x=1) + + pool = get_session_pool() + assert ("server", "default") in pool._entries + + +@pytest.mark.asyncio +async def test_session_pool_tool_get_config_fallback(): + """When runtime is None, get_config() provides thread_id as fallback.""" + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + + class Args(BaseModel): + x: int = Field(..., description="x") + + original_tool = StructuredTool( + name="server_tool", + description="test", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + fake_config = {"configurable": {"thread_id": "from-langgraph-config"}} + + with ( + patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm), + patch("deerflow.mcp.tools.get_config", return_value=fake_config), + ): + wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []}) + + # runtime=None — get_config() fallback should provide thread_id + await wrapped.coroutine(runtime=None, x=1) + + pool = get_session_pool() + assert ("server", "from-langgraph-config") in pool._entries + + +def test_session_pool_tool_sync_wrapper_path_is_safe(): + """Sync wrapper (tool.func) invocation doesn't crash on cross-loop access.""" + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + from deerflow.tools.sync import make_sync_tool_wrapper + + class Args(BaseModel): + url: str = Field(..., description="url") + + original_tool = StructuredTool( + name="playwright_navigate", + description="Navigate browser", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + connection = {"transport": "stdio", "command": "pw", "args": []} + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + wrapped = _make_session_pool_tool(original_tool, "playwright", connection) + # Attach the sync wrapper exactly as get_mcp_tools() does. + wrapped.func = make_sync_tool_wrapper(wrapped.coroutine, wrapped.name) + + # Call via the sync path (asyncio.run in a worker thread). + # runtime is not supplied so _extract_thread_id falls back to "default". + wrapped.func(url="https://example.com") + + mock_session.call_tool.assert_called_once_with("navigate", {"url": "https://example.com"}) From 253542ea0de6e6a038170a03f13fd2e2f7e32164 Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Fri, 22 May 2026 03:19:23 +0200 Subject: [PATCH 66/86] docs: discourage MCP filesystem workspace config (#3141) --- backend/docs/API.md | 7 ------- backend/docs/MCP_SERVER.md | 16 ++++++++++++++-- extensions_config.example.json | 14 +------------- frontend/src/content/en/harness/mcp.mdx | 15 ++++++++++----- .../src/content/zh/application/configuration.mdx | 12 ++++++++++-- frontend/src/content/zh/harness/mcp.mdx | 14 +++++++++----- 6 files changed, 44 insertions(+), 34 deletions(-) diff --git a/backend/docs/API.md b/backend/docs/API.md index 762a135c4..10ea99858 100644 --- a/backend/docs/API.md +++ b/backend/docs/API.md @@ -241,13 +241,6 @@ GET /api/mcp/config "GITHUB_TOKEN": "***" }, "description": "GitHub operations" - }, - "filesystem": { - "enabled": false, - "type": "stdio", - "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-filesystem"], - "description": "File system access" } } } diff --git a/backend/docs/MCP_SERVER.md b/backend/docs/MCP_SERVER.md index b7320f8cc..ba5ccd769 100644 --- a/backend/docs/MCP_SERVER.md +++ b/backend/docs/MCP_SERVER.md @@ -14,6 +14,19 @@ DeerFlow supports configurable MCP servers and skills to extend its capabilities 3. Configure each server’s command, arguments, and environment variables as needed. 4. Restart the application to load and register MCP tools. +## Filesystem MCP Servers + +DeerFlow already provides built-in file tools for thread-scoped workspace access. +Do not add an MCP filesystem server for the same DeerFlow workspace. The +overlapping file tools use different path semantics, which can make LLM tool +selection and file access behavior unstable. + +DeerFlow does not currently adapt the MCP Roots mode for filesystem servers. In +particular, it does not publish per-thread MCP roots or map DeerFlow sandbox +paths such as `/mnt/user-data/...` to paths accepted by +`@modelcontextprotocol/server-filesystem`. Use DeerFlow's built-in file tools +for DeerFlow workspace files. + ## OAuth Support (HTTP/SSE MCP Servers) For `http` and `sse` MCP servers, DeerFlow supports OAuth token acquisition and automatic token refresh. @@ -88,7 +101,6 @@ MCP servers expose tools that are automatically discovered and integrated into D MCP servers can provide access to: -- **File systems** - **Databases** (e.g., PostgreSQL) - **External APIs** (e.g., GitHub, Brave Search) - **Browser automation** (e.g., Puppeteer) @@ -97,4 +109,4 @@ MCP servers can provide access to: ## Learn More For detailed documentation about the Model Context Protocol, visit: -https://modelcontextprotocol.io \ No newline at end of file +https://modelcontextprotocol.io diff --git a/extensions_config.example.json b/extensions_config.example.json index 118c5d6db..7c0dce740 100644 --- a/extensions_config.example.json +++ b/extensions_config.example.json @@ -3,18 +3,6 @@ "my_package.mcp.auth:build_auth_interceptor" ], "mcpServers": { - "filesystem": { - "enabled": false, - "type": "stdio", - "command": "npx", - "args": [ - "-y", - "@modelcontextprotocol/server-filesystem", - "/path/to/allowed/files" - ], - "env": {}, - "description": "Provides filesystem access within allowed directories" - }, "github": { "enabled": false, "type": "stdio", @@ -42,4 +30,4 @@ } }, "skills": {} -} \ No newline at end of file +} diff --git a/frontend/src/content/en/harness/mcp.mdx b/frontend/src/content/en/harness/mcp.mdx index 0e43aa235..53e5ca274 100644 --- a/frontend/src/content/en/harness/mcp.mdx +++ b/frontend/src/content/en/harness/mcp.mdx @@ -29,11 +29,6 @@ The default location is the project root (same directory as `config.yaml`). The "args": ["-y", "@my-org/my-mcp-server"], "enabled": true }, - "filesystem": { - "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/dir"], - "enabled": true - }, "sqlite": { "command": "uvx", "args": ["mcp-server-sqlite", "--db-path", "/path/to/db.sqlite"], @@ -43,6 +38,16 @@ The default location is the project root (same directory as `config.yaml`). The } ``` + + Do not add an MCP filesystem server for DeerFlow workspace files. DeerFlow + already provides built-in file tools for thread-scoped workspace access, and + overlapping file tools with different path semantics can make LLM tool + selection and file access behavior unstable. DeerFlow does not currently + adapt MCP Roots mode for filesystem servers: it does not publish per-thread + MCP roots or map sandbox paths such as /mnt/user-data/... to + paths accepted by @modelcontextprotocol/server-filesystem. + + Each server entry supports: - `command`: the executable to run (e.g., `npx`, `uvx`, `python`) diff --git a/frontend/src/content/zh/application/configuration.mdx b/frontend/src/content/zh/application/configuration.mdx index 0094323e7..94e78120c 100644 --- a/frontend/src/content/zh/application/configuration.mdx +++ b/frontend/src/content/zh/application/configuration.mdx @@ -193,15 +193,23 @@ BETTER_AUTH_SECRET=local-dev-secret-at-least-32-chars ```json { "mcpServers": { - "filesystem": { + "my-server": { "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/dir"], + "args": ["-y", "@my-org/my-mcp-server"], "enabled": true } } } ``` + + 不要为 DeerFlow 工作区文件引入 MCP filesystem server。它会与 DeerFlow + 内置文件工具形成路径语义不同的重复能力,让 LLM 行为不稳定。DeerFlow + 当前没有为 filesystem server 适配 MCP Roots 模式,也不会把{" "} + /mnt/user-data/... 这类沙箱路径映射成{" "} + @modelcontextprotocol/server-filesystem 可接受的路径。 + + ### 技能启用状态 技能启用状态会反映在 `extensions_config.json` 中。你可以直接编辑它,或通过 DeerFlow 应用界面进行管理。 diff --git a/frontend/src/content/zh/harness/mcp.mdx b/frontend/src/content/zh/harness/mcp.mdx index 4bf72c1a6..0b076aff8 100644 --- a/frontend/src/content/zh/harness/mcp.mdx +++ b/frontend/src/content/zh/harness/mcp.mdx @@ -28,11 +28,6 @@ MCP 服务器在 `extensions_config.json` 中配置,这个文件独立于 `con "args": ["-y", "@my-org/my-mcp-server"], "enabled": true }, - "filesystem": { - "command": "npx", - "args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/dir"], - "enabled": true - }, "sqlite": { "command": "uvx", "args": ["mcp-server-sqlite", "--db-path", "/path/to/db.sqlite"], @@ -42,6 +37,15 @@ MCP 服务器在 `extensions_config.json` 中配置,这个文件独立于 `con } ``` + + 不要为 DeerFlow 工作区文件引入 MCP filesystem server。DeerFlow 已提供按 + thread 隔离的内置文件工具;重复引入路径语义不同的文件工具,会让 LLM + 的工具选择和文件访问行为不稳定。DeerFlow 当前没有为 filesystem server + 适配 MCP Roots 模式:不会发布按 thread 收窄的 MCP roots,也不会把{" "} + /mnt/user-data/... 这类沙箱路径映射成{" "} + @modelcontextprotocol/server-filesystem 可接受的路径。 + + 每个服务器条目支持: - `command`:要运行的可执行文件(如 `npx`、`uvx`、`python`) From be0eae9825619b63ca0c67b253d40b5eee76a2d6 Mon Sep 17 00:00:00 2001 From: Xinmin Zeng <135568692+fancyboi999@users.noreply.github.com> Date: Fri, 22 May 2026 21:20:28 +0800 Subject: [PATCH 67/86] fix(runtime): suppress tool execution when provider safety-terminates with tool_calls (#3035) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(runtime): suppress tool execution when provider safety-terminates with tool_calls When a provider stops generation for safety reasons (OpenAI/Moonshot finish_reason=content_filter, Anthropic stop_reason=refusal, Gemini finish_reason=SAFETY/BLOCKLIST/PROHIBITED_CONTENT/SPII/RECITATION/ IMAGE_SAFETY/...), the response may still carry truncated tool_calls. LangChain's tool router treats any non-empty tool_calls as executable, so partial arguments (e.g. write_file with a half-finished markdown) get dispatched and the agent loops on retry. Add SafetyFinishReasonMiddleware at after_model: detect safety termination via a pluggable detector registry, clear both structured tool_calls and raw additional_kwargs.tool_calls / function_call, preserve response_metadata.finish_reason for downstream observers, stamp additional_kwargs.safety_termination for traces, append a user-facing explanation to message content (list-aware for thinking blocks), and emit a safety_termination custom stream event so SSE consumers can reconcile any "tool starting..." UI. Default detectors cover OpenAI-compatible content_filter, Anthropic refusal, and Gemini safety enums (text + image). Custom providers are added via reflection (same pattern as guardrails). Wired into both lead-agent and subagent runtimes. Closes #3028 * fix(runtime): persist safety_termination as a middleware audit event Address review on #3035: the SSE custom event is great for live consumers but invisible to post-run audit. RunEventStore should carry its own row so operators can answer "which runs were safety-suppressed today?" from a single SQL query without joining the message body. Worker now exposes the run-scoped RunJournal via runtime.context["__run_journal"] (sentinel key, internal channel). SafetyFinishReasonMiddleware calls the previously-unused RunJournal.record_middleware, which emits event_type = "middleware:safety_termination" category = "middleware" content = {name, hook, action, changes={ detector, reason_field, reason_value, suppressed_tool_call_count, suppressed_tool_call_names, suppressed_tool_call_ids, message_id, extras}} Tool *arguments* are deliberately excluded — those are the very content the provider filtered and persisting them would defeat the purpose of the safety filter (per review note in #3035). Graceful skips when journal is absent (subagent runtime, unit tests, no-event-store local dev). Journal exceptions never propagate into the agent loop. Refs #3028 * fix(runtime): satisfy ruff format + address Copilot review - ruff format on safety_finish_reason_config.py and e2e demo (CI lint failed on ruff format --check; backend Makefile lint target runs ruff check AND ruff format --check). - Docstring on SafetyFinishReasonConfig now says resolve_variable to match the actual loader used in from_config (the wording was resolve_class previously; behavior is unchanged — resolve_variable mirrors how guardrails.provider is loaded). - Switch the AIMessage type check in SafetyFinishReasonMiddleware._apply from getattr(last, "type") == "ai" to isinstance(last, AIMessage), matching TokenUsageMiddleware / TodoMiddleware / ViewImageMiddleware / SummarizationMiddleware which are the dominant pattern. Refs #3028 --- .../deerflow/agents/lead_agent/agent.py | 10 + .../safety_finish_reason_middleware.py | 317 +++++++++ .../safety_termination_detectors.py | 237 +++++++ .../tool_error_handling_middleware.py | 10 + .../harness/deerflow/config/app_config.py | 2 + .../config/safety_finish_reason_config.py | 47 ++ .../harness/deerflow/runtime/runs/worker.py | 6 + .../scripts/e2e_safety_termination_demo.py | 206 ++++++ .../tests/test_lead_agent_model_resolution.py | 7 +- ..._safety_finish_reason_graph_integration.py | 225 ++++++ .../test_safety_finish_reason_middleware.py | 651 ++++++++++++++++++ .../test_safety_termination_detectors.py | 176 +++++ .../test_tool_error_handling_middleware.py | 10 +- config.example.yaml | 37 +- 14 files changed, 1936 insertions(+), 5 deletions(-) create mode 100644 backend/packages/harness/deerflow/agents/middlewares/safety_finish_reason_middleware.py create mode 100644 backend/packages/harness/deerflow/agents/middlewares/safety_termination_detectors.py create mode 100644 backend/packages/harness/deerflow/config/safety_finish_reason_config.py create mode 100644 backend/scripts/e2e_safety_termination_demo.py create mode 100644 backend/tests/test_safety_finish_reason_graph_integration.py create mode 100644 backend/tests/test_safety_finish_reason_middleware.py create mode 100644 backend/tests/test_safety_termination_detectors.py diff --git a/backend/packages/harness/deerflow/agents/lead_agent/agent.py b/backend/packages/harness/deerflow/agents/lead_agent/agent.py index 328a8a6e1..e03ff33ad 100644 --- a/backend/packages/harness/deerflow/agents/lead_agent/agent.py +++ b/backend/packages/harness/deerflow/agents/lead_agent/agent.py @@ -29,6 +29,7 @@ from deerflow.agents.memory.summarization_hook import memory_flush_hook from deerflow.agents.middlewares.clarification_middleware import ClarificationMiddleware from deerflow.agents.middlewares.loop_detection_middleware import LoopDetectionMiddleware from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware +from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware from deerflow.agents.middlewares.subagent_limit_middleware import SubagentLimitMiddleware from deerflow.agents.middlewares.summarization_middleware import BeforeSummarizationHook, DeerFlowSummarizationMiddleware from deerflow.agents.middlewares.title_middleware import TitleMiddleware @@ -338,6 +339,15 @@ def _build_middlewares( if custom_middlewares: middlewares.extend(custom_middlewares) + # SafetyFinishReasonMiddleware — suppress tool execution when the provider + # safety-terminated the response. Registered after custom middlewares so + # that LangChain's reverse-order after_model dispatch runs Safety first; + # cleared tool_calls then flow through Loop/Subagent accounting without + # firing extra alarms. See safety_finish_reason_middleware.py docstring. + safety_config = resolved_app_config.safety_finish_reason + if safety_config.enabled: + middlewares.append(SafetyFinishReasonMiddleware.from_config(safety_config)) + # ClarificationMiddleware should always be last middlewares.append(ClarificationMiddleware()) return middlewares diff --git a/backend/packages/harness/deerflow/agents/middlewares/safety_finish_reason_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/safety_finish_reason_middleware.py new file mode 100644 index 000000000..8fd733c23 --- /dev/null +++ b/backend/packages/harness/deerflow/agents/middlewares/safety_finish_reason_middleware.py @@ -0,0 +1,317 @@ +"""Suppress tool execution when the provider safety-terminated the response. + +Background — see issue bytedance/deer-flow#3028. + +Some providers (OpenAI ``finish_reason='content_filter'``, Anthropic +``stop_reason='refusal'``, Gemini ``finish_reason='SAFETY'`` ...) can stop +generation mid-stream while still returning partially-formed ``tool_calls``. +LangChain's tool router treats any AIMessage with a non-empty ``tool_calls`` +field as "go execute these", so half-truncated arguments — e.g. a markdown +``write_file`` that stops in the middle of a sentence — get dispatched as if +they were complete. The agent then sees the truncated file, tries to fix it, +gets filtered again, and loops. + +This middleware sits at ``after_model`` and gates that behaviour: when a +configured ``SafetyTerminationDetector`` fires *and* the AIMessage carries +tool calls, we strip the tool calls (both structured and raw provider +payloads), append a user-facing explanation, and stash observability fields +in ``additional_kwargs.safety_termination`` so logs, traces, and SSE +consumers can see what happened. + +Hook choice: ``after_model`` (not ``wrap_model_call``) because the response +is a *normal* return — not an exception — and we want to participate in the +same after-model chain as ``LoopDetectionMiddleware``, with which we share +the same tool-call-suppression mechanic but a different trigger. + +Placement: register *after* ``LoopDetectionMiddleware`` in the middleware +list. LangChain factory wires ``after_model`` edges in reverse list order +(``langchain/agents/factory.py:add_edge("model", middleware_w_after_model[-1])``, +then walks ``range(len-1, 0, -1)``), so the *last* registered middleware is +the *first* to observe the model output. Registering Safety after Loop +means Safety sees the raw response first, clears tool calls if it fires, +and Loop then accounts against the cleaned message. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, override + +from langchain.agents import AgentState +from langchain.agents.middleware import AgentMiddleware +from langchain_core.messages import AIMessage +from langgraph.runtime import Runtime + +from deerflow.agents.middlewares.safety_termination_detectors import ( + SafetyTermination, + SafetyTerminationDetector, + default_detectors, +) +from deerflow.agents.middlewares.tool_call_metadata import clone_ai_message_with_tool_calls + +if TYPE_CHECKING: + from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig + +logger = logging.getLogger(__name__) + + +_USER_FACING_MESSAGE = ( + "The model provider stopped this response with a safety-related signal " + "({reason_field}={reason_value!r}, detector={detector!r}). Any tool " + "calls produced in this turn were suppressed because their arguments " + "may be truncated and unsafe to execute. Please rephrase the request " + "or ask for a narrower output." +) + + +class SafetyFinishReasonMiddleware(AgentMiddleware[AgentState]): + """Strip tool_calls from AIMessages flagged by a SafetyTerminationDetector.""" + + def __init__(self, detectors: list[SafetyTerminationDetector] | None = None) -> None: + super().__init__() + # Copy so caller mutations after construction don't leak into us. + self._detectors: list[SafetyTerminationDetector] = list(detectors) if detectors else default_detectors() + + @classmethod + def from_config(cls, config: SafetyFinishReasonConfig) -> SafetyFinishReasonMiddleware: + """Construct from validated Pydantic config, honouring the + reflection-loaded detector list when provided. + + An explicit empty list is intentionally rejected — it would silently + disable detection while leaving the middleware in the chain, which + is the worst of both worlds. Use ``enabled: false`` instead. + """ + if config.detectors is None: + return cls() + + if not config.detectors: + raise ValueError("safety_finish_reason.detectors must be omitted (use built-ins) or contain at least one entry; use enabled=false to disable the middleware entirely.") + + from deerflow.reflection import resolve_variable + + detectors: list[SafetyTerminationDetector] = [] + for entry in config.detectors: + detector_cls = resolve_variable(entry.use) + kwargs = dict(entry.config) if entry.config else {} + detector = detector_cls(**kwargs) + if not isinstance(detector, SafetyTerminationDetector): + raise TypeError(f"{entry.use} did not produce a SafetyTerminationDetector (got {type(detector).__name__}); ensure it has a `name` attribute and a `detect(message)` method") + detectors.append(detector) + return cls(detectors=detectors) + + # ----- detection ------------------------------------------------------- + + def _detect(self, message: AIMessage) -> SafetyTermination | None: + for detector in self._detectors: + try: + hit = detector.detect(message) + except Exception: # noqa: BLE001 - never let a buggy detector break the agent run + logger.exception("SafetyTerminationDetector %r raised; treating as no-match", getattr(detector, "name", type(detector).__name__)) + continue + if hit is not None: + return hit + return None + + # ----- message rewriting ---------------------------------------------- + + @staticmethod + def _append_user_message(content: object, text: str) -> str | list: + """Append a plain-text explanation to AIMessage content. + + Mirrors ``LoopDetectionMiddleware._append_text`` so list-content + responses (Anthropic thinking blocks, vLLM reasoning splits) keep + their structure instead of being string-coerced into a TypeError. + """ + if content is None or content == "": + return text + if isinstance(content, list): + return [*content, {"type": "text", "text": f"\n\n{text}"}] + if isinstance(content, str): + return content + f"\n\n{text}" + return str(content) + f"\n\n{text}" + + def _build_suppressed_message( + self, + message: AIMessage, + termination: SafetyTermination, + ) -> AIMessage: + suppressed_names = [tc.get("name") or "unknown" for tc in (message.tool_calls or [])] + explanation = _USER_FACING_MESSAGE.format( + reason_field=termination.reason_field, + reason_value=termination.reason_value, + detector=termination.detector, + ) + new_content = self._append_user_message(message.content, explanation) + + # clone_ai_message_with_tool_calls handles structured tool_calls, + # raw additional_kwargs.tool_calls, and function_call in one shot. + # It only rewrites finish_reason when the old value was "tool_calls", + # which is not our case — content_filter / refusal / SAFETY stay put + # so downstream SSE / converters keep seeing the real provider reason. + cleared = clone_ai_message_with_tool_calls(message, [], content=new_content) + + # Re-clone additional_kwargs so we don't accidentally mutate the + # dict returned by clone_ai_message_with_tool_calls (which already + # made a shallow copy, but downstream model_copy still references + # it). Then stamp the observability record. + kwargs = dict(getattr(cleared, "additional_kwargs", None) or {}) + kwargs["safety_termination"] = { + "detector": termination.detector, + "reason_field": termination.reason_field, + "reason_value": termination.reason_value, + "suppressed_tool_call_count": len(suppressed_names), + "suppressed_tool_call_names": suppressed_names, + "extras": dict(termination.extras) if termination.extras else {}, + } + return cleared.model_copy(update={"additional_kwargs": kwargs}) + + # ----- observability --------------------------------------------------- + + def _emit_event( + self, + termination: SafetyTermination, + suppressed_names: list[str], + runtime: Runtime, + ) -> None: + """Notify SSE consumers (e.g. the web UI) that a tool turn was + suppressed so they can reconcile any "tool starting..." placeholders + already streamed to the user. Failures are logged at debug and + ignored — this is a best-effort signal.""" + try: + from langgraph.config import get_stream_writer + + writer = get_stream_writer() + except Exception: # noqa: BLE001 + logger.debug("get_stream_writer unavailable; skipping safety_termination event", exc_info=True) + return + + thread_id = None + if runtime is not None and getattr(runtime, "context", None): + thread_id = runtime.context.get("thread_id") if isinstance(runtime.context, dict) else None + + try: + writer( + { + "type": "safety_termination", + "detector": termination.detector, + "reason_field": termination.reason_field, + "reason_value": termination.reason_value, + "suppressed_tool_call_count": len(suppressed_names), + "suppressed_tool_call_names": suppressed_names, + "thread_id": thread_id, + } + ) + except Exception: # noqa: BLE001 + logger.debug("Failed to emit safety_termination stream event", exc_info=True) + + def _record_audit_event( + self, + termination: SafetyTermination, + message, + tool_calls: list[dict], + runtime: Runtime, + ) -> None: + """Write a ``middleware:safety_termination`` record to RunEventStore + for post-run auditability. + + The custom stream event in ``_emit_event`` is consumed by live SSE + clients and disappears after the run; this event is persisted so an + operator can answer "which runs were safety-suppressed today?" from + a single SQL query without joining the message body. Worker exposes + the run-scoped ``RunJournal`` via ``runtime.context["__run_journal"]``; + absent in unit-test / subagent / no-event-store paths, in which case + we silently skip. + + Tool **arguments** are deliberately **not** recorded — those are the + very content the provider filtered; persisting them would defeat the + purpose of the safety filter. Names / count / ids are sufficient for + audit and debugging (issue #3028 review). + """ + journal = None + if runtime is not None and getattr(runtime, "context", None): + context = runtime.context + if isinstance(context, dict): + journal = context.get("__run_journal") + if journal is None: + return + + suppressed_names = [tc.get("name") or "unknown" for tc in tool_calls] + suppressed_ids = [tc.get("id") for tc in tool_calls if tc.get("id")] + + changes = { + "detector": termination.detector, + "reason_field": termination.reason_field, + "reason_value": termination.reason_value, + "suppressed_tool_call_count": len(tool_calls), + "suppressed_tool_call_names": suppressed_names, + "suppressed_tool_call_ids": suppressed_ids, + "message_id": getattr(message, "id", None), + "extras": dict(termination.extras) if termination.extras else {}, + } + + try: + journal.record_middleware( + tag="safety_termination", + name=type(self).__name__, + hook="after_model", + action="suppress_tool_calls", + changes=changes, + ) + except Exception: # noqa: BLE001 + # Audit-event persistence must never break agent execution. + logger.debug("Failed to record middleware:safety_termination event", exc_info=True) + + # ----- main apply ------------------------------------------------------ + + def _apply(self, state: AgentState, runtime: Runtime) -> dict | None: + messages = state.get("messages", []) + if not messages: + return None + + last = messages[-1] + if not isinstance(last, AIMessage): + return None + + # Issue scope: only intervene when there's something to suppress. + # ``content_filter`` without tool_calls is allowed through unchanged + # so the partial text response (if any) reaches the user naturally. + tool_calls = last.tool_calls + if not tool_calls: + return None + + termination = self._detect(last) + if termination is None: + return None + + patched = self._build_suppressed_message(last, termination) + + thread_id = None + if runtime is not None and getattr(runtime, "context", None): + thread_id = runtime.context.get("thread_id") if isinstance(runtime.context, dict) else None + + logger.warning( + "Provider safety termination detected — suppressed %d tool call(s)", + len(tool_calls), + extra={ + "thread_id": thread_id, + "detector": termination.detector, + "reason_field": termination.reason_field, + "reason_value": termination.reason_value, + "suppressed_tool_call_names": [tc.get("name") for tc in tool_calls], + }, + ) + + self._emit_event(termination, [tc.get("name") or "unknown" for tc in tool_calls], runtime) + self._record_audit_event(termination, last, list(tool_calls), runtime) + + return {"messages": [patched]} + + # ----- hooks ----------------------------------------------------------- + + @override + def after_model(self, state: AgentState, runtime: Runtime) -> dict | None: + return self._apply(state, runtime) + + @override + async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None: + return self._apply(state, runtime) diff --git a/backend/packages/harness/deerflow/agents/middlewares/safety_termination_detectors.py b/backend/packages/harness/deerflow/agents/middlewares/safety_termination_detectors.py new file mode 100644 index 000000000..b98e9f4d7 --- /dev/null +++ b/backend/packages/harness/deerflow/agents/middlewares/safety_termination_detectors.py @@ -0,0 +1,237 @@ +"""Detectors for provider-side safety termination signals. + +Different LLM providers signal "I stopped this response for safety reasons" +through different fields with different values. This module defines a small +strategy interface and three built-in detectors that cover the major +providers DeerFlow supports today. New providers (Wenxin, Hunyuan, Bedrock +adapters, in-house gateways, ...) can be added by implementing +``SafetyTerminationDetector`` and wiring it through +``config.yaml: safety_finish_reason.detectors``. + +The middleware that consumes these detectors lives in +``safety_finish_reason_middleware.py``. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Protocol, runtime_checkable + +from langchain_core.messages import AIMessage + + +@dataclass(frozen=True) +class SafetyTermination: + """A detected safety-related termination signal. + + Attributes: + detector: Name of the detector that produced this result. Used for + observability so operators can see which provider rule fired. + reason_field: The message metadata field that carried the signal + (e.g. ``finish_reason``, ``stop_reason``). + reason_value: The actual value of that field + (e.g. ``content_filter``, ``refusal``, ``SAFETY``). + extras: Provider-specific metadata that may help downstream + consumers (e.g. Azure OpenAI content_filter_results, Gemini + safety_ratings). Detectors are free to populate or skip this. + """ + + detector: str + reason_field: str + reason_value: str + extras: dict[str, Any] = field(default_factory=dict) + + +@runtime_checkable +class SafetyTerminationDetector(Protocol): + """Strategy interface for provider safety termination detection.""" + + name: str + + def detect(self, message: AIMessage) -> SafetyTermination | None: + """Return a SafetyTermination if *message* indicates provider safety + termination, otherwise return ``None``. + + Implementations must be side-effect free and tolerant of missing or + oddly-typed metadata — detectors run on every model response. + """ + ... + + +def _get_metadata_value(message: AIMessage, field_name: str) -> str | None: + """Read a string-typed value from either ``response_metadata`` or + ``additional_kwargs``. + + LangChain provider adapters are inconsistent about where they stash + provider stop signals. Most modern adapters use ``response_metadata``, + but some legacy / passthrough paths still surface them via + ``additional_kwargs``. We check both, in that order, and only accept + string values — Pydantic enums or dicts are ignored so we never raise + on malformed inputs. + """ + for container_name in ("response_metadata", "additional_kwargs"): + container = getattr(message, container_name, None) or {} + if not isinstance(container, dict): + continue + value = container.get(field_name) + if isinstance(value, str) and value: + return value + return None + + +class OpenAICompatibleContentFilterDetector: + """OpenAI-compatible content_filter signal. + + Covers OpenAI, Azure OpenAI, Moonshot/Kimi, DeepSeek, Mistral, vLLM, + Qwen (OpenAI-compatible mode), and any other adapter that follows the + OpenAI ``finish_reason`` convention. + + Some Chinese providers ship custom OpenAI-compatible gateways that use + alternative tokens like ``sensitive`` or ``violation``. Extend the set + via the ``finish_reasons`` kwarg in config. + """ + + name = "openai_compatible_content_filter" + + def __init__(self, finish_reasons: list[str] | tuple[str, ...] | None = None) -> None: + configured = finish_reasons if finish_reasons is not None else ("content_filter",) + self._finish_reasons: frozenset[str] = frozenset(r.lower() for r in configured) + + def detect(self, message: AIMessage) -> SafetyTermination | None: + value = _get_metadata_value(message, "finish_reason") + if value is None or value.lower() not in self._finish_reasons: + return None + + extras: dict[str, Any] = {} + # Azure OpenAI ships a structured content_filter_results block; carry it + # through so operators can see *what* was filtered without re-tracing. + response_metadata = getattr(message, "response_metadata", None) or {} + if isinstance(response_metadata, dict): + filter_results = response_metadata.get("content_filter_results") + if filter_results: + extras["content_filter_results"] = filter_results + + return SafetyTermination( + detector=self.name, + reason_field="finish_reason", + reason_value=value, + extras=extras, + ) + + +class AnthropicRefusalDetector: + """Anthropic ``stop_reason == "refusal"`` signal. + + Anthropic models surface safety refusals via a dedicated ``stop_reason`` + rather than ``finish_reason``. See: + https://platform.claude.com/docs/en/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals + """ + + name = "anthropic_refusal" + + def __init__(self, stop_reasons: list[str] | tuple[str, ...] | None = None) -> None: + configured = stop_reasons if stop_reasons is not None else ("refusal",) + self._stop_reasons: frozenset[str] = frozenset(r.lower() for r in configured) + + def detect(self, message: AIMessage) -> SafetyTermination | None: + value = _get_metadata_value(message, "stop_reason") + if value is None or value.lower() not in self._stop_reasons: + return None + return SafetyTermination( + detector=self.name, + reason_field="stop_reason", + reason_value=value, + ) + + +class GeminiSafetyDetector: + """Gemini / Vertex AI safety-related finish reasons. + + Gemini uses the same ``finish_reason`` field as OpenAI but with an + enumerated upper-case taxonomy. The default set covers every Gemini + finish_reason that means "the model stopped because the content/image + tripped a safety, blocklist, recitation, or PII filter" — i.e. cases + where any tool_calls returned alongside are likely truncated/ + unreliable. Full enum: + https://docs.cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform_v1.types.Candidate.FinishReason + + Intentionally **excluded** from the default set: + - ``STOP`` — normal termination. + - ``MAX_TOKENS`` — output length truncation, not safety + (same root failure mode as + content_filter, but issue #3028 + scopes it out; expose separately if + desired). + - ``LANGUAGE`` / ``NO_IMAGE`` — capability mismatches, unrelated to + safety; tool_calls would be absent + anyway. + - ``MALFORMED_FUNCTION_CALL`` / + ``UNEXPECTED_TOOL_CALL`` — tool-call protocol errors. The + tool_calls are *also* unreliable + here, but the failure category is + distinct from safety filtering; + handle in a dedicated detector to + keep observability records honest. + - ``OTHER`` / ``IMAGE_OTHER`` / + ``FINISH_REASON_UNSPECIFIED`` — too broad to enable by default; + opt in via ``finish_reasons=`` if + your provider abuses these. + """ + + name = "gemini_safety" + + _DEFAULT_FINISH_REASONS = ( + # Text safety + "SAFETY", + "BLOCKLIST", + "PROHIBITED_CONTENT", + "SPII", + "RECITATION", + # Image safety (multimodal generation) + "IMAGE_SAFETY", + "IMAGE_PROHIBITED_CONTENT", + "IMAGE_RECITATION", + ) + + def __init__(self, finish_reasons: list[str] | tuple[str, ...] | None = None) -> None: + configured = finish_reasons if finish_reasons is not None else self._DEFAULT_FINISH_REASONS + self._finish_reasons: frozenset[str] = frozenset(r.upper() for r in configured) + + def detect(self, message: AIMessage) -> SafetyTermination | None: + value = _get_metadata_value(message, "finish_reason") + if value is None or value.upper() not in self._finish_reasons: + return None + + extras: dict[str, Any] = {} + response_metadata = getattr(message, "response_metadata", None) or {} + if isinstance(response_metadata, dict): + # Gemini surfaces per-category scoring under safety_ratings. + ratings = response_metadata.get("safety_ratings") + if ratings: + extras["safety_ratings"] = ratings + + return SafetyTermination( + detector=self.name, + reason_field="finish_reason", + reason_value=value, + extras=extras, + ) + + +def default_detectors() -> list[SafetyTerminationDetector]: + """Built-in detector set used when no custom detectors are configured.""" + return [ + OpenAICompatibleContentFilterDetector(), + AnthropicRefusalDetector(), + GeminiSafetyDetector(), + ] + + +__all__ = [ + "AnthropicRefusalDetector", + "GeminiSafetyDetector", + "OpenAICompatibleContentFilterDetector", + "SafetyTermination", + "SafetyTerminationDetector", + "default_detectors", +] diff --git a/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py index 4393bd360..ae3522454 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py @@ -164,4 +164,14 @@ def build_subagent_runtime_middlewares( middlewares.append(ViewImageMiddleware()) + # Same provider safety-termination guard the lead agent uses — subagents + # are equally exposed to truncated tool_calls returned with + # finish_reason=content_filter (and friends), and the bad call would then + # propagate back to the lead agent via the task tool result. + safety_config = app_config.safety_finish_reason + if safety_config.enabled: + from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware + + middlewares.append(SafetyFinishReasonMiddleware.from_config(safety_config)) + return middlewares diff --git a/backend/packages/harness/deerflow/config/app_config.py b/backend/packages/harness/deerflow/config/app_config.py index d470d6558..931c95757 100644 --- a/backend/packages/harness/deerflow/config/app_config.py +++ b/backend/packages/harness/deerflow/config/app_config.py @@ -20,6 +20,7 @@ from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_ from deerflow.config.model_config import ModelConfig from deerflow.config.run_events_config import RunEventsConfig from deerflow.config.runtime_paths import existing_project_file +from deerflow.config.safety_finish_reason_config import SafetyFinishReasonConfig from deerflow.config.sandbox_config import SandboxConfig from deerflow.config.skill_evolution_config import SkillEvolutionConfig from deerflow.config.skills_config import SkillsConfig @@ -102,6 +103,7 @@ class AppConfig(BaseModel): guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration") circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration") loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration") + safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration") model_config = ConfigDict(extra="allow") database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration") run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration") diff --git a/backend/packages/harness/deerflow/config/safety_finish_reason_config.py b/backend/packages/harness/deerflow/config/safety_finish_reason_config.py new file mode 100644 index 000000000..0e8adebc5 --- /dev/null +++ b/backend/packages/harness/deerflow/config/safety_finish_reason_config.py @@ -0,0 +1,47 @@ +"""Configuration for SafetyFinishReasonMiddleware. + +Mirrors the shape of GuardrailsConfig: detectors are loaded by class path +through ``deerflow.reflection.resolve_variable`` (same loader the +``guardrails.provider`` config uses) so users can drop in custom provider +detectors without modifying core code. +""" + +from __future__ import annotations + +from pydantic import BaseModel, Field + + +class SafetyDetectorConfig(BaseModel): + """One detector entry under ``safety_finish_reason.detectors``.""" + + use: str = Field( + description=("Class path of a SafetyTerminationDetector implementation (e.g. 'deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector')."), + ) + config: dict = Field( + default_factory=dict, + description="Constructor kwargs passed to the detector class.", + ) + + +class SafetyFinishReasonConfig(BaseModel): + """Configuration for the SafetyFinishReasonMiddleware. + + The middleware intercepts AIMessages where the provider signaled a + safety-related termination (e.g. OpenAI ``finish_reason='content_filter'``) + while still returning tool calls, and suppresses those tool calls so the + half-truncated arguments never execute. + """ + + enabled: bool = Field( + default=True, + description="Master switch for the SafetyFinishReasonMiddleware.", + ) + detectors: list[SafetyDetectorConfig] | None = Field( + default=None, + description=( + "Custom detector list. Leave unset (None) to use the built-in " + "set covering OpenAI-compatible content_filter, Anthropic " + "refusal, and Gemini SAFETY/BLOCKLIST/PROHIBITED_CONTENT/SPII/" + "RECITATION. Provide a non-null list to fully override." + ), + ) diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index aa47cd39b..694464fe3 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -219,6 +219,12 @@ async def run_agent( # manually here because we drive the graph through ``agent.astream(config=...)`` # without passing the official ``context=`` parameter. runtime_ctx = _build_runtime_context(thread_id, run_id, config.get("context"), ctx.app_config) + # Expose the run-scoped journal under a sentinel key so middleware can + # write audit events (e.g. SafetyFinishReasonMiddleware recording + # suppressed tool calls). Double-underscore prefix marks it as a + # runtime-internal channel; user code must not depend on the key name. + if journal is not None: + runtime_ctx["__run_journal"] = journal _install_runtime_context(config, runtime_ctx) runtime = Runtime(context=cast(Any, runtime_ctx), store=store) config.setdefault("configurable", {})["__pregel_runtime"] = runtime diff --git a/backend/scripts/e2e_safety_termination_demo.py b/backend/scripts/e2e_safety_termination_demo.py new file mode 100644 index 000000000..7fd27b23f --- /dev/null +++ b/backend/scripts/e2e_safety_termination_demo.py @@ -0,0 +1,206 @@ +"""End-to-end demo: SafetyFinishReasonMiddleware on the real DeerFlow lead-agent. + +What it proves +-------------- +- The real ``make_lead_agent`` / ``DeerFlowClient`` pipeline is built (full + 18-middleware chain, sandbox, tools, etc.). +- A model that returns ``finish_reason='content_filter'`` + ``tool_calls`` + triggers SafetyFinishReasonMiddleware. +- LangChain's tool router never invokes ``write_file`` — the truncated + arguments do **not** reach the sandbox. +- A ``safety_termination`` custom event is emitted on the stream and the + final AIMessage carries the observability stamp. + +Run from backend/ directory: + PYTHONPATH=. uv run python scripts/e2e_safety_termination_demo.py +""" + +from __future__ import annotations + +import sys +from typing import Any + +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration, ChatResult + +# --------------------------------------------------------------------------- +# Fake provider that mimics Moonshot's content_filter behaviour +# --------------------------------------------------------------------------- + + +class _ContentFilteredFakeModel(BaseChatModel): + """First call returns finish_reason=content_filter + truncated write_file + tool_call. Subsequent calls return a normal stop response so the agent + can terminate (the middleware should make a second call unnecessary by + clearing tool_calls, but we keep this safety net in case loop-detection + or anything else triggers another model invocation).""" + + call_count: int = 0 + + @property + def _llm_type(self) -> str: + return "fake-content-filtered" + + def bind_tools(self, tools, **kwargs): + return self + + def _generate(self, messages, stop=None, run_manager=None, **kwargs): + self.call_count += 1 + if self.call_count == 1: + msg = AIMessage( + content="# 政经周报\n- **会晤时间**:2026年5月12日—13日,特朗普访问中国,与", + tool_calls=[ + { + "id": "call_truncated_write", + "name": "write_file", + "args": { + "path": "/mnt/user-data/outputs/political-economic-news-weekly-may-16-2026.md", + "content": "# 政经周报\n- **会晤时间**:2026年5月12日—13日,特朗普访问中国,与", + }, + } + ], + response_metadata={ + "finish_reason": "content_filter", + "model_name": "kimi-k2.6", + "model_provider": "openai", + }, + ) + else: + msg = AIMessage( + content="(secondary call, should not be needed)", + response_metadata={"finish_reason": "stop", "model_name": "kimi-k2.6"}, + ) + return ChatResult(generations=[ChatGeneration(message=msg)]) + + async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs): + return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs) + + +# --------------------------------------------------------------------------- +# Driver +# --------------------------------------------------------------------------- + + +def main() -> int: + # Inject the fake model BEFORE constructing the client. Both the + # client module and the lead-agent module bind ``create_chat_model`` + # at import time via ``from deerflow.models import create_chat_model``, + # so we patch both attribute slots — the source-of-truth patch on + # ``factory.create_chat_model`` doesn't propagate back into already- + # imported names. + import deerflow.agents.lead_agent.agent as lead_agent_module + import deerflow.client as client_module + + fake = _ContentFilteredFakeModel() + originals = { + "lead": lead_agent_module.create_chat_model, + "client": client_module.create_chat_model, + } + + def fake_create_chat_model(*args, **kwargs): + return fake + + lead_agent_module.create_chat_model = fake_create_chat_model + client_module.create_chat_model = fake_create_chat_model + + from deerflow.client import DeerFlowClient + + try: + client = DeerFlowClient() + + print("\n=== Streaming a turn through the real lead-agent ===") + events: list[dict[str, Any]] = [] + for event in client.stream( + "帮我整理一下最近一周政经新闻,写到 /mnt/user-data/outputs/political-economic-news-weekly-may-16-2026.md", + thread_id="e2e-safety-1", + ): + events.append({"type": event.type, "data": event.data}) + + # ---- Assertions ---- + safety_event = next( + (e for e in events if e["type"] == "custom" and isinstance(e["data"], dict) and e["data"].get("type") == "safety_termination"), + None, + ) + final_values = next( + (e for e in reversed(events) if e["type"] == "values"), + None, + ) + tool_messages = [e for e in events if e["type"] == "messages-tuple" and isinstance(e["data"], dict) and e["data"].get("type") == "tool"] + ai_tool_call_messages = [e for e in events if e["type"] == "messages-tuple" and isinstance(e["data"], dict) and e["data"].get("type") == "ai" and e["data"].get("tool_calls")] + + print(f"\n[stats] total stream events: {len(events)}") + print(f"[stats] model call count: {fake.call_count}") + print(f"[stats] tool messages on stream: {len(tool_messages)}") + print(f"[stats] AI messages carrying tool_calls: {len(ai_tool_call_messages)}") + + print("\n[event] safety_termination custom event:") + if safety_event is None: + print(" *** NOT FOUND ***") + return 1 + for k, v in safety_event["data"].items(): + print(f" {k}: {v}") + + print("\n[state] final AIMessage from last values snapshot:") + if final_values is None: + print(" *** no values snapshot ***") + return 1 + # `values` event carries `_serialize_message` dicts, not Message objects. + final_messages = final_values["data"].get("messages") or [] + last_ai = next((m for m in reversed(final_messages) if isinstance(m, dict) and m.get("type") == "ai"), None) + if last_ai is None: + print(" *** no AIMessage in final state ***") + print(f" message types seen: {[m.get('type') if isinstance(m, dict) else type(m).__name__ for m in final_messages]}") + return 1 + + tool_calls = last_ai.get("tool_calls") or [] + additional_kwargs = last_ai.get("additional_kwargs") or {} + response_metadata = last_ai.get("response_metadata") or {} + content = last_ai.get("content") + + print(f" tool_calls (must be empty): {tool_calls}") + print(f" additional_kwargs.safety_termination: {additional_kwargs.get('safety_termination')}") + content_preview = (content if isinstance(content, str) else str(content))[:200] + print(f" content[:200]: {content_preview!r}") + print(f" response_metadata.finish_reason: {response_metadata.get('finish_reason')}") + + # NOTE: `client._serialize_message` does not include `response_metadata` + # in the values-event payload (client-layer behaviour, unrelated to the + # middleware). The middleware *does* preserve finish_reason on the + # AIMessage object — see test_safety_finish_reason_middleware.py:: + # TestMessageRewrite::test_preserves_response_metadata_finish_reason. + # Here we assert on the observability stamp, which carries the same + # evidence and is in the serialized payload. + stamp = additional_kwargs.get("safety_termination") or {} + failures = [] + if tool_calls: + failures.append("final AIMessage still has tool_calls — middleware did NOT clear them") + if not stamp: + failures.append("final AIMessage missing safety_termination observability stamp") + if tool_messages: + failures.append(f"tool node was invoked: {len(tool_messages)} ToolMessage(s) on stream") + if stamp.get("reason_value") != "content_filter": + failures.append(f"safety_termination.reason_value was {stamp.get('reason_value')!r}, expected 'content_filter'") + if safety_event is None: + failures.append("safety_termination custom event was not emitted on the stream") + + if failures: + print("\n=== FAIL ===") + for f in failures: + print(f" - {f}") + return 1 + + print("\n=== PASS ===") + print(" - tool_calls cleared on final AIMessage") + print(" - tool node never invoked (no ToolMessage on stream)") + print(" - safety_termination custom event emitted") + print(" - observability stamp written to additional_kwargs") + print(" - response_metadata.finish_reason preserved for downstream SSE") + return 0 + finally: + lead_agent_module.create_chat_model = originals["lead"] + client_module.create_chat_model = originals["client"] + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/backend/tests/test_lead_agent_model_resolution.py b/backend/tests/test_lead_agent_model_resolution.py index 7ac4b97e6..a12a754c2 100644 --- a/backend/tests/test_lead_agent_model_resolution.py +++ b/backend/tests/test_lead_agent_model_resolution.py @@ -336,8 +336,11 @@ def test_build_middlewares_uses_resolved_model_name_for_vision(monkeypatch): ) assert any(isinstance(m, lead_agent_module.ViewImageMiddleware) for m in middlewares) - # verify the custom middleware is injected correctly - assert len(middlewares) > 0 and isinstance(middlewares[-2], MagicMock) + # verify the custom middleware is injected correctly. + # Chain tail order after the custom middleware is: + # ..., custom, SafetyFinishReasonMiddleware, ClarificationMiddleware + # so the custom mock sits at index [-3]. + assert len(middlewares) > 0 and isinstance(middlewares[-3], MagicMock) def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypatch): diff --git a/backend/tests/test_safety_finish_reason_graph_integration.py b/backend/tests/test_safety_finish_reason_graph_integration.py new file mode 100644 index 000000000..f26a7be90 --- /dev/null +++ b/backend/tests/test_safety_finish_reason_graph_integration.py @@ -0,0 +1,225 @@ +"""End-to-end graph integration test for SafetyFinishReasonMiddleware. + +Unit tests prove ``_apply`` does the right thing on a synthetic state. +This test does one level up: builds a real ``langchain.agents.create_agent`` +graph with the SafetyFinishReasonMiddleware in place, feeds it a fake model +that returns ``finish_reason='content_filter'`` + tool_calls, and asserts: + + 1. The tool node is **not** invoked (the dangerous truncated tool call + is suppressed). + 2. The final AIMessage in graph state has ``tool_calls == []``. + 3. The observability ``safety_termination`` record is attached. + 4. The user-facing explanation is appended to the message content. + +This is the closest we can get to the issue's failure mode without a live +Moonshot key, and it proves the middleware actually gates LangChain's +tool router — not just rewrites state in isolation. +""" + +from __future__ import annotations + +from typing import Any + +from langchain.agents import create_agent +from langchain.agents.middleware import AgentMiddleware +from langchain.agents.middleware.types import ModelRequest, ModelResponse +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.tools import tool + +from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware + +_TOOL_INVOCATIONS: list[dict[str, Any]] = [] + + +@tool +def write_file(path: str, content: str) -> str: + """Pretend to write *content* to *path*. Records the call for assertion.""" + _TOOL_INVOCATIONS.append({"path": path, "content": content}) + return f"wrote {len(content)} bytes to {path}" + + +class _ContentFilteredModel(BaseChatModel): + """Fake chat model that mimics OpenAI/Moonshot's content_filter response. + + First call returns finish_reason='content_filter' + a tool_call whose + arguments are visibly truncated. Second call (if reached) returns a + normal text completion so the agent can terminate cleanly. + """ + + call_count: int = 0 + + @property + def _llm_type(self) -> str: + return "fake-content-filtered" + + def bind_tools(self, tools, **kwargs): + # create_agent binds tools onto the model; we don't actually need + # to bind anything since responses are hard-coded, but the method + # must not raise. + return self + + def _generate(self, messages, stop=None, run_manager=None, **kwargs): + self.call_count += 1 + if self.call_count == 1: + message = AIMessage( + content="Here is the report:\n# Weekly Politics\n- Meeting time: 2026-05-12—", + tool_calls=[ + { + "id": "call_truncated_1", + "name": "write_file", + "args": { + "path": "/mnt/user-data/outputs/report.md", + "content": "# Weekly Politics\n- Meeting time: 2026-05-12—", + }, + } + ], + response_metadata={"finish_reason": "content_filter", "model_name": "fake-kimi"}, + ) + else: + message = AIMessage(content="ack", response_metadata={"finish_reason": "stop"}) + return ChatResult(generations=[ChatGeneration(message=message)]) + + async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs): + return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs) + + +class _InspectMiddleware(AgentMiddleware): + """Captures the messages list at every model entry so we can assert + no synthetic tool result was injected back into the conversation.""" + + def __init__(self) -> None: + super().__init__() + self.observed: list[list[Any]] = [] + + def wrap_model_call(self, request: ModelRequest, handler) -> ModelResponse: + self.observed.append(list(request.messages)) + return handler(request) + + +def test_content_filter_with_tool_calls_does_not_invoke_tool_node(): + _TOOL_INVOCATIONS.clear() + inspector = _InspectMiddleware() + + agent = create_agent( + model=_ContentFilteredModel(), + tools=[write_file], + # Inspector first so its after_model is registered; Safety last in + # the list so it executes first under LIFO (matches production wiring). + middleware=[inspector, SafetyFinishReasonMiddleware()], + ) + + result = agent.invoke({"messages": [HumanMessage(content="write me a report")]}) + + # Critical assertion: the dangerous truncated tool call must NOT have + # been executed. This is the entire point of the middleware. + assert _TOOL_INVOCATIONS == [], f"write_file was invoked despite content_filter: {_TOOL_INVOCATIONS}" + + # Final AIMessage has no tool calls left. + final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage)) + assert final_ai.tool_calls == [] + + # Observability stamp is present. + record = final_ai.additional_kwargs.get("safety_termination") + assert record is not None + assert record["detector"] == "openai_compatible_content_filter" + assert record["reason_field"] == "finish_reason" + assert record["reason_value"] == "content_filter" + assert record["suppressed_tool_call_count"] == 1 + assert record["suppressed_tool_call_names"] == ["write_file"] + + # User-facing explanation is appended. + assert "safety-related signal" in final_ai.content + # Original partial text preserved (we don't throw away what the user + # already saw in the stream — see middleware docstring). + assert "Weekly Politics" in final_ai.content + + # finish_reason on response_metadata is preserved (so SSE / converters + # downstream still see the real provider reason). + assert final_ai.response_metadata.get("finish_reason") == "content_filter" + + +def test_content_filter_without_tool_calls_passes_through_unchanged(): + """No tool calls => issue scope says don't intervene; the partial + response should be delivered as-is so the user sees what they got.""" + _TOOL_INVOCATIONS.clear() + + class _NoToolModel(BaseChatModel): + @property + def _llm_type(self) -> str: + return "fake-no-tool" + + def bind_tools(self, tools, **kwargs): + return self + + def _generate(self, messages, stop=None, run_manager=None, **kwargs): + msg = AIMessage( + content="Partial answer truncated by safety filter", + response_metadata={"finish_reason": "content_filter"}, + ) + return ChatResult(generations=[ChatGeneration(message=msg)]) + + async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs): + return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs) + + agent = create_agent( + model=_NoToolModel(), + tools=[write_file], + middleware=[SafetyFinishReasonMiddleware()], + ) + result = agent.invoke({"messages": [HumanMessage(content="hi")]}) + final_ai = next(m for m in reversed(result["messages"]) if isinstance(m, AIMessage)) + + # Content untouched. + assert final_ai.content == "Partial answer truncated by safety filter" + # No safety_termination stamp because we didn't intervene. + assert "safety_termination" not in final_ai.additional_kwargs + # tool node never ran (there were no tool calls in the first place). + assert _TOOL_INVOCATIONS == [] + + +def test_normal_tool_call_round_trip_is_not_affected(): + """Regression: a healthy finish_reason='tool_calls' response must still + execute the tool. The middleware must not over-fire.""" + _TOOL_INVOCATIONS.clear() + + class _HealthyToolModel(BaseChatModel): + call_count: int = 0 + + @property + def _llm_type(self) -> str: + return "fake-healthy" + + def bind_tools(self, tools, **kwargs): + return self + + def _generate(self, messages, stop=None, run_manager=None, **kwargs): + self.call_count += 1 + if self.call_count == 1: + msg = AIMessage( + content="", + tool_calls=[ + { + "id": "call_ok", + "name": "write_file", + "args": {"path": "/tmp/ok", "content": "complete content"}, + } + ], + response_metadata={"finish_reason": "tool_calls"}, + ) + else: + msg = AIMessage(content="done", response_metadata={"finish_reason": "stop"}) + return ChatResult(generations=[ChatGeneration(message=msg)]) + + async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs): + return self._generate(messages, stop=stop, run_manager=run_manager, **kwargs) + + agent = create_agent( + model=_HealthyToolModel(), + tools=[write_file], + middleware=[SafetyFinishReasonMiddleware()], + ) + agent.invoke({"messages": [HumanMessage(content="write")]}) + + assert _TOOL_INVOCATIONS == [{"path": "/tmp/ok", "content": "complete content"}] diff --git a/backend/tests/test_safety_finish_reason_middleware.py b/backend/tests/test_safety_finish_reason_middleware.py new file mode 100644 index 000000000..14c6226dd --- /dev/null +++ b/backend/tests/test_safety_finish_reason_middleware.py @@ -0,0 +1,651 @@ +"""Unit tests for SafetyFinishReasonMiddleware.""" + +from unittest.mock import MagicMock + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage + +from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware +from deerflow.agents.middlewares.safety_termination_detectors import ( + SafetyTermination, +) +from deerflow.config.safety_finish_reason_config import ( + SafetyDetectorConfig, + SafetyFinishReasonConfig, +) + + +def _runtime(thread_id="t-1"): + runtime = MagicMock() + runtime.context = {"thread_id": thread_id} + return runtime + + +def _ai( + *, + content="", + tool_calls=None, + response_metadata=None, + additional_kwargs=None, +): + return AIMessage( + content=content, + tool_calls=tool_calls or [], + response_metadata=response_metadata or {}, + additional_kwargs=additional_kwargs or {}, + ) + + +def _write_call(idx=1, content_text="半截"): + return { + "id": f"call_write_{idx}", + "name": "write_file", + "args": {"path": "/mnt/user-data/outputs/x.md", "content": content_text}, + } + + +class AlwaysHitDetector: + """Test fixture: always reports the given termination.""" + + name = "always_hit" + + def __init__(self, *, reason_field="finish_reason", reason_value="content_filter", extras=None): + self.reason_field = reason_field + self.reason_value = reason_value + self.extras = extras or {} + + def detect(self, message): + return SafetyTermination( + detector=self.name, + reason_field=self.reason_field, + reason_value=self.reason_value, + extras=self.extras, + ) + + +class NeverHitDetector: + name = "never_hit" + + def detect(self, message): + return None + + +class RaisingDetector: + name = "raising" + + def detect(self, message): + raise RuntimeError("boom") + + +# --------------------------------------------------------------------------- +# Core trigger behaviour +# --------------------------------------------------------------------------- + + +class TestTriggerCriteria: + def test_content_filter_with_tool_calls_triggers(self): + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + content="partial", + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + result = mw._apply(state, _runtime()) + assert result is not None + patched = result["messages"][0] + assert patched.tool_calls == [] + + def test_content_filter_without_tool_calls_passes_through(self): + """issue scope: when there are no tool calls the partial text is a + legitimate final response and should not be rewritten.""" + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + content="partial response", + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + assert mw._apply(state, _runtime()) is None + + def test_normal_tool_calls_pass_through(self): + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "tool_calls"}, + ) + ] + } + assert mw._apply(state, _runtime()) is None + + def test_normal_stop_with_tool_calls_pass_through(self): + # Some providers report finish_reason='stop' for tool-call messages. + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "stop"}, + ) + ] + } + assert mw._apply(state, _runtime()) is None + + def test_empty_message_list_passes_through(self): + mw = SafetyFinishReasonMiddleware() + assert mw._apply({"messages": []}, _runtime()) is None + + def test_non_ai_last_message_passes_through(self): + mw = SafetyFinishReasonMiddleware() + state = {"messages": [HumanMessage(content="hi"), SystemMessage(content="sys")]} + assert mw._apply(state, _runtime()) is None + + def test_anthropic_refusal_with_tool_calls_triggers(self): + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"stop_reason": "refusal"}, + ) + ] + } + result = mw._apply(state, _runtime()) + assert result is not None + assert result["messages"][0].tool_calls == [] + + def test_gemini_safety_with_tool_calls_triggers(self): + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "SAFETY"}, + ) + ] + } + result = mw._apply(state, _runtime()) + assert result is not None + assert result["messages"][0].tool_calls == [] + + +# --------------------------------------------------------------------------- +# Message rewriting +# --------------------------------------------------------------------------- + + +class TestMessageRewrite: + def test_clears_structured_tool_calls(self): + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call(1), _write_call(2)], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + result = mw._apply(state, _runtime()) + patched = result["messages"][0] + assert patched.tool_calls == [] + + def test_clears_raw_additional_kwargs_tool_calls(self): + """Critical defence-in-depth: DanglingToolCallMiddleware will recover + tool calls from additional_kwargs.tool_calls if we forget them, which + would re-emit a synthetic ToolMessage downstream and confuse the + model. We must wipe both.""" + mw = SafetyFinishReasonMiddleware() + raw_tool_calls = [ + { + "id": "call_write_1", + "type": "function", + "function": {"name": "write_file", "arguments": '{"path": "/x"}'}, + } + ] + state = { + "messages": [ + _ai( + tool_calls=[_write_call(1)], + response_metadata={"finish_reason": "content_filter"}, + additional_kwargs={ + "tool_calls": raw_tool_calls, + "function_call": {"name": "write_file", "arguments": "{}"}, + }, + ) + ] + } + result = mw._apply(state, _runtime()) + patched = result["messages"][0] + assert "tool_calls" not in patched.additional_kwargs + assert "function_call" not in patched.additional_kwargs + + def test_preserves_other_additional_kwargs(self): + # vLLM puts reasoning under additional_kwargs.reasoning; Anthropic + # may carry other provider-specific keys. They must not be wiped. + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + additional_kwargs={ + "reasoning": "thinking text", + "custom_provider_field": {"x": 1}, + }, + ) + ] + } + patched = mw._apply(state, _runtime())["messages"][0] + assert patched.additional_kwargs["reasoning"] == "thinking text" + assert patched.additional_kwargs["custom_provider_field"] == {"x": 1} + + def test_writes_observability_field(self): + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call(1), _write_call(2)], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + patched = mw._apply(state, _runtime())["messages"][0] + record = patched.additional_kwargs["safety_termination"] + assert record["detector"] == "openai_compatible_content_filter" + assert record["reason_field"] == "finish_reason" + assert record["reason_value"] == "content_filter" + assert record["suppressed_tool_call_count"] == 2 + assert record["suppressed_tool_call_names"] == ["write_file", "write_file"] + + def test_preserves_response_metadata_finish_reason(self): + """Downstream SSE converters read response_metadata.finish_reason — + we want them to see the *real* provider reason, not 'stop'.""" + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter", "model_name": "kimi-k2"}, + ) + ] + } + patched = mw._apply(state, _runtime())["messages"][0] + assert patched.response_metadata["finish_reason"] == "content_filter" + assert patched.response_metadata["model_name"] == "kimi-k2" + + def test_appends_user_facing_explanation_to_str_content(self): + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + content="some partial text", + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + patched = mw._apply(state, _runtime())["messages"][0] + assert isinstance(patched.content, str) + assert patched.content.startswith("some partial text") + assert "safety-related signal" in patched.content + + def test_handles_empty_content(self): + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + content="", + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + patched = mw._apply(state, _runtime())["messages"][0] + assert isinstance(patched.content, str) + assert "safety-related signal" in patched.content + + def test_handles_list_content_thinking_blocks(self): + """Anthropic thinking / vLLM reasoning models emit content blocks. + Naively concatenating a string would raise TypeError.""" + mw = SafetyFinishReasonMiddleware() + thinking_blocks = [ + {"type": "thinking", "text": "let me consider..."}, + {"type": "text", "text": "partial answer"}, + ] + state = { + "messages": [ + _ai( + content=thinking_blocks, + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + patched = mw._apply(state, _runtime())["messages"][0] + assert isinstance(patched.content, list) + assert patched.content[:2] == thinking_blocks + assert patched.content[-1]["type"] == "text" + assert "safety-related signal" in patched.content[-1]["text"] + + def test_idempotent_on_already_cleared_message(self): + # Re-running the middleware on a message we already cleared must not + # re-trigger (tool_calls is now empty → fast passthrough). + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + first = mw._apply(state, _runtime()) + state2 = {"messages": [first["messages"][0]]} + second = mw._apply(state2, _runtime()) + assert second is None + + def test_preserves_message_id_for_add_messages_replacement(self): + """LangGraph's add_messages reducer treats same-id messages as + replacements. model_copy keeps id by default.""" + mw = SafetyFinishReasonMiddleware() + original = _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + # AIMessage auto-generates id; capture it + original_id = original.id + state = {"messages": [original]} + patched = mw._apply(state, _runtime())["messages"][0] + assert patched.id == original_id + + +# --------------------------------------------------------------------------- +# Detector wiring +# --------------------------------------------------------------------------- + + +class TestDetectorWiring: + def test_iterates_detectors_in_order(self): + first = AlwaysHitDetector(reason_value="first") + second = AlwaysHitDetector(reason_value="second") + mw = SafetyFinishReasonMiddleware(detectors=[first, second]) + state = {"messages": [_ai(tool_calls=[_write_call()])]} + patched = mw._apply(state, _runtime())["messages"][0] + assert patched.additional_kwargs["safety_termination"]["reason_value"] == "first" + + def test_returns_none_when_no_detector_matches(self): + mw = SafetyFinishReasonMiddleware(detectors=[NeverHitDetector(), NeverHitDetector()]) + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + assert mw._apply(state, _runtime()) is None + + def test_buggy_detector_does_not_break_run(self): + mw = SafetyFinishReasonMiddleware(detectors=[RaisingDetector(), AlwaysHitDetector()]) + state = {"messages": [_ai(tool_calls=[_write_call()])]} + result = mw._apply(state, _runtime()) + assert result is not None + assert result["messages"][0].additional_kwargs["safety_termination"]["detector"] == "always_hit" + + def test_constructor_copies_detectors(self): + """Caller mutation after construction must not leak into us.""" + detectors = [AlwaysHitDetector()] + mw = SafetyFinishReasonMiddleware(detectors=detectors) + detectors.clear() + state = {"messages": [_ai(tool_calls=[_write_call()])]} + assert mw._apply(state, _runtime()) is not None + + +# --------------------------------------------------------------------------- +# from_config +# --------------------------------------------------------------------------- + + +class TestFromConfig: + def test_default_config_uses_builtin_detectors(self): + mw = SafetyFinishReasonMiddleware.from_config(SafetyFinishReasonConfig()) + assert len(mw._detectors) == 3 + names = {d.name for d in mw._detectors} + assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"} + + def test_custom_detectors_loaded_via_reflection(self): + cfg = SafetyFinishReasonConfig( + detectors=[ + SafetyDetectorConfig( + use="deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector", + config={"finish_reasons": ["custom_filter"]}, + ), + ] + ) + mw = SafetyFinishReasonMiddleware.from_config(cfg) + assert len(mw._detectors) == 1 + # Confirm the kwargs propagated. + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "custom_filter"}, + ) + ] + } + assert mw._apply(state, _runtime()) is not None + # Default token no longer matches. + state2 = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + assert mw._apply(state2, _runtime()) is None + + def test_empty_detector_list_rejected(self): + cfg = SafetyFinishReasonConfig(detectors=[]) + with pytest.raises(ValueError, match="enabled=false"): + SafetyFinishReasonMiddleware.from_config(cfg) + + def test_non_detector_class_rejected(self): + cfg = SafetyFinishReasonConfig( + detectors=[SafetyDetectorConfig(use="builtins:dict")], + ) + with pytest.raises(TypeError): + SafetyFinishReasonMiddleware.from_config(cfg) + + +# --------------------------------------------------------------------------- +# Stream event +# --------------------------------------------------------------------------- + + +class TestAuditEvent: + """Verify SafetyFinishReasonMiddleware records a `middleware:safety_termination` + audit event via RunJournal.record_middleware when the run-scoped journal is + exposed under runtime.context["__run_journal"]. + + Background: review on PR #3035 — SSE custom event handles live consumers, + but post-run audit needs a row in run_events that can be queried with one + SQL statement (no JOIN against message body). + """ + + def _runtime_with_journal(self, journal): + runtime = MagicMock() + runtime.context = {"thread_id": "t-audit", "__run_journal": journal} + return runtime + + def test_records_audit_event_when_journal_present(self): + journal = MagicMock() + mw = SafetyFinishReasonMiddleware() + tc = _write_call(1) + state = { + "messages": [ + _ai( + content="partial", + tool_calls=[tc], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + result = mw._apply(state, self._runtime_with_journal(journal)) + assert result is not None + + journal.record_middleware.assert_called_once() + call = journal.record_middleware.call_args + # tag is positional or kwarg depending on call style; we use kwargs. + assert call.kwargs["tag"] == "safety_termination" + assert call.kwargs["name"] == "SafetyFinishReasonMiddleware" + assert call.kwargs["hook"] == "after_model" + assert call.kwargs["action"] == "suppress_tool_calls" + + changes = call.kwargs["changes"] + assert changes["detector"] == "openai_compatible_content_filter" + assert changes["reason_field"] == "finish_reason" + assert changes["reason_value"] == "content_filter" + assert changes["suppressed_tool_call_count"] == 1 + assert changes["suppressed_tool_call_names"] == ["write_file"] + assert changes["suppressed_tool_call_ids"] == ["call_write_1"] + assert "message_id" in changes + assert isinstance(changes["extras"], dict) + + def test_audit_event_never_carries_tool_arguments(self): + """PR #3035 review IMPORTANT: tool args are the filtered content itself + and must NOT be persisted to run_events under any circumstance.""" + journal = MagicMock() + mw = SafetyFinishReasonMiddleware() + sensitive_tc = { + "id": "call_x", + "name": "write_file", + "args": {"path": "/x", "content": "FILTERED_CONTENT_DO_NOT_PERSIST"}, + } + state = { + "messages": [ + _ai( + tool_calls=[sensitive_tc], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + mw._apply(state, self._runtime_with_journal(journal)) + flat = repr(journal.record_middleware.call_args) + assert "FILTERED_CONTENT_DO_NOT_PERSIST" not in flat, "tool arguments must not leak into audit event" + assert "args" not in journal.record_middleware.call_args.kwargs["changes"] + + def test_no_journal_in_runtime_context_is_silently_skipped(self): + """Subagent runtime / unit tests / no-event-store paths have no journal. + Middleware must still intervene and clear tool_calls — only the audit + event is skipped.""" + mw = SafetyFinishReasonMiddleware() + runtime = MagicMock() + runtime.context = {"thread_id": "t-noj"} # no __run_journal + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + # Should not raise; should still clear tool_calls. + result = mw._apply(state, runtime) + assert result is not None + assert result["messages"][0].tool_calls == [] + + def test_journal_record_exception_does_not_break_run(self): + """Buggy journal must never propagate an exception into the agent loop.""" + journal = MagicMock() + journal.record_middleware.side_effect = RuntimeError("db down") + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + # Must not raise. + result = mw._apply(state, self._runtime_with_journal(journal)) + assert result is not None + assert result["messages"][0].tool_calls == [] + + def test_no_record_when_passthrough(self): + """When the middleware does NOT intervene, no audit event is written.""" + journal = MagicMock() + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "tool_calls"}, # healthy + ) + ] + } + assert mw._apply(state, self._runtime_with_journal(journal)) is None + journal.record_middleware.assert_not_called() + + +class TestStreamEvent: + def test_emits_event_when_writer_available(self, monkeypatch): + captured: list = [] + + def fake_writer(payload): + captured.append(payload) + + # Patch get_stream_writer at the symbol-resolution site. + import langgraph.config + + monkeypatch.setattr(langgraph.config, "get_stream_writer", lambda: fake_writer) + + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + mw._apply(state, _runtime("t-stream")) + + assert len(captured) == 1 + payload = captured[0] + assert payload["type"] == "safety_termination" + assert payload["detector"] == "openai_compatible_content_filter" + assert payload["reason_field"] == "finish_reason" + assert payload["reason_value"] == "content_filter" + assert payload["suppressed_tool_call_count"] == 1 + assert payload["suppressed_tool_call_names"] == ["write_file"] + assert payload["thread_id"] == "t-stream" + + def test_writer_unavailable_does_not_break(self, monkeypatch): + import langgraph.config + + def boom(): + raise LookupError("not in a stream context") + + monkeypatch.setattr(langgraph.config, "get_stream_writer", boom) + + mw = SafetyFinishReasonMiddleware() + state = { + "messages": [ + _ai( + tool_calls=[_write_call()], + response_metadata={"finish_reason": "content_filter"}, + ) + ] + } + # Should not raise. + result = mw._apply(state, _runtime()) + assert result is not None diff --git a/backend/tests/test_safety_termination_detectors.py b/backend/tests/test_safety_termination_detectors.py new file mode 100644 index 000000000..0679aed0e --- /dev/null +++ b/backend/tests/test_safety_termination_detectors.py @@ -0,0 +1,176 @@ +"""Unit tests for SafetyTerminationDetector built-ins.""" + +from langchain_core.messages import AIMessage + +from deerflow.agents.middlewares.safety_termination_detectors import ( + AnthropicRefusalDetector, + GeminiSafetyDetector, + OpenAICompatibleContentFilterDetector, + SafetyTermination, + SafetyTerminationDetector, + default_detectors, +) + + +def _ai(*, content="", tool_calls=None, response_metadata=None, additional_kwargs=None) -> AIMessage: + return AIMessage( + content=content, + tool_calls=tool_calls or [], + response_metadata=response_metadata or {}, + additional_kwargs=additional_kwargs or {}, + ) + + +class TestOpenAICompatibleContentFilterDetector: + def test_default_matches_content_filter(self): + d = OpenAICompatibleContentFilterDetector() + hit = d.detect(_ai(response_metadata={"finish_reason": "content_filter"})) + assert hit is not None + assert hit.detector == "openai_compatible_content_filter" + assert hit.reason_field == "finish_reason" + assert hit.reason_value == "content_filter" + + def test_case_insensitive_match(self): + d = OpenAICompatibleContentFilterDetector() + assert d.detect(_ai(response_metadata={"finish_reason": "CONTENT_FILTER"})) is not None + + def test_other_finish_reasons_pass_through(self): + d = OpenAICompatibleContentFilterDetector() + assert d.detect(_ai(response_metadata={"finish_reason": "stop"})) is None + assert d.detect(_ai(response_metadata={"finish_reason": "tool_calls"})) is None + assert d.detect(_ai(response_metadata={"finish_reason": "length"})) is None + + def test_missing_metadata_passes_through(self): + d = OpenAICompatibleContentFilterDetector() + assert d.detect(_ai()) is None + + def test_non_string_finish_reason_passes_through(self): + # Some adapters may stash an enum or dict — must not raise. + d = OpenAICompatibleContentFilterDetector() + assert d.detect(_ai(response_metadata={"finish_reason": 42})) is None + assert d.detect(_ai(response_metadata={"finish_reason": {"value": "content_filter"}})) is None + + def test_falls_back_to_additional_kwargs(self): + # Legacy adapters surface finish_reason via additional_kwargs. + d = OpenAICompatibleContentFilterDetector() + hit = d.detect(_ai(additional_kwargs={"finish_reason": "content_filter"})) + assert hit is not None + + def test_configurable_extra_values(self): + # Chinese providers sometimes use bespoke tokens. + d = OpenAICompatibleContentFilterDetector(finish_reasons=["content_filter", "sensitive", "violation"]) + assert d.detect(_ai(response_metadata={"finish_reason": "sensitive"})) is not None + assert d.detect(_ai(response_metadata={"finish_reason": "violation"})) is not None + # Original token still matches. + assert d.detect(_ai(response_metadata={"finish_reason": "content_filter"})) is not None + + def test_carries_azure_content_filter_results(self): + d = OpenAICompatibleContentFilterDetector() + filter_results = {"hate": {"filtered": True, "severity": "high"}} + hit = d.detect( + _ai( + response_metadata={ + "finish_reason": "content_filter", + "content_filter_results": filter_results, + }, + ) + ) + assert hit is not None + assert hit.extras["content_filter_results"] == filter_results + + +class TestAnthropicRefusalDetector: + def test_default_matches_refusal(self): + hit = AnthropicRefusalDetector().detect(_ai(response_metadata={"stop_reason": "refusal"})) + assert hit is not None + assert hit.reason_field == "stop_reason" + assert hit.reason_value == "refusal" + + def test_other_stop_reasons_pass_through(self): + d = AnthropicRefusalDetector() + assert d.detect(_ai(response_metadata={"stop_reason": "end_turn"})) is None + assert d.detect(_ai(response_metadata={"stop_reason": "tool_use"})) is None + assert d.detect(_ai(response_metadata={"stop_reason": "max_tokens"})) is None + + def test_anthropic_does_not_steal_finish_reason(self): + # An OpenAI message must not accidentally trip the Anthropic detector. + assert AnthropicRefusalDetector().detect(_ai(response_metadata={"finish_reason": "content_filter"})) is None + + +class TestGeminiSafetyDetector: + def test_default_set_covers_documented_reasons(self): + d = GeminiSafetyDetector() + for reason in ( + # text safety + "SAFETY", + "BLOCKLIST", + "PROHIBITED_CONTENT", + "SPII", + "RECITATION", + # image safety + "IMAGE_SAFETY", + "IMAGE_PROHIBITED_CONTENT", + "IMAGE_RECITATION", + ): + assert d.detect(_ai(response_metadata={"finish_reason": reason})) is not None, reason + + def test_normal_termination_passes_through(self): + d = GeminiSafetyDetector() + assert d.detect(_ai(response_metadata={"finish_reason": "STOP"})) is None + # MAX_TOKENS / LANGUAGE / NO_IMAGE / OTHER / IMAGE_OTHER / + # MALFORMED_FUNCTION_CALL / UNEXPECTED_TOOL_CALL are intentionally + # excluded from the default set — they are either normal termination, + # capability mismatches, too broad (OTHER), or tool-call protocol + # errors. See GeminiSafetyDetector docstring. + for reason in ( + "MAX_TOKENS", + "LANGUAGE", + "NO_IMAGE", + "OTHER", + "IMAGE_OTHER", + "MALFORMED_FUNCTION_CALL", + "UNEXPECTED_TOOL_CALL", + "FINISH_REASON_UNSPECIFIED", + ): + assert d.detect(_ai(response_metadata={"finish_reason": reason})) is None, reason + + def test_carries_safety_ratings(self): + ratings = [{"category": "HARM_CATEGORY_HARASSMENT", "probability": "HIGH"}] + hit = GeminiSafetyDetector().detect( + _ai( + response_metadata={ + "finish_reason": "SAFETY", + "safety_ratings": ratings, + }, + ) + ) + assert hit is not None + assert hit.extras["safety_ratings"] == ratings + + +class TestDefaultDetectorSet: + def test_default_set_returns_three_detectors(self): + dets = default_detectors() + names = {d.name for d in dets} + assert names == {"openai_compatible_content_filter", "anthropic_refusal", "gemini_safety"} + + def test_default_set_returns_fresh_list(self): + # Caller mutation must not affect later calls. + first = default_detectors() + first.clear() + second = default_detectors() + assert len(second) == 3 + + +class TestProtocolConformance: + def test_builtins_satisfy_protocol(self): + for d in default_detectors(): + assert isinstance(d, SafetyTerminationDetector) + + def test_safety_termination_is_frozen(self): + t = SafetyTermination(detector="x", reason_field="finish_reason", reason_value="content_filter") + try: + t.detector = "y" # type: ignore[misc] + except Exception: + return + raise AssertionError("SafetyTermination should be frozen") diff --git a/backend/tests/test_tool_error_handling_middleware.py b/backend/tests/test_tool_error_handling_middleware.py index 2c28dac35..28c59a9ad 100644 --- a/backend/tests/test_tool_error_handling_middleware.py +++ b/backend/tests/test_tool_error_handling_middleware.py @@ -134,8 +134,14 @@ def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware middlewares = build_subagent_runtime_middlewares(app_config=app_config, lazy_init=False) assert captured["app_config"] is app_config - assert len(middlewares) == 6 - assert isinstance(middlewares[-1], ToolErrorHandlingMiddleware) + # 6 baseline (ThreadData, Sandbox, DanglingToolCall, LLMErrorHandling, + # SandboxAudit, ToolErrorHandling) + 1 SafetyFinishReasonMiddleware + # (enabled by default — see SafetyFinishReasonConfig). + from deerflow.agents.middlewares.safety_finish_reason_middleware import SafetyFinishReasonMiddleware + + assert len(middlewares) == 7 + assert any(isinstance(m, ToolErrorHandlingMiddleware) for m in middlewares) + assert isinstance(middlewares[-1], SafetyFinishReasonMiddleware) def test_wrap_tool_call_passthrough_on_success(): diff --git a/config.example.yaml b/config.example.yaml index 9ea4e4c08..8e289fac9 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -15,7 +15,7 @@ # ============================================================================ # Bump this number when the config schema changes. # Run `make config-upgrade` to merge new fields into your local config.yaml. -config_version: 9 +config_version: 10 # ============================================================================ # Logging @@ -535,6 +535,41 @@ loop_detection: # warn: 150 # hard_limit: 300 +# ============================================================================ +# Provider Safety Termination Configuration +# ============================================================================ +# Intercept AIMessages where the provider stopped generation for safety reasons +# (e.g. OpenAI finish_reason='content_filter', Anthropic stop_reason='refusal', +# Gemini finish_reason='SAFETY') while still returning tool_calls. The +# tool_calls in such responses are typically truncated/unreliable and must +# not be executed. See issue #3028 for the full failure mode. +# +# Detectors are loaded by class path via reflection (same pattern as +# guardrails / models / tools). The built-in set covers OpenAI-compatible +# content_filter, Anthropic refusal, and Gemini SAFETY/BLOCKLIST/ +# PROHIBITED_CONTENT/SPII/RECITATION. + +safety_finish_reason: + enabled: true + # Leave `detectors` unset to use the built-in detector set. Set to a + # non-empty list to fully override (use `enabled: false` to disable instead + # of providing an empty list). + # + # Example — extend the OpenAI-compatible detector for a Chinese provider + # whose gateway uses a non-standard finish_reason token: + # detectors: + # - use: deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector + # config: + # finish_reasons: ["content_filter", "sensitive", "risk_control"] + # - use: deerflow.agents.middlewares.safety_termination_detectors:AnthropicRefusalDetector + # - use: deerflow.agents.middlewares.safety_termination_detectors:GeminiSafetyDetector + # + # Example — add a custom detector for an in-house provider: + # detectors: + # - use: my_company.deerflow_ext:WenxinSafetyDetector + # config: + # error_codes: [336003, 17, 18] + # ============================================================================ # Sandbox Configuration # ============================================================================ From 914d6a4f1c4b0b3330c123168cf8dabfca86b04b Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Fri, 22 May 2026 15:33:15 +0200 Subject: [PATCH 68/86] docs: add provider safety termination post (#3167) --- ...ider-safety-termination-in-tool-agents.mdx | 124 +++++++++++++++++ ...ider-safety-termination-in-tool-agents.mdx | 125 ++++++++++++++++++ 2 files changed, 249 insertions(+) create mode 100644 frontend/src/content/en/posts/provider-safety-termination-in-tool-agents.mdx create mode 100644 frontend/src/content/zh/posts/provider-safety-termination-in-tool-agents.mdx diff --git a/frontend/src/content/en/posts/provider-safety-termination-in-tool-agents.mdx b/frontend/src/content/en/posts/provider-safety-termination-in-tool-agents.mdx new file mode 100644 index 000000000..f72c57770 --- /dev/null +++ b/frontend/src/content/en/posts/provider-safety-termination-in-tool-agents.mdx @@ -0,0 +1,124 @@ +--- +title: Tool-Using Agents Must Handle Provider Safety Termination Signals Correctly +description: Why tool calls left in a safety-terminated model response must not be executed, and how to configure provider detectors in DeerFlow. +date: 2026-05-22 +tags: + - Safety + - Agents + - Model Providers +--- + +## Tool-Using Agents Must Handle Provider Safety Termination Signals Correctly + +When a large model provider decides that an input or output has triggered a safety policy, the important outcome is not merely that the model says less. The application needs to know that the current generation turn has been terminated. In a normal chat interface, this may appear as a refusal, filtered text, or an error response. For an Agent that can call tools, the risk is higher: if the provider has already stopped generation while the response still contains `tool_calls`, those tool arguments may only be partially generated. + +These partial tool calls must not be executed as normal intent. A truncated `write_file` call may write an incomplete report. A truncated `bash` call may enter the sandbox with incomplete arguments. After seeing the failed result, the Agent may retry and trigger the same safety rule repeatedly. + +[PR #3035](https://github.com/bytedance/deer-flow/pull/3035) addresses this boundary: when a provider stops generation with a safety signal while the response still contains tool calls, DeerFlow should suppress those tool calls first and record the turn as a safety termination event. + +## Why Safety Termination Needs Dedicated Handling + +A safety termination is not a normal tool-call finish reason. + +In a healthy tool turn, the provider explicitly tells the application that it should call tools. A safety termination says something different: the output has been blocked by provider policy, or streaming generation has been cut off early. Even if tool-call fragments remain in the response object, the application cannot assume that their JSON arguments, file contents, or command text are complete. + +In a real Agent run, this creates two kinds of risk: + +| Risk | Impact | +| --- | --- | +| Runtime risk | Executing truncated tool arguments can create corrupted files, malformed commands, repeated retries, or tool loops | +| Provider risk | Repeatedly sending similar violating inputs or outputs to a provider increases safety review and abuse-control pressure | + +The second risk matters. Providers enforce their policies differently, but their official materials already make clear that safety policy can affect more than a single completion. It can also affect end users, API access, or account status. + +## What Providers Expose and How They Respond + +Providers do not use one common field name, and they do not share one enforcement process. Deployments need to distinguish at least two layers: + +1. Which signal in this response says that generation was stopped by a safety policy. +2. Which follow-up actions the provider has publicly described when safety problems keep recurring. + +| Provider | Runtime signal | Publicly documented response or recommendation | +| --- | --- | --- | +| GLM | Synchronous calls may return a safety audit error; streaming output may end with `finish_reason="sensitive"` | Pass `user_id` to distinguish end users; the platform may block violating end-user requests so enterprise accounts are not affected by end-user abuse | +| OpenAI | Chat Completions may return `finish_reason="content_filter"` | Use Moderation and `safety_identifier`; repeated usage policy violations may lead to warnings, restrictions, or account deactivation | +| Anthropic | Streaming refusals may be exposed through `stop_reason="refusal"` | Reset, rewrite, or narrow context after a refusal; the AUP describes request limiting, output modification, suspension, or termination | +| Gemini | A safety-filtered candidate may return `finishReason=SAFETY`, and blocked content is not returned | Abuse monitoring covers prompts and outputs; follow-up actions can escalate from contacting the developer to temporary restrictions, suspension, or account closure | +| DeepSeek | Chat completion `finish_reason` includes `content_filter` | The `user` field can help content safety review; potential usage guideline violations may trigger a temporary suspension protocol | + +GLM is the most direct example. Its safety audit documentation describes the streaming safety finish signal, the recommendation to identify end users, and the possibility of blocking requests from violating end users. [GLM safety audit documentation](https://docs.bigmodel.cn/cn/guide/platform/securityaudit) + +OpenAI defines `content_filter` as a Chat Completions finish reason. Its safety best practices recommend using `safety_identifier` for end users so policy violations can be attributed more precisely than a shared API key alone. OpenAI help documentation also says repeated usage policy violations may lead to account deactivation. [Safety best practices](https://developers.openai.com/api/docs/guides/safety-best-practices/) [Why Was My OpenAI Account Deactivated?](https://help.openai.com/en/articles/10562188) + +Anthropic distinguishes ordinary stops from safety refusals in its Claude streaming refusal guidance: when the streaming classifier intervenes, the response can carry `stop_reason="refusal"`. It also recommends that applications do not keep feeding refused content back into later context, and instead reset the conversation, rewrite the prompt, or narrow the task. The Anthropic AUP says it may limit requests, block or modify outputs, and suspend or terminate access when necessary. [Handle streaming refusals](https://platform.claude.com/docs/en/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals) [Acceptable Use Policy](https://www.anthropic.com/legal/aup) + +Gemini safety documentation emphasizes another shape of intervention. A prompt may be blocked before generation, and a candidate may be filtered after generation. When a response candidate is stopped by safety policy, the response can expose `finishReason=SAFETY` without returning the blocked content itself. Gemini API terms also say abuse monitoring covers prompts and outputs and list progressively stronger follow-up actions. [Gemini safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) [Gemini API Additional Terms of Service](https://ai.google.dev/gemini-api/terms) + +DeepSeek lists `content_filter` as a chat completion finish reason and describes the request `user` field as helpful for content safety review. Its FAQ also says potential usage guideline violations may trigger a temporary suspension process. [Create Chat Completion](https://api-docs.deepseek.com/api/create-chat-completion) + +Some providers intervene earlier or at a layer outside the model message. For example, Azure OpenAI tells applications to inspect `finish_reason` because `content_filter` may leave a completion incomplete. Amazon Bedrock Guardrails can return `stopReason="guardrail_intervened"` in a response. In Alibaba Cloud Model Studio guardrail examples, output-side blocking may also appear directly as a `DataInspectionFailed` error. Together, these examples show that a safety intervention may be a stop signal in a model message or an API-level error. Applications need more than one handling path. [Azure OpenAI content filtering](https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/content-filter) [Amazon Bedrock Guardrails](https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails-use-converse-api.html) + +## What DeerFlow Does at This Boundary + +`SafetyFinishReasonMiddleware` has a narrow responsibility. It does not replace provider content review, and it does not rewrite every refusal into the same error. It only intervenes when both conditions below are true: + +1. The provider response carries a configured safety termination signal. +2. The current `AIMessage` still contains non-empty `tool_calls`. + +When it intervenes, it: + +1. Clears structured tool calls and residual tool-call fields in raw provider metadata. +2. Prevents those tool arguments from reaching the tool node for execution. +3. Preserves already generated partial text and appends a user-facing explanation. +4. Records the detector, reason field, reason value, and suppressed tool names and counts. +5. Avoids writing tool arguments that may themselves contain filtered content into audit events again. + +This makes the safety termination signal take priority over the fact that tool calls are present in the response. For the Agent runtime, that is the more conservative and more correct control flow. + +## Default Configuration + +The default configuration only needs `safety_finish_reason` enabled: + +```yaml +safety_finish_reason: + enabled: true +``` + +When `detectors` is not configured explicitly, DeerFlow uses the built-in detector set: + +| Detector | Default match | +| --- | --- | +| `OpenAICompatibleContentFilterDetector` | `finish_reason="content_filter"` | +| `AnthropicRefusalDetector` | `stop_reason="refusal"` | +| `GeminiSafetyDetector` | Gemini safety-related `finish_reason` values such as `SAFETY`, `BLOCKLIST`, `PROHIBITED_CONTENT`, `SPII`, and `RECITATION` | + +This default set covers common DeerFlow paths for OpenAI-compatible providers, Anthropic, and Gemini. It does not treat a normal `finish_reason="tool_calls"` as a safety termination, and it does not fold length truncation such as `length` or `max_tokens` into the safety category. + +## Example: Extend the Streaming Safety Finish Signal for GLM + +GLM streaming responses use `sensitive` as the safety finish value. If the current adapter preserves that value in `AIMessage.response_metadata.finish_reason` or `additional_kwargs.finish_reason`, it can be handled through the configurable finish reason set on the OpenAI-compatible detector: + +```yaml +safety_finish_reason: + enabled: true + detectors: + - use: deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector + config: + finish_reasons: ["content_filter", "sensitive"] + + - use: deerflow.agents.middlewares.safety_termination_detectors:AnthropicRefusalDetector + + - use: deerflow.agents.middlewares.safety_termination_detectors:GeminiSafetyDetector +``` + +Two configuration details matter here. + +First, `detectors` replaces the default list. It does not append one item to it. The example therefore keeps the Anthropic and Gemini detectors while adding GLM's `sensitive` value. + +Second, this middleware handles safety finish signals that have already reached a model message. If the provider returns a safety audit error at the API layer, such as a synchronous GLM safety audit error code, the caller still needs to handle it in the LLM or API error path. + +## Boundary + +`SafetyFinishReasonMiddleware` solves a specific Agent control-flow problem. It is not a complete content safety solution. It does not replace moderation, permission isolation, user governance, or provider-side review, and it does not cover every plain-text refusal. + +This boundary is still worth protecting explicitly: when a provider has already stopped output for safety reasons, a tool-using Agent should treat that turn as interrupted output, not executable tool intent. diff --git a/frontend/src/content/zh/posts/provider-safety-termination-in-tool-agents.mdx b/frontend/src/content/zh/posts/provider-safety-termination-in-tool-agents.mdx new file mode 100644 index 000000000..4979fa397 --- /dev/null +++ b/frontend/src/content/zh/posts/provider-safety-termination-in-tool-agents.mdx @@ -0,0 +1,125 @@ +--- +title: 工具型 Agent 需要正确处理模型提供商的安全中止信号 +description: 当模型输出因安全策略被中止时,为什么不能继续执行残留的工具调用,以及如何在 DeerFlow 中配置 provider detector。 +date: 2026-05-22 +tags: + - Safety + - Agents + - Model Providers +--- + +## 工具型 Agent 需要正确处理模型提供商的安全中止信号 + +当大模型提供商认为输入或输出触发了安全策略时,最理想的结果不是“模型少说了几句话”,而是应用已经明确知道这一轮生成被中止了。对于普通聊天界面,这通常表现为拒答、过滤后的文本,或者一个错误响应。对于能调用工具的 Agent,风险会更高:如果 provider 已经中止输出,但响应里仍残留了 `tool_calls`,这些工具参数很可能只生成了一半。 + +这类半截工具调用不应被当成正常意图执行。一个被截断的 `write_file` 可能写出不完整的报告;一个被截断的 `bash` 调用可能带着残缺参数进入沙箱;Agent 看到失败结果后还可能继续重试,反复触发同一条安全规则。 + +[PR #3035](https://github.com/bytedance/deer-flow/pull/3035) 处理的就是这个边界:当 provider 用安全信号中止生成,同时响应仍带有工具调用时,DeerFlow 应先压制这些工具调用,再把这一轮作为安全中止事件记录下来。 + +## 为什么需要单独处理安全中止 + +安全中止不是普通的工具调用结束原因。 + +一次健康的工具轮次通常由 provider 明确告诉应用“现在应该调用工具”。但安全中止表达的是另一件事:输出已经被 provider 的策略拦住,或者流式生成已经被提前切断。此时即使响应对象里还能看到工具调用片段,也不能假设它的 JSON 参数、文件内容或命令文本已经完整。 + +在真实 Agent 运行中,这会同时产生两类风险: + +| 风险 | 影响 | +| --- | --- | +| 运行时风险 | 执行被截断的工具参数,产生损坏文件、异常命令、重复重试或工具循环 | +| provider 风险 | 应用反复把同类违规输入或输出送到 provider,累积安全审核和风控压力 | + +第二类风险不能被忽略。不同 provider 的处置力度不同,但官方材料已经表明,安全策略不仅影响单次 completion,也可能影响终端用户、API 访问能力或账号状态。 + +## 各家 provider 公开了什么信号和处置方式 + +provider 并没有统一的字段名,也没有统一的处罚流程。部署方至少要区分两层信息: + +1. 这一轮响应里,什么信号说明生成被安全策略中止。 +2. 如果安全问题反复出现,provider 公开说明过哪些后续动作。 + +| Provider | 运行时信号 | 公开的后续处置或建议 | +| --- | --- | --- | +| GLM | 同步调用可能返回安全审核错误;流式输出可能以 `finish_reason="sensitive"` 结束 | 建议传入 `user_id` 区分终端用户;平台可封禁违规终端用户请求,避免企业账号受终端用户滥用影响 | +| OpenAI | Chat Completions 的 `finish_reason` 可为 `content_filter` | 建议使用 Moderation 和 `safety_identifier`;重复违反使用政策可能带来警告、限制或账号停用 | +| Anthropic | 流式拒绝可通过 `stop_reason="refusal"` 暴露 | 收到拒绝后应重置、改写或缩小上下文;AUP 说明可限制请求、修改输出、暂停或终止访问 | +| Gemini | 被安全过滤的 candidate 可返回 `finishReason=SAFETY`,且被拦截内容不会返回 | abuse monitoring 会检查 prompts 和 outputs;后续动作可从联系开发者升级到临时限制、暂停或账号关闭 | +| DeepSeek | Chat completion 的 `finish_reason` 枚举包含 `content_filter` | `user` 字段可帮助内容安全审核;潜在使用规范违规可能触发临时 suspension protocol | + +GLM 的说明最直接。它的安全审核文档同时给出了流式安全结束信号、终端用户标识建议,以及对违规终端用户请求做封禁处理的说明。[GLM 安全审核文档](https://docs.bigmodel.cn/cn/guide/platform/securityaudit) + +OpenAI 把 `content_filter` 定义为 Chat Completions 的一种 finish reason,并在安全最佳实践中推荐对终端用户使用 `safety_identifier`,以便在违反策略时定位到具体用户而不是只看到一个共享的 API key。OpenAI 的帮助文档还说明,重复违反使用政策可能导致账号被停用。 [Safety best practices](https://developers.openai.com/api/docs/guides/safety-best-practices/) [Why Was My OpenAI Account Deactivated?](https://help.openai.com/en/articles/10562188) + +Anthropic 在 Claude 流式拒绝说明中明确区分了普通停止和安全拒绝:当 streaming classifier 介入时,响应可以带有 `stop_reason="refusal"`。它同时建议应用不要把被拒绝内容继续塞回下一轮上下文,而应重置对话、改写提示或缩小任务范围。Anthropic AUP 也说明,它可以限制请求、拦截或修改输出,并在必要时暂停或终止访问。[Handle streaming refusals](https://platform.claude.com/docs/en/test-and-evaluate/strengthen-guardrails/handle-streaming-refusals) [Acceptable Use Policy](https://www.anthropic.com/legal/aup) + +Gemini 的安全文档则强调另一种形态:prompt 可能在生成前被拦截,candidate 也可能在生成后被过滤;当 response candidate 被安全策略拦下时,可以看到 `finishReason=SAFETY`,但不会拿到被拦截内容本身。Gemini API 的使用政策还说明,abuse monitoring 会覆盖 prompts 和 outputs,并列出了逐步升级的处置动作。[Gemini safety settings](https://ai.google.dev/gemini-api/docs/safety-settings) [Gemini API Additional Terms of Service](https://ai.google.dev/gemini-api/terms) + +DeepSeek 的 API 文档把 `content_filter` 列为 chat completion finish reason,并把请求里的 `user` 字段说明为有助于内容安全审核。它的 FAQ 也说明,潜在违反使用规范的场景可能触发临时暂停流程。[Create Chat Completion](https://api-docs.deepseek.com/api/create-chat-completion) [DeepSeek FAQ](https://api-docs.deepseek.com/faq) + +还有一些 provider 会在更早或更外层的位置拦截请求。例如 Azure OpenAI 提醒应用检查 `finish_reason`,因为 `content_filter` 可能让 completion 不完整;Amazon Bedrock Guardrails 可在响应中返回 `stopReason="guardrail_intervened"`;阿里云百炼的安全护栏示例里,输出侧拦截也可能直接表现为 `DataInspectionFailed` 错误。它们共同说明了一点:安全拦截既可能是模型消息里的停止信号,也可能是 API 层错误,应用不能只准备一种处理路径。[Azure OpenAI content filtering](https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/content-filter) [Amazon Bedrock Guardrails](https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails-use-converse-api.html) + +## DeerFlow 在这条边界上做什么 + +`SafetyFinishReasonMiddleware` 的职责很窄:它不替代 provider 的内容审核,也不把所有拒答都改写成同一种错误。它只在下面两个条件同时成立时介入: + +1. provider 响应携带了已配置的安全中止信号。 +2. 当前 `AIMessage` 仍包含非空的 `tool_calls`。 + +介入后,它会: + +1. 清空结构化工具调用以及 raw provider metadata 中残留的工具调用字段。 +2. 阻止这些工具参数进入工具节点执行。 +3. 保留已经生成的部分文本,并追加面向用户的说明。 +4. 记录 detector、reason 字段、reason 值、被压制的工具名和数量。 +5. 避免把可能正是被过滤内容的工具参数再次写入审计事件。 + +这意味着安全中止信号的优先级高于“响应里看到了工具调用”。对于 Agent 运行时,这是更保守也更正确的控制流。 + +## 默认配置 + +默认情况下只需要启用 `safety_finish_reason`: + +```yaml +safety_finish_reason: + enabled: true +``` + +不显式配置 `detectors` 时,DeerFlow 使用内置 detector 集合: + +| Detector | 默认匹配 | +| --- | --- | +| `OpenAICompatibleContentFilterDetector` | `finish_reason="content_filter"` | +| `AnthropicRefusalDetector` | `stop_reason="refusal"` | +| `GeminiSafetyDetector` | Gemini 安全相关 `finish_reason`,例如 `SAFETY`、`BLOCKLIST`、`PROHIBITED_CONTENT`、`SPII`、`RECITATION` | + +这个默认集合覆盖了 DeerFlow 常见的 OpenAI-compatible provider、Anthropic 和 Gemini 路径。它不会把普通 `finish_reason="tool_calls"` 当成安全中止,也不会把 `length`、`max_tokens` 之类的长度截断混入安全分类。 + +## 例子:为 GLM 扩展流式安全结束信号 + +GLM 流式响应使用的安全结束值是 `sensitive`。如果当前适配层把这个值保留在 `AIMessage.response_metadata.finish_reason` 或 `additional_kwargs.finish_reason` 中,可以通过 OpenAI-compatible detector 的可配置 finish reason 集合接入: + +```yaml +safety_finish_reason: + enabled: true + detectors: + - use: deerflow.agents.middlewares.safety_termination_detectors:OpenAICompatibleContentFilterDetector + config: + finish_reasons: ["content_filter", "sensitive"] + + - use: deerflow.agents.middlewares.safety_termination_detectors:AnthropicRefusalDetector + + - use: deerflow.agents.middlewares.safety_termination_detectors:GeminiSafetyDetector +``` + +这里有两个配置细节需要注意。 + +第一,`detectors` 是覆盖默认列表,不是向默认列表追加一项。因此为了给 GLM 增加 `sensitive`,示例里也保留了 Anthropic 和 Gemini detector。 + +第二,这个 middleware 处理的是已经进入模型消息的安全结束信号。如果 provider 在 API 层直接返回安全审核错误,例如 GLM 同步调用的安全审核错误码,调用方还需要在 LLM/API 错误处理路径里单独处理它。 + + +## 边界 + +`SafetyFinishReasonMiddleware` 解决的是一个明确的 Agent 控制流问题,不是完整的内容安全方案。它不替代 moderation、权限隔离、用户治理或 provider 自身的审核策略,也不覆盖每一种普通文本拒答。 + +但这一条边界值得单独守住:当 provider 已经因为安全原因停下输出时,工具型 Agent 应把这一轮视为被中断的输出,而不是可执行的工具意图。 From 2eeb597985cdd1f1710065a05f46021f1eea10dc Mon Sep 17 00:00:00 2001 From: Lawrance_YXLiao <32213920+kibabsquirrel@users.noreply.github.com> Date: Fri, 22 May 2026 21:42:14 +0800 Subject: [PATCH 69/86] fix(runs): expose active progress counters (#3148) * fix(runs): expose active progress counters * fix(runs): avoid delayed progress flush on completion * fix(runs): tighten progress snapshot semantics * fix(runs): preserve omitted progress fields * chore(runs): remove duplicate journal initialization --- backend/app/gateway/routers/thread_runs.py | 27 +++- .../harness/deerflow/persistence/run/sql.py | 43 +++++- .../harness/deerflow/runtime/journal.py | 73 ++++++++++- .../harness/deerflow/runtime/runs/manager.py | 47 +++++++ .../deerflow/runtime/runs/store/base.py | 20 ++- .../deerflow/runtime/runs/store/memory.py | 12 +- .../harness/deerflow/runtime/runs/worker.py | 3 +- backend/tests/test_run_journal.py | 104 +++++++++++++++ backend/tests/test_run_repository.py | 122 ++++++++++++++++++ backend/tests/test_thread_token_usage.py | 27 ++++ 10 files changed, 468 insertions(+), 10 deletions(-) diff --git a/backend/app/gateway/routers/thread_runs.py b/backend/app/gateway/routers/thread_runs.py index 294fa9799..a542593b2 100644 --- a/backend/app/gateway/routers/thread_runs.py +++ b/backend/app/gateway/routers/thread_runs.py @@ -66,6 +66,14 @@ class RunResponse(BaseModel): multitask_strategy: str = "reject" created_at: str = "" updated_at: str = "" + total_input_tokens: int = 0 + total_output_tokens: int = 0 + total_tokens: int = 0 + llm_call_count: int = 0 + lead_agent_tokens: int = 0 + subagent_tokens: int = 0 + middleware_tokens: int = 0 + message_count: int = 0 class ThreadTokenUsageModelBreakdown(BaseModel): @@ -111,6 +119,14 @@ def _record_to_response(record: RunRecord) -> RunResponse: multitask_strategy=record.multitask_strategy, created_at=record.created_at, updated_at=record.updated_at, + total_input_tokens=record.total_input_tokens, + total_output_tokens=record.total_output_tokens, + total_tokens=record.total_tokens, + llm_call_count=record.llm_call_count, + lead_agent_tokens=record.lead_agent_tokens, + subagent_tokens=record.subagent_tokens, + middleware_tokens=record.middleware_tokens, + message_count=record.message_count, ) @@ -402,8 +418,15 @@ async def list_run_events( @router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse) @require_permission("threads", "read", owner_check=True) -async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse: +async def thread_token_usage( + thread_id: str, + request: Request, + include_active: bool = Query(default=False, description="Include running run progress snapshots"), +) -> ThreadTokenUsageResponse: """Thread-level token usage aggregation.""" run_store = get_run_store(request) - agg = await run_store.aggregate_tokens_by_thread(thread_id) + if include_active: + agg = await run_store.aggregate_tokens_by_thread(thread_id, include_active=True) + else: + agg = await run_store.aggregate_tokens_by_thread(thread_id) return ThreadTokenUsageResponse(thread_id=thread_id, **agg) diff --git a/backend/packages/harness/deerflow/persistence/run/sql.py b/backend/packages/harness/deerflow/persistence/run/sql.py index 5679cc68f..1be9fb159 100644 --- a/backend/packages/harness/deerflow/persistence/run/sql.py +++ b/backend/packages/harness/deerflow/persistence/run/sql.py @@ -227,9 +227,48 @@ class RunRepository(RunStore): await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) await session.commit() - async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: + async def update_run_progress( + self, + run_id: str, + *, + total_input_tokens: int | None = None, + total_output_tokens: int | None = None, + total_tokens: int | None = None, + llm_call_count: int | None = None, + lead_agent_tokens: int | None = None, + subagent_tokens: int | None = None, + middleware_tokens: int | None = None, + message_count: int | None = None, + last_ai_message: str | None = None, + first_human_message: str | None = None, + ) -> None: + """Update token usage + convenience fields while a run is still active.""" + values: dict[str, Any] = {"updated_at": datetime.now(UTC)} + optional_counters = { + "total_input_tokens": total_input_tokens, + "total_output_tokens": total_output_tokens, + "total_tokens": total_tokens, + "llm_call_count": llm_call_count, + "lead_agent_tokens": lead_agent_tokens, + "subagent_tokens": subagent_tokens, + "middleware_tokens": middleware_tokens, + "message_count": message_count, + } + for key, value in optional_counters.items(): + if value is not None: + values[key] = value + if last_ai_message is not None: + values["last_ai_message"] = last_ai_message[:2000] + if first_human_message is not None: + values["first_human_message"] = first_human_message[:2000] + async with self._sf() as session: + await session.execute(update(RunRow).where(RunRow.run_id == run_id, RunRow.status == "running").values(**values)) + await session.commit() + + async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]: """Aggregate token usage via a single SQL GROUP BY query.""" - _completed = RunRow.status.in_(("success", "error")) + statuses = ("success", "error", "running") if include_active else ("success", "error") + _completed = RunRow.status.in_(statuses) _thread = RunRow.thread_id == thread_id model_name = func.coalesce(RunRow.model_name, "unknown") diff --git a/backend/packages/harness/deerflow/runtime/journal.py b/backend/packages/harness/deerflow/runtime/journal.py index 8a9382e23..a12ebd98b 100644 --- a/backend/packages/harness/deerflow/runtime/journal.py +++ b/backend/packages/harness/deerflow/runtime/journal.py @@ -20,7 +20,7 @@ from __future__ import annotations import asyncio import logging import time -from collections.abc import Mapping +from collections.abc import Awaitable, Callable, Mapping from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, cast from uuid import UUID @@ -46,6 +46,8 @@ class RunJournal(BaseCallbackHandler): *, track_token_usage: bool = True, flush_threshold: int = 20, + progress_reporter: Callable[[dict], Awaitable[None]] | None = None, + progress_flush_interval: float = 5.0, ): super().__init__() self.run_id = run_id @@ -53,10 +55,16 @@ class RunJournal(BaseCallbackHandler): self._store = event_store self._track_tokens = track_token_usage self._flush_threshold = flush_threshold + self._progress_reporter = progress_reporter + self._progress_flush_interval = progress_flush_interval # Write buffer self._buffer: list[dict] = [] self._pending_flush_tasks: set[asyncio.Task[None]] = set() + self._pending_progress_task: asyncio.Task[None] | None = None + self._pending_progress_delayed = False + self._progress_dirty = False + self._last_progress_flush = 0.0 # Token accumulators self._total_input_tokens = 0 @@ -294,6 +302,8 @@ class RunJournal(BaseCallbackHandler): else: self._lead_agent_tokens += total_tk + self._schedule_progress_flush() + if messages: self._counted_message_llm_run_ids.add(str(run_id)) @@ -445,6 +455,8 @@ class RunJournal(BaseCallbackHandler): else: self._lead_agent_tokens += total_tk + self._schedule_progress_flush() + def set_first_human_message(self, content: str) -> None: """Record the first human message for convenience fields.""" self._first_human_msg = content[:2000] if content else None @@ -474,6 +486,14 @@ class RunJournal(BaseCallbackHandler): """Force flush remaining buffer. Called in worker's finally block.""" if self._pending_flush_tasks: await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True) + while self._pending_progress_task is not None and not self._pending_progress_task.done(): + if self._pending_progress_delayed: + self._pending_progress_task.cancel() + await asyncio.gather(self._pending_progress_task, return_exceptions=True) + self._progress_dirty = False + self._pending_progress_delayed = False + break + await asyncio.gather(self._pending_progress_task, return_exceptions=True) while self._buffer: batch = self._buffer[: self._flush_threshold] @@ -484,6 +504,57 @@ class RunJournal(BaseCallbackHandler): self._buffer = batch + self._buffer raise + def _schedule_progress_flush(self) -> None: + """Best-effort throttled progress snapshot for active run visibility.""" + if self._progress_reporter is None: + return + now = time.monotonic() + elapsed = now - self._last_progress_flush + if elapsed < self._progress_flush_interval: + self._progress_dirty = True + self._schedule_delayed_progress_flush(self._progress_flush_interval - elapsed) + return + if self._pending_progress_task is not None and not self._pending_progress_task.done(): + self._progress_dirty = True + return + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return + self._progress_dirty = False + self._pending_progress_task = loop.create_task(self._flush_progress_async(snapshot=self.get_completion_data())) + + def _schedule_delayed_progress_flush(self, delay: float) -> None: + if self._pending_progress_task is not None and not self._pending_progress_task.done(): + return + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return + delay = max(0.0, delay) + self._pending_progress_delayed = delay > 0 + self._pending_progress_task = loop.create_task(self._flush_progress_async(delay=delay)) + + async def _flush_progress_async(self, *, snapshot: dict | None = None, delay: float = 0.0) -> None: + if self._progress_reporter is None: + return + if delay > 0: + self._pending_progress_delayed = True + await asyncio.sleep(delay) + self._pending_progress_delayed = False + dirty_before_write = self._progress_dirty + self._progress_dirty = False + snapshot_to_write = snapshot or self.get_completion_data() + try: + await self._progress_reporter(snapshot_to_write) + self._last_progress_flush = time.monotonic() + except Exception: + logger.warning("Failed to persist progress snapshot for run %s", self.run_id, exc_info=True) + if dirty_before_write or self._progress_dirty: + self._progress_dirty = False + self._pending_progress_task = None + self._schedule_delayed_progress_flush(self._progress_flush_interval) + def get_completion_data(self) -> dict: """Return accumulated token and message data for run completion.""" return { diff --git a/backend/packages/harness/deerflow/runtime/runs/manager.py b/backend/packages/harness/deerflow/runtime/runs/manager.py index ea78f89c9..5387689dc 100644 --- a/backend/packages/harness/deerflow/runtime/runs/manager.py +++ b/backend/packages/harness/deerflow/runtime/runs/manager.py @@ -38,6 +38,16 @@ class RunRecord: error: str | None = None model_name: str | None = None store_only: bool = False + total_input_tokens: int = 0 + total_output_tokens: int = 0 + total_tokens: int = 0 + llm_call_count: int = 0 + lead_agent_tokens: int = 0 + subagent_tokens: int = 0 + middleware_tokens: int = 0 + message_count: int = 0 + last_ai_message: str | None = None + first_human_message: str | None = None class RunManager: @@ -102,16 +112,53 @@ class RunManager: error=row.get("error"), model_name=row.get("model_name"), store_only=True, + total_input_tokens=row.get("total_input_tokens") or 0, + total_output_tokens=row.get("total_output_tokens") or 0, + total_tokens=row.get("total_tokens") or 0, + llm_call_count=row.get("llm_call_count") or 0, + lead_agent_tokens=row.get("lead_agent_tokens") or 0, + subagent_tokens=row.get("subagent_tokens") or 0, + middleware_tokens=row.get("middleware_tokens") or 0, + message_count=row.get("message_count") or 0, + last_ai_message=row.get("last_ai_message"), + first_human_message=row.get("first_human_message"), ) async def update_run_completion(self, run_id: str, **kwargs) -> None: """Persist token usage and completion data to the backing store.""" + async with self._lock: + record = self._runs.get(run_id) + if record is not None: + for key, value in kwargs.items(): + if key == "status": + continue + if hasattr(record, key) and value is not None: + setattr(record, key, value) + record.updated_at = _now_iso() if self._store is not None: try: await self._store.update_run_completion(run_id, **kwargs) except Exception: logger.warning("Failed to persist run completion for %s", run_id, exc_info=True) + async def update_run_progress(self, run_id: str, **kwargs) -> None: + """Persist a running token/message snapshot without changing status.""" + should_persist = True + async with self._lock: + record = self._runs.get(run_id) + if record is not None: + should_persist = record.status == RunStatus.running + if record is not None and should_persist: + for key, value in kwargs.items(): + if hasattr(record, key) and value is not None: + setattr(record, key, value) + record.updated_at = _now_iso() + if should_persist and self._store is not None: + try: + await self._store.update_run_progress(run_id, **kwargs) + except Exception: + logger.warning("Failed to persist run progress for %s", run_id, exc_info=True) + async def create( self, thread_id: str, diff --git a/backend/packages/harness/deerflow/runtime/runs/store/base.py b/backend/packages/harness/deerflow/runtime/runs/store/base.py index 10c90d7ea..c5ac18212 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/base.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/base.py @@ -95,12 +95,30 @@ class RunStore(abc.ABC): ) -> None: pass + async def update_run_progress( + self, + run_id: str, + *, + total_input_tokens: int | None = None, + total_output_tokens: int | None = None, + total_tokens: int | None = None, + llm_call_count: int | None = None, + lead_agent_tokens: int | None = None, + subagent_tokens: int | None = None, + middleware_tokens: int | None = None, + message_count: int | None = None, + last_ai_message: str | None = None, + first_human_message: str | None = None, + ) -> None: + """Persist a best-effort running snapshot without changing run status.""" + return None + @abc.abstractmethod async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]: pass @abc.abstractmethod - async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: + async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]: """Aggregate token usage for completed runs in a thread. Returns a dict with keys: total_tokens, total_input_tokens, diff --git a/backend/packages/harness/deerflow/runtime/runs/store/memory.py b/backend/packages/harness/deerflow/runtime/runs/store/memory.py index 56ef02b5b..d241f2ecc 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/memory.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/memory.py @@ -82,14 +82,22 @@ class MemoryRunStore(RunStore): self._runs[run_id][key] = value self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat() + async def update_run_progress(self, run_id, **kwargs): + if run_id in self._runs and self._runs[run_id].get("status") == "running": + for key, value in kwargs.items(): + if value is not None: + self._runs[run_id][key] = value + self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat() + async def list_pending(self, *, before=None): now = before or datetime.now(UTC).isoformat() results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now] results.sort(key=lambda r: r["created_at"]) return results - async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: - completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")] + async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]: + statuses = ("success", "error", "running") if include_active else ("success", "error") + completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses] by_model: dict[str, dict] = {} for r in completed: model = r.get("model_name") or "unknown" diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index 694464fe3..d84b3edf9 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -153,8 +153,6 @@ async def run_agent( journal = None - journal = None - # Track whether "events" was requested but skipped if "events" in requested_modes: logger.info( @@ -177,6 +175,7 @@ async def run_agent( thread_id=thread_id, event_store=event_store, track_token_usage=getattr(run_events_config, "track_token_usage", True), + progress_reporter=lambda snapshot: run_manager.update_run_progress(run_id, **snapshot), ) # 1. Mark running diff --git a/backend/tests/test_run_journal.py b/backend/tests/test_run_journal.py index 8615caa49..0b495954b 100644 --- a/backend/tests/test_run_journal.py +++ b/backend/tests/test_run_journal.py @@ -714,6 +714,110 @@ class TestExternalUsageRecords: assert j._subagent_tokens == 0 +class TestProgressSnapshots: + @pytest.mark.anyio + async def test_on_llm_end_reports_progress_snapshot(self): + snapshots: list[dict] = [] + + async def reporter(snapshot: dict) -> None: + snapshots.append(snapshot) + + store = MemoryRunEventStore() + j = RunJournal( + "r1", + "t1", + store, + flush_threshold=100, + progress_reporter=reporter, + progress_flush_interval=0, + ) + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + j.on_llm_end(_make_llm_response("Answer", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) + await j.flush() + + assert snapshots + assert snapshots[-1]["total_tokens"] == 15 + assert snapshots[-1]["llm_call_count"] == 1 + assert snapshots[-1]["message_count"] == 1 + assert snapshots[-1]["last_ai_message"] == "Answer" + + @pytest.mark.anyio + async def test_throttled_progress_flush_emits_trailing_snapshot(self): + snapshots: list[dict] = [] + trailing_seen = asyncio.Event() + + async def reporter(snapshot: dict) -> None: + snapshots.append(snapshot) + if snapshot["total_tokens"] == 45: + trailing_seen.set() + + store = MemoryRunEventStore() + j = RunJournal( + "r1", + "t1", + store, + flush_threshold=100, + progress_reporter=reporter, + progress_flush_interval=0.01, + ) + j.on_llm_end( + _make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), + run_id=uuid4(), + parent_run_id=None, + tags=["lead_agent"], + ) + j.on_llm_end( + _make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}), + run_id=uuid4(), + parent_run_id=None, + tags=["lead_agent"], + ) + await asyncio.wait_for(trailing_seen.wait(), timeout=1.0) + await j.flush() + + assert len(snapshots) >= 2 + assert snapshots[-1]["total_tokens"] == 45 + assert snapshots[-1]["llm_call_count"] == 2 + assert snapshots[-1]["last_ai_message"] == "Second" + + @pytest.mark.anyio + async def test_flush_cancels_delayed_progress_without_final_progress_write(self): + snapshots: list[dict] = [] + + async def reporter(snapshot: dict) -> None: + snapshots.append(snapshot) + + store = MemoryRunEventStore() + j = RunJournal( + "r1", + "t1", + store, + flush_threshold=100, + progress_reporter=reporter, + progress_flush_interval=10.0, + ) + j.on_llm_end( + _make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), + run_id=uuid4(), + parent_run_id=None, + tags=["lead_agent"], + ) + await asyncio.sleep(0) + assert snapshots[-1]["total_tokens"] == 15 + j.on_llm_end( + _make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}), + run_id=uuid4(), + parent_run_id=None, + tags=["lead_agent"], + ) + + await asyncio.wait_for(j.flush(), timeout=0.2) + + assert snapshots[-1]["total_tokens"] == 15 + assert snapshots[-1]["llm_call_count"] == 1 + assert snapshots[-1]["last_ai_message"] == "First" + + class TestChatModelStartHumanMessage: """Tests for on_chat_model_start extracting the first human message.""" diff --git a/backend/tests/test_run_repository.py b/backend/tests/test_run_repository.py index 5809db517..f18e51348 100644 --- a/backend/tests/test_run_repository.py +++ b/backend/tests/test_run_repository.py @@ -10,6 +10,7 @@ from sqlalchemy.dialects import postgresql from deerflow.persistence.run import RunRepository from deerflow.runtime import RunManager, RunStatus +from deerflow.runtime.runs.store.base import RunStore async def _make_repo(tmp_path): @@ -26,6 +27,42 @@ async def _cleanup(): await close_engine() +class _CustomRunStoreWithoutProgress(RunStore): + async def put(self, *args, **kwargs): + return None + + async def get(self, *args, **kwargs): + return None + + async def list_by_thread(self, *args, **kwargs): + return [] + + async def update_status(self, *args, **kwargs): + return None + + async def delete(self, *args, **kwargs): + return None + + async def update_model_name(self, *args, **kwargs): + return None + + async def update_run_completion(self, *args, **kwargs): + return None + + async def list_pending(self, *args, **kwargs): + return [] + + async def aggregate_tokens_by_thread(self, *args, **kwargs): + return {} + + +@pytest.mark.anyio +async def test_update_run_progress_defaults_to_noop_for_custom_store(): + store = _CustomRunStoreWithoutProgress() + + await store.update_run_progress("r1", total_tokens=1) + + class TestRunRepository: @pytest.mark.anyio async def test_put_and_get(self, tmp_path): @@ -170,6 +207,69 @@ class TestRunRepository: assert row["total_tokens"] == 100 await _cleanup() + @pytest.mark.anyio + async def test_update_run_progress_keeps_status_running(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", status="running") + await repo.update_run_progress( + "r1", + total_input_tokens=40, + total_output_tokens=10, + total_tokens=50, + llm_call_count=1, + message_count=2, + last_ai_message="partial answer", + ) + row = await repo.get("r1") + assert row["status"] == "running" + assert row["total_tokens"] == 50 + assert row["llm_call_count"] == 1 + assert row["message_count"] == 2 + assert row["last_ai_message"] == "partial answer" + await _cleanup() + + @pytest.mark.anyio + async def test_update_run_progress_preserves_omitted_fields(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", status="running") + await repo.update_run_progress( + "r1", + total_input_tokens=40, + total_output_tokens=10, + total_tokens=50, + llm_call_count=1, + lead_agent_tokens=30, + subagent_tokens=20, + message_count=2, + ) + + await repo.update_run_progress("r1", total_tokens=60, last_ai_message="updated") + + row = await repo.get("r1") + assert row["total_input_tokens"] == 40 + assert row["total_output_tokens"] == 10 + assert row["total_tokens"] == 60 + assert row["llm_call_count"] == 1 + assert row["lead_agent_tokens"] == 30 + assert row["subagent_tokens"] == 20 + assert row["message_count"] == 2 + assert row["last_ai_message"] == "updated" + await _cleanup() + + @pytest.mark.anyio + async def test_update_run_progress_skips_terminal_runs(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", status="running") + await repo.update_run_completion("r1", status="success", total_tokens=100, llm_call_count=1) + + await repo.update_run_progress("r1", total_tokens=200, llm_call_count=2) + + row = await repo.get("r1") + assert row["status"] == "success" + assert row["total_tokens"] == 100 + assert row["llm_call_count"] == 1 + await _cleanup() + @pytest.mark.anyio async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path): repo = await _make_repo(tmp_path) @@ -225,6 +325,28 @@ class TestRunRepository: } await _cleanup() + @pytest.mark.anyio + async def test_aggregate_tokens_by_thread_can_include_active_runs(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("success-run", thread_id="t1", status="running") + await repo.update_run_completion("success-run", status="success", total_tokens=100, lead_agent_tokens=100) + await repo.put("running-run", thread_id="t1", status="running") + await repo.update_run_progress("running-run", total_tokens=25, lead_agent_tokens=20, subagent_tokens=5) + + without_active = await repo.aggregate_tokens_by_thread("t1") + with_active = await repo.aggregate_tokens_by_thread("t1", include_active=True) + + assert without_active["total_tokens"] == 100 + assert without_active["total_runs"] == 1 + assert with_active["total_tokens"] == 125 + assert with_active["total_runs"] == 2 + assert with_active["by_caller"] == { + "lead_agent": 120, + "subagent": 5, + "middleware": 0, + } + await _cleanup() + @pytest.mark.anyio async def test_list_by_thread_ordered_desc(self, tmp_path): """list_by_thread returns newest first.""" diff --git a/backend/tests/test_thread_token_usage.py b/backend/tests/test_thread_token_usage.py index 713f6aa5f..19f8e0c19 100644 --- a/backend/tests/test_thread_token_usage.py +++ b/backend/tests/test_thread_token_usage.py @@ -53,3 +53,30 @@ def test_thread_token_usage_returns_stable_shape(): }, } run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1") + + +def test_thread_token_usage_can_include_active_runs(): + run_store = MagicMock() + run_store.aggregate_tokens_by_thread = AsyncMock( + return_value={ + "total_tokens": 175, + "total_input_tokens": 120, + "total_output_tokens": 55, + "total_runs": 3, + "by_model": {"unknown": {"tokens": 175, "runs": 3}}, + "by_caller": { + "lead_agent": 145, + "subagent": 25, + "middleware": 5, + }, + }, + ) + app = _make_app(run_store) + + with TestClient(app) as client: + response = client.get("/api/threads/thread-1/token-usage?include_active=true") + + assert response.status_code == 200 + assert response.json()["total_tokens"] == 175 + assert response.json()["total_runs"] == 3 + run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1", include_active=True) From f0bae286366074408cff07635df7c9e1b439fcf4 Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Fri, 22 May 2026 15:44:05 +0200 Subject: [PATCH 70/86] fix(middleware): handle repeated tool call ids (#3143) * fix(middleware): handle repeated tool call ids * add tests * refactor(middleware): rely on tool result queues --- .../dangling_tool_call_middleware.py | 13 ++-- .../test_dangling_tool_call_middleware.py | 64 +++++++++++++++++++ 2 files changed, 70 insertions(+), 7 deletions(-) diff --git a/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py index 000ca51a2..6026d834e 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/dangling_tool_call_middleware.py @@ -15,6 +15,7 @@ to the end of the message list as before_model + add_messages reducer would do. import json import logging +from collections import defaultdict, deque from collections.abc import Awaitable, Callable from typing import override @@ -109,10 +110,10 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): This normalizes model-bound causal order before provider serialization while preserving already-valid transcripts unchanged. """ - tool_messages_by_id: dict[str, ToolMessage] = {} + tool_messages_by_id: dict[str, deque[ToolMessage]] = defaultdict(deque) for msg in messages: if isinstance(msg, ToolMessage): - tool_messages_by_id.setdefault(msg.tool_call_id, msg) + tool_messages_by_id[msg.tool_call_id].append(msg) tool_call_ids: set[str] = set() for msg in messages: @@ -124,7 +125,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): tool_call_ids.add(tc_id) patched: list = [] - consumed_tool_msg_ids: set[str] = set() patch_count = 0 for msg in messages: if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids: @@ -136,13 +136,13 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): for tc in self._message_tool_calls(msg): tc_id = tc.get("id") - if not tc_id or tc_id in consumed_tool_msg_ids: + if not tc_id: continue - existing_tool_msg = tool_messages_by_id.get(tc_id) + tool_msg_queue = tool_messages_by_id.get(tc_id) + existing_tool_msg = tool_msg_queue.popleft() if tool_msg_queue else None if existing_tool_msg is not None: patched.append(existing_tool_msg) - consumed_tool_msg_ids.add(tc_id) else: patched.append( ToolMessage( @@ -152,7 +152,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]): status="error", ) ) - consumed_tool_msg_ids.add(tc_id) patch_count += 1 if patched == messages: diff --git a/backend/tests/test_dangling_tool_call_middleware.py b/backend/tests/test_dangling_tool_call_middleware.py index 5ecded924..34f1ac035 100644 --- a/backend/tests/test_dangling_tool_call_middleware.py +++ b/backend/tests/test_dangling_tool_call_middleware.py @@ -218,6 +218,70 @@ class TestBuildPatchedMessagesPatching: assert mw._build_patched_messages(msgs) is None + def test_reused_tool_call_ids_across_ai_turns_keep_their_own_tool_results(self): + mw = DanglingToolCallMiddleware() + msgs = [ + HumanMessage(content="summary", name="summary", additional_kwargs={"hide_from_ui": True}), + _ai_with_tool_calls( + [ + _tc("web_search", "web_search:11"), + _tc("web_search", "web_search:12"), + _tc("web_search", "web_search:13"), + ] + ), + _tool_msg("web_search:11", "web_search"), + _tool_msg("web_search:12", "web_search"), + _tool_msg("web_search:13", "web_search"), + _ai_with_tool_calls( + [ + _tc("web_search", "web_search:9"), + _tc("web_search", "web_search:10"), + _tc("web_search", "web_search:11"), + ] + ), + _tool_msg("web_search:9", "web_search"), + _tool_msg("web_search:10", "web_search"), + _tool_msg("web_search:11", "web_search"), + ] + + assert mw._build_patched_messages(msgs) is None + + def test_reused_tool_call_id_patches_second_dangling_occurrence(self): + mw = DanglingToolCallMiddleware() + msgs = [ + _ai_with_tool_calls([_tc("web_search", "web_search:11")]), + _tool_msg("web_search:11", "web_search"), + _ai_with_tool_calls([_tc("web_search", "web_search:11")]), + ] + + patched = mw._build_patched_messages(msgs) + + assert patched is not None + assert isinstance(patched[1], ToolMessage) + assert patched[1].tool_call_id == "web_search:11" + assert patched[1].status == "success" + assert isinstance(patched[3], ToolMessage) + assert patched[3].tool_call_id == "web_search:11" + assert patched[3].status == "error" + + def test_reused_tool_call_id_consumes_later_result_for_first_dangling_occurrence(self): + mw = DanglingToolCallMiddleware() + result = _tool_msg("web_search:11", "web_search") + msgs = [ + _ai_with_tool_calls([_tc("web_search", "web_search:11")]), + _ai_with_tool_calls([_tc("web_search", "web_search:11")]), + result, + ] + + patched = mw._build_patched_messages(msgs) + + assert patched is not None + assert patched[1] is result + assert patched[1].status == "success" + assert isinstance(patched[3], ToolMessage) + assert patched[3].tool_call_id == "web_search:11" + assert patched[3].status == "error" + def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self): mw = DanglingToolCallMiddleware() msgs = [ From 66d6a6a4e85e8d2accebf722a3a7eb70db2346e9 Mon Sep 17 00:00:00 2001 From: AochenShen99 Date: Sat, 23 May 2026 00:09:06 +0800 Subject: [PATCH 71/86] fix: harden run finalization persistence (#3155) * fix: harden run finalization persistence * style: format gateway recovery test * fix: align run repository return types * fix: harden completion recovery follow-up --- backend/app/gateway/deps.py | 35 +++ .../harness/deerflow/persistence/run/sql.py | 77 ++++-- .../harness/deerflow/runtime/runs/manager.py | 255 +++++++++++++++--- .../deerflow/runtime/runs/store/base.py | 18 +- .../deerflow/runtime/runs/store/memory.py | 10 + backend/tests/test_gateway_run_recovery.py | 127 +++++++++ backend/tests/test_run_manager.py | 240 +++++++++++++++++ backend/tests/test_run_repository.py | 49 +++- 8 files changed, 755 insertions(+), 56 deletions(-) create mode 100644 backend/tests/test_gateway_run_recovery.py diff --git a/backend/app/gateway/deps.py b/backend/app/gateway/deps.py index f045a2ee3..7f9674070 100644 --- a/backend/app/gateway/deps.py +++ b/backend/app/gateway/deps.py @@ -37,11 +37,36 @@ if TYPE_CHECKING: from app.gateway.auth.local_provider import LocalAuthProvider from app.gateway.auth.repositories.sqlite import SQLiteUserRepository from deerflow.persistence.thread_meta.base import ThreadMetaStore + from deerflow.runtime import RunRecord T = TypeVar("T") +async def _mark_latest_recovered_threads_error( + run_manager: RunManager, + thread_store: ThreadMetaStore, + recovered_runs: list[RunRecord], +) -> None: + """Mark thread status as error only when its newest run was recovered.""" + recovered_by_thread: dict[str, set[str]] = {} + for record in recovered_runs: + recovered_by_thread.setdefault(record.thread_id, set()).add(record.run_id) + + for thread_id, recovered_run_ids in recovered_by_thread.items(): + try: + latest_runs = await run_manager.list_by_thread(thread_id, user_id=None, limit=1) + except Exception: + logger.warning("Failed to find latest run for thread %s during run reconciliation", thread_id, exc_info=True) + continue + if not latest_runs or latest_runs[0].run_id not in recovered_run_ids: + continue + try: + await thread_store.update_status(thread_id, "error", user_id=None) + except Exception: + logger.warning("Failed to mark thread %s as error during run reconciliation", thread_id, exc_info=True) + + def get_config() -> AppConfig: """Return the freshest ``AppConfig`` for the current request. @@ -138,6 +163,16 @@ async def langgraph_runtime(app: FastAPI, startup_config: AppConfig) -> AsyncGen # RunManager with store backing for persistence app.state.run_manager = RunManager(store=app.state.run_store) + if getattr(config.database, "backend", None) == "sqlite": + from deerflow.utils.time import now_iso + + # Startup-only recovery: clean shutdowns return no active rows and + # the thread-status update below becomes a no-op. + recovered_runs = await app.state.run_manager.reconcile_orphaned_inflight_runs( + error="Gateway restarted before this run reached a durable final state.", + before=now_iso(), + ) + await _mark_latest_recovered_threads_error(app.state.run_manager, app.state.thread_store, recovered_runs) try: yield diff --git a/backend/packages/harness/deerflow/persistence/run/sql.py b/backend/packages/harness/deerflow/persistence/run/sql.py index 1be9fb159..7ca2ea1e1 100644 --- a/backend/packages/harness/deerflow/persistence/run/sql.py +++ b/backend/packages/harness/deerflow/persistence/run/sql.py @@ -94,25 +94,35 @@ class RunRepository(RunStore): created_at=None, follow_up_to_run_id=None, ): + """Insert or update a run row. + + ``RunManager`` retries ``put`` after transient SQLite failures. Making + this operation idempotent prevents a successful-but-unacknowledged first + commit from turning the retry into a primary-key failure. + """ resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put") now = datetime.now(UTC) - row = RunRow( - run_id=run_id, - thread_id=thread_id, - assistant_id=assistant_id, - user_id=resolved_user_id, - model_name=self._normalize_model_name(model_name), - status=status, - multitask_strategy=multitask_strategy, - metadata_json=self._safe_json(metadata) or {}, - kwargs_json=self._safe_json(kwargs) or {}, - error=error, - follow_up_to_run_id=follow_up_to_run_id, - created_at=datetime.fromisoformat(created_at) if created_at else now, - updated_at=now, - ) + created = datetime.fromisoformat(created_at) if created_at else now + values = { + "thread_id": thread_id, + "assistant_id": assistant_id, + "user_id": resolved_user_id, + "model_name": self._normalize_model_name(model_name), + "status": status, + "multitask_strategy": multitask_strategy, + "metadata_json": self._safe_json(metadata) or {}, + "kwargs_json": self._safe_json(kwargs) or {}, + "error": error, + "follow_up_to_run_id": follow_up_to_run_id, + "updated_at": now, + } async with self._sf() as session: - session.add(row) + row = await session.get(RunRow, run_id) + if row is None: + session.add(RunRow(run_id=run_id, created_at=created, **values)) + else: + for key, value in values.items(): + setattr(row, key, value) await session.commit() async def get( @@ -146,13 +156,14 @@ class RunRepository(RunStore): result = await session.execute(stmt) return [self._row_to_dict(r) for r in result.scalars()] - async def update_status(self, run_id, status, *, error=None): + async def update_status(self, run_id, status, *, error=None) -> bool: values: dict[str, Any] = {"status": status, "updated_at": datetime.now(UTC)} if error is not None: values["error"] = error async with self._sf() as session: - await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) + result = await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) await session.commit() + return result.rowcount != 0 async def update_model_name(self, run_id, model_name): async with self._sf() as session: @@ -187,6 +198,26 @@ class RunRepository(RunStore): result = await session.execute(stmt) return [self._row_to_dict(r) for r in result.scalars()] + async def list_inflight(self, *, before=None): + """Return persisted active runs for startup recovery.""" + if before is None: + before_dt = datetime.now(UTC) + elif isinstance(before, datetime): + before_dt = before + else: + before_dt = datetime.fromisoformat(before) + stmt = ( + select(RunRow) + .where( + RunRow.status.in_(("pending", "running")), + RunRow.created_at <= before_dt, + ) + .order_by(RunRow.created_at.asc()) + ) + async with self._sf() as session: + result = await session.execute(stmt) + return [self._row_to_dict(r) for r in result.scalars()] + async def update_run_completion( self, run_id: str, @@ -203,8 +234,11 @@ class RunRepository(RunStore): last_ai_message: str | None = None, first_human_message: str | None = None, error: str | None = None, - ) -> None: - """Update status + token usage + convenience fields on run completion.""" + ) -> bool: + """Update status + token usage + convenience fields on run completion. + + Returns ``False`` when no run row matched the requested ``run_id``. + """ values: dict[str, Any] = { "status": status, "total_input_tokens": total_input_tokens, @@ -224,8 +258,9 @@ class RunRepository(RunStore): if error is not None: values["error"] = error async with self._sf() as session: - await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) + result = await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) await session.commit() + return result.rowcount != 0 async def update_run_progress( self, diff --git a/backend/packages/harness/deerflow/runtime/runs/manager.py b/backend/packages/harness/deerflow/runtime/runs/manager.py index 5387689dc..c6bb5be26 100644 --- a/backend/packages/harness/deerflow/runtime/runs/manager.py +++ b/backend/packages/harness/deerflow/runtime/runs/manager.py @@ -4,7 +4,9 @@ from __future__ import annotations import asyncio import logging +import sqlite3 import uuid +from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any @@ -17,6 +19,57 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +_RETRYABLE_SQLITE_MESSAGES = ( + "database is locked", + "database table is locked", + "database is busy", +) + +_RETRYABLE_SQLITE_ERROR_CODES = { + sqlite3.SQLITE_BUSY, + sqlite3.SQLITE_LOCKED, +} + + +def _is_retryable_persistence_error(exc: BaseException) -> bool: + """Return True for transient SQLite persistence failures. + + SQLite lock contention normally surfaces through either sqlite3 exceptions + or SQLAlchemy wrappers. The short bounded retry here protects run status + finalization from transient writer pressure without hiding permanent + failures forever. + """ + + pending: list[BaseException] = [exc] + seen: set[int] = set() + while pending: + current = pending.pop() + if id(current) in seen: + continue + seen.add(id(current)) + + message = str(current).lower() + if any(fragment in message for fragment in _RETRYABLE_SQLITE_MESSAGES): + return True + if isinstance(current, (sqlite3.OperationalError, sqlite3.DatabaseError)): + error_code = getattr(current, "sqlite_errorcode", None) + if error_code in _RETRYABLE_SQLITE_ERROR_CODES: + return True + for chained in (getattr(current, "orig", None), current.__cause__, current.__context__): + if isinstance(chained, BaseException): + pending.append(chained) + return False + + +@dataclass(frozen=True) +class PersistenceRetryPolicy: + """Bounded retry policy for short run-store writes.""" + + max_attempts: int = 5 + initial_delay: float = 0.05 + max_delay: float = 1.0 + backoff_factor: float = 2.0 + @dataclass class RunRecord: @@ -58,38 +111,100 @@ class RunManager: that run history survives process restarts. """ - def __init__(self, store: RunStore | None = None) -> None: + def __init__( + self, + store: RunStore | None = None, + *, + persistence_retry_policy: PersistenceRetryPolicy | None = None, + ) -> None: self._runs: dict[str, RunRecord] = {} self._lock = asyncio.Lock() self._store = store + self._persistence_retry_policy = persistence_retry_policy or PersistenceRetryPolicy() - async def _persist_to_store(self, record: RunRecord) -> None: - """Best-effort persist run record to backing store.""" + @staticmethod + def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]: + return { + "thread_id": record.thread_id, + "assistant_id": record.assistant_id, + "status": record.status.value, + "multitask_strategy": record.multitask_strategy, + "metadata": record.metadata or {}, + "kwargs": record.kwargs or {}, + "error": error if error is not None else record.error, + "created_at": record.created_at, + "model_name": record.model_name, + } + + async def _call_store_with_retry( + self, + operation_name: str, + run_id: str, + operation: Callable[[], Awaitable[Any]], + ) -> Any: + """Run a short store operation with bounded retries for SQLite pressure.""" + policy = self._persistence_retry_policy + attempt = 1 + delay = policy.initial_delay + while True: + try: + return await operation() + except Exception as exc: + retryable = _is_retryable_persistence_error(exc) + if attempt >= policy.max_attempts or not retryable: + raise + logger.warning( + "Transient persistence failure during %s for run %s (attempt %d/%d); retrying", + operation_name, + run_id, + attempt, + policy.max_attempts, + exc_info=True, + ) + if delay > 0: + await asyncio.sleep(delay) + delay = min(policy.max_delay, delay * policy.backoff_factor if delay else policy.initial_delay) + attempt += 1 + + async def _persist_snapshot_to_store(self, run_id: str, payload: dict[str, Any]) -> bool: + """Best-effort persist a previously captured run snapshot.""" if self._store is None: - return + return True try: - await self._store.put( - record.run_id, - thread_id=record.thread_id, - assistant_id=record.assistant_id, - status=record.status.value, - multitask_strategy=record.multitask_strategy, - metadata=record.metadata or {}, - kwargs=record.kwargs or {}, - created_at=record.created_at, - model_name=record.model_name, + await self._call_store_with_retry( + "put", + run_id, + lambda: self._store.put(run_id, **payload), ) + return True except Exception: - logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True) + logger.warning("Failed to persist run %s to store", run_id, exc_info=True) + return False - async def _persist_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None: + async def _persist_to_store(self, record: RunRecord, *, error: str | None = None) -> bool: + """Best-effort persist run record to backing store.""" + return await self._persist_snapshot_to_store( + record.run_id, + self._store_put_payload(record, error=error), + ) + + async def _persist_status(self, record: RunRecord, status: RunStatus, *, error: str | None = None) -> bool: """Best-effort persist a status transition to the backing store.""" if self._store is None: - return + return True + row_recovery_payload = self._store_put_payload(record, error=error) try: - await self._store.update_status(run_id, status.value, error=error) + updated = await self._call_store_with_retry( + "update_status", + record.run_id, + lambda: self._store.update_status(record.run_id, status.value, error=error), + ) + if updated is False: + return await self._persist_snapshot_to_store(record.run_id, row_recovery_payload) + return True except Exception: - logger.warning("Failed to persist status update for run %s", run_id, exc_info=True) + logger.warning("Failed to persist status update for run %s", record.run_id, exc_info=True) + return False @staticmethod def _record_from_store(row: dict[str, Any]) -> RunRecord: @@ -126,6 +241,7 @@ class RunManager: async def update_run_completion(self, run_id: str, **kwargs) -> None: """Persist token usage and completion data to the backing store.""" + row_recovery_payload: dict[str, Any] | None = None async with self._lock: record = self._runs.get(run_id) if record is not None: @@ -135,11 +251,30 @@ class RunManager: if hasattr(record, key) and value is not None: setattr(record, key, value) record.updated_at = _now_iso() - if self._store is not None: - try: - await self._store.update_run_completion(run_id, **kwargs) - except Exception: - logger.warning("Failed to persist run completion for %s", run_id, exc_info=True) + row_recovery_payload = self._store_put_payload(record, error=kwargs.get("error")) + if self._store is None: + return + try: + updated = await self._call_store_with_retry( + "update_run_completion", + run_id, + lambda: self._store.update_run_completion(run_id, **kwargs), + ) + if updated is False: + if row_recovery_payload is None: + logger.warning("Failed to recreate missing run %s for completion persistence", run_id) + return + if not await self._persist_snapshot_to_store(run_id, row_recovery_payload): + return + recovered = await self._call_store_with_retry( + "update_run_completion", + run_id, + lambda: self._store.update_run_completion(run_id, **kwargs), + ) + if recovered is False: + logger.warning("Run completion update for %s affected no rows after row recreation", run_id) + except Exception: + logger.warning("Failed to persist run completion for %s", run_id, exc_info=True) async def update_run_progress(self, run_id: str, **kwargs) -> None: """Persist a running token/message snapshot without changing status.""" @@ -273,7 +408,7 @@ class RunManager: record.updated_at = _now_iso() if error is not None: record.error = error - await self._persist_status(run_id, status, error=error) + await self._persist_status(record, status, error=error) logger.info("Run %s -> %s", run_id, status.value) async def _persist_model_name(self, run_id: str, model_name: str | None) -> None: @@ -281,7 +416,11 @@ class RunManager: if self._store is None: return try: - await self._store.update_model_name(run_id, model_name) + await self._call_store_with_retry( + "update_model_name", + run_id, + lambda: self._store.update_model_name(run_id, model_name), + ) except Exception: logger.warning("Failed to persist model_name update for run %s", run_id, exc_info=True) @@ -324,7 +463,7 @@ class RunManager: record.task.cancel() record.status = RunStatus.interrupted record.updated_at = _now_iso() - await self._persist_status(run_id, RunStatus.interrupted) + await self._persist_status(record, RunStatus.interrupted) logger.info("Run %s cancelled (action=%s)", run_id, action) return True @@ -352,7 +491,7 @@ class RunManager: now = _now_iso() _supported_strategies = ("reject", "interrupt", "rollback") - interrupted_run_ids: list[str] = [] + interrupted_records: list[RunRecord] = [] async with self._lock: if multitask_strategy not in _supported_strategies: @@ -371,7 +510,7 @@ class RunManager: r.task.cancel() r.status = RunStatus.interrupted r.updated_at = now - interrupted_run_ids.append(r.run_id) + interrupted_records.append(r) logger.info( "Cancelled %d inflight run(s) on thread %s (strategy=%s)", len(inflight), @@ -394,12 +533,66 @@ class RunManager: ) self._runs[run_id] = record - for interrupted_run_id in interrupted_run_ids: - await self._persist_status(interrupted_run_id, RunStatus.interrupted) + for interrupted_record in interrupted_records: + await self._persist_status(interrupted_record, RunStatus.interrupted) await self._persist_to_store(record) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) return record + async def reconcile_orphaned_inflight_runs( + self, + *, + error: str, + before: str | None = None, + ) -> list[RunRecord]: + """Mark persisted active runs as failed when no local task owns them. + + Gateway runs are process-local: the asyncio task and abort event live in + memory, while the run row is durable. After a SQLite-backed gateway + restart, any persisted ``pending`` or ``running`` row created before + startup cannot still have a local worker. This recovery step turns that + ambiguous state into an explicit error instead of letting the UI show an + indefinite active run. + """ + if self._store is None: + return [] + try: + rows = await self._call_store_with_retry( + "list_inflight", + "*", + lambda: self._store.list_inflight(before=before), + ) + except Exception: + logger.warning("Failed to list orphaned inflight runs for reconciliation", exc_info=True) + return [] + + recovered: list[RunRecord] = [] + now = _now_iso() + for row in rows: + try: + record = self._record_from_store(row) + except Exception: + logger.warning("Failed to map orphaned run row during reconciliation", exc_info=True) + continue + + async with self._lock: + live_record = self._runs.get(record.run_id) + if live_record is not None and live_record.status in (RunStatus.pending, RunStatus.running): + continue + + record.status = RunStatus.error + record.error = error + record.updated_at = now + persisted = await self._persist_status(record, RunStatus.error, error=error) + if not persisted: + logger.warning("Skipped orphaned run %s recovery because error status was not persisted", record.run_id) + continue + recovered.append(record) + + if recovered: + logger.warning("Recovered %d orphaned inflight run(s) as error", len(recovered)) + return recovered + async def has_inflight(self, thread_id: str) -> bool: """Return ``True`` if *thread_id* has a pending or running run.""" async with self._lock: diff --git a/backend/packages/harness/deerflow/runtime/runs/store/base.py b/backend/packages/harness/deerflow/runtime/runs/store/base.py index c5ac18212..071f1436f 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/base.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/base.py @@ -59,7 +59,12 @@ class RunStore(abc.ABC): status: str, *, error: str | None = None, - ) -> None: + ) -> bool | None: + """Update a run status. + + Returns ``False`` when the store can prove no row was updated. Older or + lightweight stores may return ``None`` when they cannot report rowcount. + """ pass @abc.abstractmethod @@ -92,7 +97,11 @@ class RunStore(abc.ABC): last_ai_message: str | None = None, first_human_message: str | None = None, error: str | None = None, - ) -> None: + ) -> bool | None: + """Persist final completion fields. + + Returns ``False`` when the store can prove no row was updated. + """ pass async def update_run_progress( @@ -117,6 +126,11 @@ class RunStore(abc.ABC): async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]: pass + @abc.abstractmethod + async def list_inflight(self, *, before: str | None = None) -> list[dict[str, Any]]: + """Return persisted runs that are still ``pending`` or ``running``.""" + pass + @abc.abstractmethod async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]: """Aggregate token usage for completed runs in a thread. diff --git a/backend/packages/harness/deerflow/runtime/runs/store/memory.py b/backend/packages/harness/deerflow/runtime/runs/store/memory.py index d241f2ecc..743240723 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/memory.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/memory.py @@ -65,6 +65,8 @@ class MemoryRunStore(RunStore): if error is not None: self._runs[run_id]["error"] = error self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat() + return True + return False async def update_model_name(self, run_id, model_name): if run_id in self._runs: @@ -81,6 +83,8 @@ class MemoryRunStore(RunStore): if value is not None: self._runs[run_id][key] = value self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat() + return True + return False async def update_run_progress(self, run_id, **kwargs): if run_id in self._runs and self._runs[run_id].get("status") == "running": @@ -95,6 +99,12 @@ class MemoryRunStore(RunStore): results.sort(key=lambda r: r["created_at"]) return results + async def list_inflight(self, *, before=None): + now = before or datetime.now(UTC).isoformat() + results = [r for r in self._runs.values() if r["status"] in ("pending", "running") and r["created_at"] <= now] + results.sort(key=lambda r: r["created_at"]) + return results + async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]: statuses = ("success", "error", "running") if include_active else ("success", "error") completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses] diff --git a/backend/tests/test_gateway_run_recovery.py b/backend/tests/test_gateway_run_recovery.py new file mode 100644 index 000000000..4cabc2147 --- /dev/null +++ b/backend/tests/test_gateway_run_recovery.py @@ -0,0 +1,127 @@ +"""Gateway startup recovery for stale persisted runs.""" + +from __future__ import annotations + +from contextlib import asynccontextmanager +from types import SimpleNamespace + +import pytest +from fastapi import FastAPI + +import deerflow.runtime as runtime_module +from app.gateway import deps as gateway_deps +from deerflow.persistence import engine as engine_module +from deerflow.persistence import thread_meta as thread_meta_module +from deerflow.runtime.checkpointer import async_provider as checkpointer_module +from deerflow.runtime.events import store as event_store_module + + +@asynccontextmanager +async def _fake_context(value): + yield value + + +class _FakeRunManager: + """RunManager double that records startup reconciliation calls.""" + + instances: list[_FakeRunManager] = [] + recovered_runs = [SimpleNamespace(run_id="run-1", thread_id="thread-1")] + latest_by_thread: dict[str, list[SimpleNamespace]] = {} + + def __init__(self, *, store): + self.store = store + self.reconcile_calls: list[dict] = [] + self.list_by_thread_calls: list[dict] = [] + _FakeRunManager.instances.append(self) + + async def reconcile_orphaned_inflight_runs(self, *, error: str, before: str | None = None): + self.reconcile_calls.append({"error": error, "before": before}) + return self.recovered_runs + + async def list_by_thread(self, thread_id: str, *, user_id=None, limit: int = 100): + self.list_by_thread_calls.append({"thread_id": thread_id, "user_id": user_id, "limit": limit}) + return self.latest_by_thread.get(thread_id, self.recovered_runs[:limit]) + + +class _FakeThreadStore: + def __init__(self) -> None: + self.status_updates: list[tuple[str, str, str | None]] = [] + + async def update_status(self, thread_id: str, status: str, *, user_id=None) -> None: + self.status_updates.append((thread_id, status, user_id)) + + +@pytest.mark.anyio +async def test_sqlite_runtime_reconciles_orphaned_runs_on_startup(monkeypatch): + """SQLite startup should recover stale active runs before serving requests.""" + app = FastAPI() + config = SimpleNamespace( + database=SimpleNamespace(backend="sqlite"), + run_events=SimpleNamespace(backend="memory"), + ) + thread_store = _FakeThreadStore() + _FakeRunManager.instances.clear() + _FakeRunManager.recovered_runs = [SimpleNamespace(run_id="run-1", thread_id="thread-1")] + _FakeRunManager.latest_by_thread = {} + + async def fake_init_engine_from_config(_database): + return None + + async def fake_close_engine(): + return None + + monkeypatch.setattr(engine_module, "init_engine_from_config", fake_init_engine_from_config) + monkeypatch.setattr(engine_module, "get_session_factory", lambda: None) + monkeypatch.setattr(engine_module, "close_engine", fake_close_engine) + monkeypatch.setattr(runtime_module, "make_stream_bridge", lambda _config: _fake_context(object())) + monkeypatch.setattr(checkpointer_module, "make_checkpointer", lambda _config: _fake_context(object())) + monkeypatch.setattr(runtime_module, "make_store", lambda _config: _fake_context(object())) + monkeypatch.setattr(thread_meta_module, "make_thread_store", lambda _sf, _store: thread_store) + monkeypatch.setattr(event_store_module, "make_run_event_store", lambda _config: object()) + monkeypatch.setattr(gateway_deps, "RunManager", _FakeRunManager) + + async with gateway_deps.langgraph_runtime(app, config): + pass + + assert len(_FakeRunManager.instances) == 1 + assert _FakeRunManager.instances[0].reconcile_calls + assert _FakeRunManager.instances[0].reconcile_calls[0]["error"] + assert _FakeRunManager.instances[0].list_by_thread_calls == [{"thread_id": "thread-1", "user_id": None, "limit": 1}] + assert thread_store.status_updates == [("thread-1", "error", None)] + + +@pytest.mark.anyio +async def test_sqlite_runtime_does_not_mark_thread_error_when_newer_run_is_success(monkeypatch): + """Startup recovery should not let an old orphaned run overwrite a newer terminal thread state.""" + app = FastAPI() + config = SimpleNamespace( + database=SimpleNamespace(backend="sqlite"), + run_events=SimpleNamespace(backend="memory"), + ) + thread_store = _FakeThreadStore() + _FakeRunManager.instances.clear() + _FakeRunManager.recovered_runs = [SimpleNamespace(run_id="old-running", thread_id="thread-1")] + _FakeRunManager.latest_by_thread = {"thread-1": [SimpleNamespace(run_id="newer-success", thread_id="thread-1", status="success")]} + + async def fake_init_engine_from_config(_database): + return None + + async def fake_close_engine(): + return None + + monkeypatch.setattr(engine_module, "init_engine_from_config", fake_init_engine_from_config) + monkeypatch.setattr(engine_module, "get_session_factory", lambda: None) + monkeypatch.setattr(engine_module, "close_engine", fake_close_engine) + monkeypatch.setattr(runtime_module, "make_stream_bridge", lambda _config: _fake_context(object())) + monkeypatch.setattr(checkpointer_module, "make_checkpointer", lambda _config: _fake_context(object())) + monkeypatch.setattr(runtime_module, "make_store", lambda _config: _fake_context(object())) + monkeypatch.setattr(thread_meta_module, "make_thread_store", lambda _sf, _store: thread_store) + monkeypatch.setattr(event_store_module, "make_run_event_store", lambda _config: object()) + monkeypatch.setattr(gateway_deps, "RunManager", _FakeRunManager) + + async with gateway_deps.langgraph_runtime(app, config): + pass + + assert len(_FakeRunManager.instances) == 1 + assert _FakeRunManager.instances[0].list_by_thread_calls == [{"thread_id": "thread-1", "user_id": None, "limit": 1}] + assert thread_store.status_updates == [] diff --git a/backend/tests/test_run_manager.py b/backend/tests/test_run_manager.py index e7b5f06f5..3ee877eca 100644 --- a/backend/tests/test_run_manager.py +++ b/backend/tests/test_run_manager.py @@ -1,10 +1,15 @@ """Tests for RunManager.""" +import logging import re +import sqlite3 +from typing import Any import pytest +from sqlalchemy.exc import DatabaseError as SQLAlchemyDatabaseError from deerflow.runtime import DisconnectMode, RunManager, RunStatus +from deerflow.runtime.runs.manager import PersistenceRetryPolicy from deerflow.runtime.runs.store.memory import MemoryRunStore ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}") @@ -15,6 +20,92 @@ def manager() -> RunManager: return RunManager() +class FlakyStatusRunStore(MemoryRunStore): + """Memory run store that simulates transient SQLite status-write failures.""" + + def __init__(self, *, status_failures: int) -> None: + super().__init__() + self.status_failures = status_failures + self.status_update_attempts = 0 + + async def update_status(self, run_id, status, *, error=None): + self.status_update_attempts += 1 + if self.status_failures > 0: + self.status_failures -= 1 + raise sqlite3.OperationalError("database is locked") + return await super().update_status(run_id, status, error=error) + + +class MissingRowStatusRunStore(MemoryRunStore): + """Memory run store that reports a missing row for status updates.""" + + async def update_status(self, run_id, status, *, error=None): + await super().update_status(run_id, status, error=error) + return False + + +class PermanentStatusRunStore(MemoryRunStore): + """Memory run store that simulates a permanent SQLAlchemy write failure.""" + + def __init__(self) -> None: + super().__init__() + self.status_update_attempts = 0 + + async def update_status(self, run_id, status, *, error=None): + self.status_update_attempts += 1 + raise SQLAlchemyDatabaseError( + "UPDATE runs SET status = :status WHERE run_id = :run_id", + {"status": status, "run_id": run_id}, + sqlite3.DatabaseError("no such table: runs"), + ) + + +class FailingStatusRunStore(MemoryRunStore): + """Memory run store that always fails status updates.""" + + def __init__(self) -> None: + super().__init__() + self.status_update_attempts = 0 + + async def update_status(self, run_id, status, *, error=None): + self.status_update_attempts += 1 + raise sqlite3.OperationalError("database is locked") + + +class MissingCompletionRunStore(MemoryRunStore): + """Memory run store that reports one missing row for completion updates.""" + + def __init__(self) -> None: + super().__init__() + self.completion_update_attempts = 0 + + async def update_run_completion(self, run_id, *, status, **kwargs): + self.completion_update_attempts += 1 + if self.completion_update_attempts == 1: + return False + return await super().update_run_completion(run_id, status=status, **kwargs) + + +class AlwaysMissingCompletionRunStore(MemoryRunStore): + """Memory run store that keeps reporting missing rows for completion updates.""" + + def __init__(self) -> None: + super().__init__() + self.completion_update_attempts = 0 + + async def update_run_completion(self, run_id, *, status, **kwargs): + self.completion_update_attempts += 1 + return False + + +async def _stored_statuses(store: MemoryRunStore, *run_ids: str) -> dict[str, Any]: + rows = {} + for run_id in run_ids: + row = await store.get(run_id) + rows[run_id] = row["status"] if row else None + return rows + + @pytest.mark.anyio async def test_create_and_get(manager: RunManager): """Created run should be retrievable with new fields.""" @@ -80,6 +171,155 @@ async def test_cancel_persists_interrupted_status_to_store(): assert stored["status"] == "interrupted" +@pytest.mark.anyio +async def test_status_persistence_retries_transient_sqlite_lock(): + """Transient SQLite lock errors should not leave a final status stale.""" + store = FlakyStatusRunStore(status_failures=2) + manager = RunManager(store=store) + record = await manager.create("thread-1") + await manager.set_status(record.run_id, RunStatus.running) + + await manager.set_status(record.run_id, RunStatus.success) + + stored = await store.get(record.run_id) + assert stored is not None + assert stored["status"] == "success" + assert store.status_update_attempts >= 4 + + +@pytest.mark.anyio +async def test_status_persistence_recreates_missing_store_row(): + """A final status update should recreate a run row if initial persistence was lost.""" + store = MissingRowStatusRunStore() + manager = RunManager(store=store) + record = await manager.create("thread-1") + await store.delete(record.run_id) + + await manager.set_status(record.run_id, RunStatus.error, error="boom") + + stored = await store.get(record.run_id) + assert stored is not None + assert stored["status"] == "error" + assert stored["error"] == "boom" + + +@pytest.mark.anyio +async def test_status_persistence_does_not_retry_permanent_sqlalchemy_errors(): + """Permanent SQLAlchemy failures should not be retried as SQLite pressure.""" + store = PermanentStatusRunStore() + manager = RunManager( + store=store, + persistence_retry_policy=PersistenceRetryPolicy(max_attempts=5, initial_delay=0), + ) + record = await manager.create("thread-1") + + await manager.set_status(record.run_id, RunStatus.error, error="boom") + + assert store.status_update_attempts == 1 + + +@pytest.mark.anyio +async def test_completion_persistence_recreates_missing_store_row(): + """Completion updates should recreate a missing row and persist final counters.""" + store = MissingCompletionRunStore() + manager = RunManager(store=store) + record = await manager.create("thread-1") + await manager.set_status(record.run_id, RunStatus.running) + await manager.set_status(record.run_id, RunStatus.success) + await store.delete(record.run_id) + + await manager.update_run_completion( + record.run_id, + status="success", + total_tokens=42, + llm_call_count=2, + last_ai_message="done", + ) + + stored = await store.get(record.run_id) + assert stored is not None + assert stored["status"] == "success" + assert stored["total_tokens"] == 42 + assert stored["llm_call_count"] == 2 + assert stored["last_ai_message"] == "done" + assert store.completion_update_attempts == 2 + + +@pytest.mark.anyio +async def test_completion_persistence_warns_when_recreated_row_still_missing(caplog): + """A second zero-row completion update after recreation should not be silent.""" + store = AlwaysMissingCompletionRunStore() + manager = RunManager(store=store) + record = await manager.create("thread-1") + await manager.set_status(record.run_id, RunStatus.success) + caplog.set_level(logging.WARNING, logger="deerflow.runtime.runs.manager") + + await manager.update_run_completion(record.run_id, status="success", total_tokens=42) + + assert store.completion_update_attempts == 2 + assert "affected no rows after row recreation" in caplog.text + + +@pytest.mark.anyio +async def test_reconcile_orphaned_inflight_runs_marks_stale_rows_error(): + """Startup recovery should turn persisted active rows into explicit errors.""" + store = MemoryRunStore() + await store.put("pending-run", thread_id="thread-1", status="pending", created_at="2026-01-01T00:00:00+00:00") + await store.put("running-run", thread_id="thread-1", status="running", created_at="2026-01-01T00:00:01+00:00") + await store.put("success-run", thread_id="thread-1", status="success", created_at="2026-01-01T00:00:02+00:00") + manager = RunManager(store=store) + + recovered = await manager.reconcile_orphaned_inflight_runs( + error="Gateway restarted before this run reached a durable final state.", + before="2026-01-01T00:00:02+00:00", + ) + + assert {record.run_id for record in recovered} == {"pending-run", "running-run"} + assert await _stored_statuses(store, "pending-run", "running-run", "success-run") == { + "pending-run": "error", + "running-run": "error", + "success-run": "success", + } + + +@pytest.mark.anyio +async def test_reconcile_orphaned_inflight_runs_skips_live_local_run(): + """Startup recovery should not mark an active row orphaned when this worker owns it.""" + store = MemoryRunStore() + manager = RunManager(store=store) + record = await manager.create("thread-1") + await manager.set_status(record.run_id, RunStatus.running) + + recovered = await manager.reconcile_orphaned_inflight_runs( + error="Gateway restarted before this run reached a durable final state.", + ) + + stored = await store.get(record.run_id) + assert recovered == [] + assert stored["status"] == "running" + + +@pytest.mark.anyio +async def test_reconcile_orphaned_inflight_runs_skips_rows_when_error_status_is_not_persisted(): + """Startup recovery must not report a row as recovered if the error update failed.""" + store = FailingStatusRunStore() + await store.put("running-run", thread_id="thread-1", status="running", created_at="2026-01-01T00:00:00+00:00") + manager = RunManager( + store=store, + persistence_retry_policy=PersistenceRetryPolicy(max_attempts=2, initial_delay=0), + ) + + recovered = await manager.reconcile_orphaned_inflight_runs( + error="Gateway restarted before this run reached a durable final state.", + before="2026-01-01T00:00:01+00:00", + ) + + stored = await store.get("running-run") + assert recovered == [] + assert stored["status"] == "running" + assert store.status_update_attempts == 2 + + @pytest.mark.anyio async def test_cancel_not_inflight(manager: RunManager): """Cancelling a completed run should return False.""" diff --git a/backend/tests/test_run_repository.py b/backend/tests/test_run_repository.py index f18e51348..037201f37 100644 --- a/backend/tests/test_run_repository.py +++ b/backend/tests/test_run_repository.py @@ -52,6 +52,9 @@ class _CustomRunStoreWithoutProgress(RunStore): async def list_pending(self, *args, **kwargs): return [] + async def list_inflight(self, *args, **kwargs): + return [] + async def aggregate_tokens_by_thread(self, *args, **kwargs): return {} @@ -75,6 +78,19 @@ class TestRunRepository: assert row["status"] == "pending" await _cleanup() + @pytest.mark.anyio + async def test_put_is_idempotent_for_retried_writes(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", assistant_id="old-agent", status="pending") + + await repo.put("r1", thread_id="t1", assistant_id="new-agent", status="running", error="retry") + + row = await repo.get("r1") + assert row["assistant_id"] == "new-agent" + assert row["status"] == "running" + assert row["error"] == "retry" + await _cleanup() + @pytest.mark.anyio async def test_get_missing_returns_none(self, tmp_path): repo = await _make_repo(tmp_path) @@ -85,11 +101,19 @@ class TestRunRepository: async def test_update_status(self, tmp_path): repo = await _make_repo(tmp_path) await repo.put("r1", thread_id="t1") - await repo.update_status("r1", "running") + updated = await repo.update_status("r1", "running") row = await repo.get("r1") + assert updated is True assert row["status"] == "running" await _cleanup() + @pytest.mark.anyio + async def test_update_status_returns_false_for_missing_row(self, tmp_path): + repo = await _make_repo(tmp_path) + updated = await repo.update_status("missing", "error", error="lost") + assert updated is False + await _cleanup() + @pytest.mark.anyio async def test_update_status_with_error(self, tmp_path): repo = await _make_repo(tmp_path) @@ -146,11 +170,24 @@ class TestRunRepository: assert all(r["status"] == "pending" for r in pending) await _cleanup() + @pytest.mark.anyio + async def test_list_inflight_returns_pending_and_running_before_cutoff(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("pending-old", thread_id="t1", status="pending", created_at="2026-01-01T00:00:00+00:00") + await repo.put("running-old", thread_id="t1", status="running", created_at="2026-01-01T00:00:01+00:00") + await repo.put("success-old", thread_id="t1", status="success", created_at="2026-01-01T00:00:02+00:00") + await repo.put("pending-new", thread_id="t1", status="pending", created_at="2026-01-01T00:00:03+00:00") + + inflight = await repo.list_inflight(before="2026-01-01T00:00:02+00:00") + + assert [row["run_id"] for row in inflight] == ["pending-old", "running-old"] + await _cleanup() + @pytest.mark.anyio async def test_update_run_completion(self, tmp_path): repo = await _make_repo(tmp_path) await repo.put("r1", thread_id="t1", status="running") - await repo.update_run_completion( + updated = await repo.update_run_completion( "r1", status="success", total_input_tokens=100, @@ -165,6 +202,7 @@ class TestRunRepository: first_human_message="What is the meaning?", ) row = await repo.get("r1") + assert updated is True assert row["status"] == "success" assert row["total_tokens"] == 150 assert row["llm_call_count"] == 2 @@ -174,6 +212,13 @@ class TestRunRepository: assert row["first_human_message"] == "What is the meaning?" await _cleanup() + @pytest.mark.anyio + async def test_update_run_completion_returns_false_for_missing_row(self, tmp_path): + repo = await _make_repo(tmp_path) + updated = await repo.update_run_completion("missing", status="error", total_tokens=1) + assert updated is False + await _cleanup() + @pytest.mark.anyio async def test_metadata_preserved(self, tmp_path): repo = await _make_repo(tmp_path) From b103d1a7f543bfc1d72f36272a37bb98d70e7ce6 Mon Sep 17 00:00:00 2001 From: JeffJiang Date: Sat, 23 May 2026 00:10:56 +0800 Subject: [PATCH 72/86] feat(frontend): support static website demo mode (#3170) * feat(frontend): support static website demo mode * fix(frontend): render html artifact previews from blob content * chore(frontend): apply pre-commit formatting * fix(frontend): address static demo PR review comments * Update the release information of DeerFlow --------- Co-authored-by: Willem Jiang --- frontend/Makefile | 4 + frontend/next.config.js | 4 + .../leica-master-photography-article.md | 6 +- .../workspace/chats/[thread_id]/layout.tsx | 22 ++--- .../app/workspace/chats/[thread_id]/page.tsx | 1 + .../workspace/chats/[thread_id]/providers.tsx | 15 ++++ frontend/src/app/workspace/layout.tsx | 14 +-- .../artifacts/artifact-file-detail.tsx | 50 +++++++++-- frontend/src/content/en/index.mdx | 34 ++++---- .../src/content/en/posts/releases/2_0_rc.mdx | 9 ++ frontend/src/content/zh/index.mdx | 22 ++--- frontend/src/core/api/api-client.ts | 49 +++++++++++ frontend/src/core/artifacts/utils.ts | 20 +++++ frontend/src/core/auth/AuthProvider.tsx | 20 ++++- frontend/src/core/auth/server.ts | 10 +++ frontend/src/core/auth/static-user.ts | 8 ++ frontend/src/core/models/api.ts | 10 +++ frontend/src/core/static-mode.ts | 5 ++ frontend/src/core/threads/static-demo.ts | 87 +++++++++++++++++++ .../tests/unit/core/artifacts/utils.test.ts | 69 +++++++++++++++ frontend/tests/unit/core/auth/server.test.ts | 77 ++++++++++++++++ 21 files changed, 477 insertions(+), 59 deletions(-) create mode 100644 frontend/src/app/workspace/chats/[thread_id]/providers.tsx create mode 100644 frontend/src/content/en/posts/releases/2_0_rc.mdx create mode 100644 frontend/src/core/auth/static-user.ts create mode 100644 frontend/src/core/static-mode.ts create mode 100644 frontend/src/core/threads/static-demo.ts create mode 100644 frontend/tests/unit/core/artifacts/utils.test.ts create mode 100644 frontend/tests/unit/core/auth/server.test.ts diff --git a/frontend/Makefile b/frontend/Makefile index 48d23b97b..bf6c351e2 100644 --- a/frontend/Makefile +++ b/frontend/Makefile @@ -18,3 +18,7 @@ lint: format: pnpm format:write + +build-static: + NEXT_CONFIG_BUILD_OUTPUT=standalone SKIP_ENV_VALIDATION=1 NEXT_PUBLIC_STATIC_WEBSITE_ONLY=true pnpm build + @if [ -d .next/static ]; then mkdir -p .next/standalone/.next && cp -R .next/static .next/standalone/.next/static; fi diff --git a/frontend/next.config.js b/frontend/next.config.js index 5b20aad5f..7007d59fc 100644 --- a/frontend/next.config.js +++ b/frontend/next.config.js @@ -16,6 +16,10 @@ const withNextra = nextra({}); /** @type {import("next").NextConfig} */ const config = { + output: + process.env.NEXT_CONFIG_BUILD_OUTPUT === "standalone" + ? "standalone" + : undefined, i18n: { locales: ["en", "zh"], defaultLocale: "en", diff --git a/frontend/public/demo/threads/7f9dc56c-e49c-4671-a3d2-c492ff4dce0c/user-data/outputs/leica-master-photography-article.md b/frontend/public/demo/threads/7f9dc56c-e49c-4671-a3d2-c492ff4dce0c/user-data/outputs/leica-master-photography-article.md index 6735fb56f..75e82aec4 100644 --- a/frontend/public/demo/threads/7f9dc56c-e49c-4671-a3d2-c492ff4dce0c/user-data/outputs/leica-master-photography-article.md +++ b/frontend/public/demo/threads/7f9dc56c-e49c-4671-a3d2-c492ff4dce0c/user-data/outputs/leica-master-photography-article.md @@ -32,7 +32,7 @@ Even with digital Leicas, photographers often emulate film characteristics: natu ### Image 1: Parisian Decisive Moment -![Paris Decisive Moment](/frontend/public/demo/threads/7f9dc56c-e49c-4671-a3d2-c492ff4dce0c/user-data/outputs/leica-paris-decisive-moment.jpg) +![Paris Decisive Moment](/demo/threads/7f9dc56c-e49c-4671-a3d2-c492ff4dce0c/user-data/outputs/leica-paris-decisive-moment.jpg) This image captures the essence of Cartier-Bresson's philosophy. A woman in a red coat leaps over a puddle while a cyclist passes in perfect synchrony. The composition follows the rule of thirds, with the subject positioned at the intersection of grid lines. Shot with a simulated Leica M11 and 35mm Summicron lens at f/2.8, the image features shallow depth of field, natural film grain, and the warm, muted color palette characteristic of Leica photography. @@ -40,7 +40,7 @@ The "decisive moment" here isn't just about timing—it's about the alignment of ### Image 2: Tokyo Night Reflections -![Tokyo Night Scene](/frontend/public/demo/threads/7f9dc56c-e49c-4671-a3d2-c492ff4dce0c/user-data/outputs/leica-tokyo-night.jpg) +![Tokyo Night Scene](/demo/threads/7f9dc56c-e49c-4671-a3d2-c492ff4dce0c/user-data/outputs/leica-tokyo-night.jpg) Moving to Shinjuku, Tokyo, this image explores the atmospheric possibilities of Leica's legendary Noctilux lens. Simulating a Leica M10-P with a 50mm f/0.95 Noctilux wide open, the image creates extremely shallow depth of field with beautiful bokeh balls from neon signs reflected in wet pavement. @@ -48,7 +48,7 @@ A salaryman waits under glowing kanji signs, steam rising from a nearby ramen sh ### Image 3: New York City Candid -![NYC Candid Scene](/frontend/public/demo/threads/7f9dc56c-e49c-4671-a3d2-c492ff4dce0c/user-data/outputs/leica-nyc-candid.jpg) +![NYC Candid Scene](/demo/threads/7f9dc56c-e49c-4671-a3d2-c492ff4dce0c/user-data/outputs/leica-nyc-candid.jpg) This Chinatown scene demonstrates the documentary power of Leica's Q2 camera with its fixed 28mm Summilux lens. The wide angle captures environmental context while maintaining intimate proximity to the subjects. A fishmonger hands a live fish to a customer while tourists photograph the scene—a moment of cultural contrast and authentic urban life. diff --git a/frontend/src/app/workspace/chats/[thread_id]/layout.tsx b/frontend/src/app/workspace/chats/[thread_id]/layout.tsx index 877103774..eeee68347 100644 --- a/frontend/src/app/workspace/chats/[thread_id]/layout.tsx +++ b/frontend/src/app/workspace/chats/[thread_id]/layout.tsx @@ -1,19 +1,19 @@ -"use client"; +import { isStaticWebsiteOnly } from "@/core/static-mode"; +import { DEMO_THREAD_IDS } from "@/core/threads/static-demo"; -import { PromptInputProvider } from "@/components/ai-elements/prompt-input"; -import { ArtifactsProvider } from "@/components/workspace/artifacts"; -import { SubtasksProvider } from "@/core/tasks/context"; +import { ChatProviders } from "./providers"; + +export function generateStaticParams() { + if (!isStaticWebsiteOnly()) { + return []; + } + return DEMO_THREAD_IDS.map((thread_id) => ({ thread_id })); +} export default function ChatLayout({ children, }: { children: React.ReactNode; }) { - return ( - - - {children} - - - ); + return {children}; } diff --git a/frontend/src/app/workspace/chats/[thread_id]/page.tsx b/frontend/src/app/workspace/chats/[thread_id]/page.tsx index 6f865ade8..ce3912b91 100644 --- a/frontend/src/app/workspace/chats/[thread_id]/page.tsx +++ b/frontend/src/app/workspace/chats/[thread_id]/page.tsx @@ -227,6 +227,7 @@ export default function ChatPage() { isWelcomeMode && } disabled={ + isMock || env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true" || isUploading } diff --git a/frontend/src/app/workspace/chats/[thread_id]/providers.tsx b/frontend/src/app/workspace/chats/[thread_id]/providers.tsx new file mode 100644 index 000000000..46d4a4cef --- /dev/null +++ b/frontend/src/app/workspace/chats/[thread_id]/providers.tsx @@ -0,0 +1,15 @@ +"use client"; + +import { PromptInputProvider } from "@/components/ai-elements/prompt-input"; +import { ArtifactsProvider } from "@/components/workspace/artifacts"; +import { SubtasksProvider } from "@/core/tasks/context"; + +export function ChatProviders({ children }: { children: React.ReactNode }) { + return ( + + + {children} + + + ); +} diff --git a/frontend/src/app/workspace/layout.tsx b/frontend/src/app/workspace/layout.tsx index c2d567339..0d214f0d3 100644 --- a/frontend/src/app/workspace/layout.tsx +++ b/frontend/src/app/workspace/layout.tsx @@ -43,12 +43,14 @@ export default async function WorkspaceLayout({ > Retry - - Logout & Reset - +
+ +
); diff --git a/frontend/src/components/workspace/artifacts/artifact-file-detail.tsx b/frontend/src/components/workspace/artifacts/artifact-file-detail.tsx index 17e642fc5..93130c44f 100644 --- a/frontend/src/components/workspace/artifacts/artifact-file-detail.tsx +++ b/frontend/src/components/workspace/artifacts/artifact-file-detail.tsx @@ -83,7 +83,7 @@ export function ArtifactFileDetail({ const isSupportPreview = useMemo(() => { return language === "html" || language === "markdown"; }, [language]); - const { content } = useArtifactContent({ + const { content, url } = useArtifactContent({ threadId, filepath: filepathFromProps, enabled: isCodeFile && !isWriteFile, @@ -254,7 +254,9 @@ export function ArtifactFileDetail({ (language === "markdown" || language === "html") && ( )} {isCodeFile && viewMode === "code" && ( @@ -277,27 +279,33 @@ export function ArtifactFileDetail({ export function ArtifactFilePreview({ content, + isWriteFile, language, + url, }: { content: string; + isWriteFile: boolean; language: string; + url?: string; }) { const [htmlPreviewUrl, setHtmlPreviewUrl] = useState(); useEffect(() => { - if (language !== "html") { + if (language !== "html" || isWriteFile) { setHtmlPreviewUrl(undefined); return; } - const blob = new Blob([content ?? ""], { type: "text/html" }); - const url = URL.createObjectURL(blob); - setHtmlPreviewUrl(url); + const blob = new Blob([htmlWithBaseHref(content ?? "", url)], { + type: "text/html", + }); + const objectUrl = URL.createObjectURL(blob); + setHtmlPreviewUrl(objectUrl); return () => { - URL.revokeObjectURL(url); + URL.revokeObjectURL(objectUrl); }; - }, [content, language]); + }, [content, isWriteFile, language, url]); if (language === "markdown") { return ( @@ -318,9 +326,35 @@ export function ArtifactFilePreview({ className="size-full" title="Artifact preview" sandbox="allow-scripts allow-forms" - src={htmlPreviewUrl} + src={isWriteFile ? undefined : htmlPreviewUrl} + srcDoc={isWriteFile ? content : undefined} /> ); } return null; } + +function htmlWithBaseHref(content: string, url?: string) { + if (!url || /`; + if (/]*>/i.exec(content)) { + return content.replace(/]*)>/i, `${baseElement}`); + } + return `${baseElement}${content}`; +} + +function htmlBaseHref(url: string) { + const baseUrl = new URL(url, window.location.href); + baseUrl.pathname = baseUrl.pathname.replace(/\/[^/]*$/, "/"); + baseUrl.search = ""; + baseUrl.hash = ""; + return baseUrl.toString(); +} + +function escapeHtmlAttribute(value: string) { + return value.replaceAll("&", "&").replaceAll('"', """); +} diff --git a/frontend/src/content/en/index.mdx b/frontend/src/content/en/index.mdx index 0dd2efe12..e289d8f3b 100644 --- a/frontend/src/content/en/index.mdx +++ b/frontend/src/content/en/index.mdx @@ -20,27 +20,27 @@ If you want to understand how DeerFlow works, start with the Introduction. If yo Start with the conceptual overview first. -- [Introduction](/docs/introduction) -- [Why DeerFlow](/docs/introduction/why-deerflow) -- [Harness vs App](/docs/introduction/harness-vs-app) +- [Introduction](./docs/introduction) +- [Why DeerFlow](./docs/introduction/why-deerflow) +- [Harness vs App](./docs/introduction/harness-vs-app) ### If you want to build with DeerFlow Start with the Harness section. This path is for teams who want to integrate DeerFlow capabilities into their own system or build a custom agent product on top of the DeerFlow runtime. -- [DeerFlow Harness](/docs/harness) -- [Quick Start](/docs/harness/quick-start) -- [Configuration](/docs/harness/configuration) -- [Customization](/docs/harness/customization) +- [DeerFlow Harness](./docs/harness) +- [Quick Start](./docs/harness/quick-start) +- [Configuration](./docs/harness/configuration) +- [Customization](./docs/harness/customization) ### If you want to deploy and use DeerFlow Start with the App section. This path is for teams who want to run DeerFlow as a complete application and understand how to configure, operate, and use it in practice. -- [DeerFlow App](/docs/app) -- [Quick Start](/docs/app/quick-start) -- [Deployment Guide](/docs/app/deployment-guide) -- [Workspace Usage](/docs/app/workspace-usage) +- [DeerFlow App](./docs/app) +- [Quick Start](./docs/app/quick-start) +- [Deployment Guide](./docs/app/deployment-guide) +- [Workspace Usage](./docs/app/workspace-usage) ## Documentation structure @@ -79,17 +79,17 @@ The App section is written for teams who want to deploy DeerFlow as a usable pro The Tutorials section is for hands-on, task-oriented learning. -- [Tutorials](/docs/tutorials) +- [Tutorials](./docs/tutorials) ### Reference The Reference section is for detailed lookup material, including configuration, runtime modes, APIs, and source-oriented mapping. -- [Reference](/docs/reference) +- [Reference](./docs/reference) ## Choose the right path -- If you are **evaluating the project**, start with [Introduction](/docs/introduction). -- If you are **building your own agent system**, start with [DeerFlow Harness](/docs/harness). -- If you are **deploying DeerFlow for users**, start with [DeerFlow App](/docs/app). -- If you want to **learn by doing**, go to [Tutorials](/docs/tutorials). +- If you are **evaluating the project**, start with [Introduction](./docs/introduction). +- If you are **building your own agent system**, start with [DeerFlow Harness](./docs/harness). +- If you are **deploying DeerFlow for users**, start with [DeerFlow App](./docs/app). +- If you want to **learn by doing**, go to [Tutorials](./docs/tutorials). diff --git a/frontend/src/content/en/posts/releases/2_0_rc.mdx b/frontend/src/content/en/posts/releases/2_0_rc.mdx new file mode 100644 index 000000000..1f5f347c4 --- /dev/null +++ b/frontend/src/content/en/posts/releases/2_0_rc.mdx @@ -0,0 +1,9 @@ +--- +title: DeerFlow 2.0 M1 +description: DeerFlow 2.0 M1 is officially in RC. Here's what you need to know. +date: 2026-05-30 +tags: + - Release +--- + +## DeerFlow 2.0 M1 Release diff --git a/frontend/src/content/zh/index.mdx b/frontend/src/content/zh/index.mdx index 5f2a18deb..912991b06 100644 --- a/frontend/src/content/zh/index.mdx +++ b/frontend/src/content/zh/index.mdx @@ -20,27 +20,27 @@ DeerFlow 是一个用于构建和运行 Agent 系统的框架。它提供了一 先从概念概述开始。 -- [简介](/docs/introduction) -- [为什么选择 DeerFlow](/docs/introduction/why-deerflow) -- [Harness 与应用的区别](/docs/introduction/harness-vs-app) +- [简介](./docs/introduction) +- [为什么选择 DeerFlow](./docs/introduction/why-deerflow) +- [Harness 与应用的区别](./docs/introduction/harness-vs-app) ### 如果你想基于 DeerFlow 进行开发 从 Harness 章节开始。这条路径适合想将 DeerFlow 功能集成到自己系统中,或基于 DeerFlow 运行时构建自定义 Agent 产品的团队。 -- [DeerFlow Harness](/docs/harness) -- [快速上手](/docs/harness/quick-start) -- [配置](/docs/harness/configuration) -- [自定义与扩展](/docs/harness/customization) +- [DeerFlow Harness](./docs/harness) +- [快速上手](./docs/harness/quick-start) +- [配置](./docs/harness/configuration) +- [自定义与扩展](./docs/harness/customization) ### 如果你想部署和使用 DeerFlow 从应用章节开始。这条路径适合想将 DeerFlow 作为完整应用运行,并了解如何配置、运维和实际使用的团队。 -- [DeerFlow 应用](/docs/application) -- [快速上手](/docs/application/quick-start) -- [部署指南](/docs/application/deployment-guide) -- [工作区使用](/docs/application/workspace-usage) +- [DeerFlow 应用](./docs/application) +- [快速上手](./docs/application/quick-start) +- [部署指南](./docs/application/deployment-guide) +- [工作区使用](./docs/application/workspace-usage) ## 文档结构 diff --git a/frontend/src/core/api/api-client.ts b/frontend/src/core/api/api-client.ts index 0b4532ca9..841c2cdfb 100644 --- a/frontend/src/core/api/api-client.ts +++ b/frontend/src/core/api/api-client.ts @@ -3,6 +3,13 @@ import { Client as LangGraphClient } from "@langchain/langgraph-sdk/client"; import { getLangGraphBaseURL } from "../config"; +import { isStaticWebsiteOnly } from "../static-mode"; +import { + loadStaticDemoThread, + loadStaticDemoThreads, + staticDemoThreadState, +} from "../threads/static-demo"; +import type { AgentThreadState } from "../threads/types"; import { isStateChangingMethod, readCsrfCookie } from "./fetcher"; import { sanitizeRunStreamOptions } from "./stream-mode"; @@ -32,6 +39,10 @@ function injectCsrfHeader(_url: URL, init: RequestInit): RequestInit { } function createCompatibleClient(isMock?: boolean): LangGraphClient { + if (isStaticWebsiteOnly() && !isMock) { + return createStaticClient(); + } + const apiUrl = getLangGraphBaseURL(isMock); console.log(`Creating API client with base URL: ${apiUrl}`); const client = new LangGraphClient({ @@ -58,6 +69,44 @@ function createCompatibleClient(isMock?: boolean): LangGraphClient { return client; } +function createStaticClient(): LangGraphClient { + const apiUrl = + typeof window === "undefined" + ? "http://localhost:3000" + : window.location.origin; + const client = new LangGraphClient({ apiUrl }); + + client.threads.search = (async (query) => { + return loadStaticDemoThreads(query); + }) as typeof client.threads.search; + + client.threads.get = (async (threadId) => { + return loadStaticDemoThread(threadId); + }) as typeof client.threads.get; + + client.threads.getState = (async (threadId) => { + return staticDemoThreadState(await loadStaticDemoThread(threadId)); + }) as typeof client.threads.getState; + + client.threads.getHistory = (async (threadId) => { + return [staticDemoThreadState(await loadStaticDemoThread(threadId))]; + }) as typeof client.threads.getHistory; + + client.threads.update = (async (threadId) => { + return loadStaticDemoThread(threadId); + }) as typeof client.threads.update; + + client.runs.list = (async () => []) as typeof client.runs.list; + client.runs.stream = async function* () { + /* empty */ + } as typeof client.runs.stream; + client.runs.joinStream = async function* () { + /* empty */ + } as typeof client.runs.joinStream; + + return client as LangGraphClient; +} + const _clients = new Map(); export function getAPIClient(isMock?: boolean): LangGraphClient { const cacheKey = isMock ? "mock" : "default"; diff --git a/frontend/src/core/artifacts/utils.ts b/frontend/src/core/artifacts/utils.ts index 402696504..e205b739a 100644 --- a/frontend/src/core/artifacts/utils.ts +++ b/frontend/src/core/artifacts/utils.ts @@ -1,4 +1,5 @@ import { getBackendBaseURL } from "../config"; +import { isStaticWebsiteOnly } from "../static-mode"; import type { AgentThread } from "../threads"; export function urlOfArtifact({ @@ -12,6 +13,9 @@ export function urlOfArtifact({ download?: boolean; isMock?: boolean; }) { + if (isStaticWebsiteOnly()) { + return staticDemoArtifactURL({ filepath, threadId, download }); + } if (isMock) { return `${getBackendBaseURL()}/mock/api/threads/${threadId}/artifacts${filepath}${download ? "?download=true" : ""}`; } @@ -23,5 +27,21 @@ export function extractArtifactsFromThread(thread: AgentThread) { } export function resolveArtifactURL(absolutePath: string, threadId: string) { + if (isStaticWebsiteOnly()) { + return staticDemoArtifactURL({ filepath: absolutePath, threadId }); + } return `${getBackendBaseURL()}/api/threads/${threadId}/artifacts${absolutePath}`; } + +function staticDemoArtifactURL({ + filepath, + threadId, + download = false, +}: { + filepath: string; + threadId: string; + download?: boolean; +}) { + const demoPath = filepath.replace(/^\/mnt\//, "/"); + return `${getBackendBaseURL()}/demo/threads/${threadId}${demoPath}${download ? "?download=true" : ""}`; +} diff --git a/frontend/src/core/auth/AuthProvider.tsx b/frontend/src/core/auth/AuthProvider.tsx index 652cc49b8..5824c5f7b 100644 --- a/frontend/src/core/auth/AuthProvider.tsx +++ b/frontend/src/core/auth/AuthProvider.tsx @@ -10,6 +10,8 @@ import React, { type ReactNode, } from "react"; +import { isStaticWebsiteOnly } from "../static-mode"; + import { type User, buildLoginUrl } from "./types"; // Re-export for consumers @@ -46,6 +48,7 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) { const [isLoading, setIsLoading] = useState(false); const router = useRouter(); const pathname = usePathname(); + const staticMode = isStaticWebsiteOnly(); const isAuthenticated = user !== null; @@ -54,6 +57,8 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) { * Used when initialUser might be stale (e.g., after tab was inactive) */ const refreshUser = useCallback(async () => { + if (staticMode) return; + try { setIsLoading(true); const res = await fetch("/api/v1/auth/me", { @@ -77,7 +82,7 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) { } finally { setIsLoading(false); } - }, [pathname, router]); + }, [staticMode, pathname, router]); /** * Logout - call FastAPI logout endpoint and clear local state @@ -87,6 +92,11 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) { // Immediately clear local state to prevent UI flicker setUser(null); + if (staticMode) { + router.push("/"); + return; + } + try { await fetch("/api/v1/auth/logout", { method: "POST", @@ -99,7 +109,7 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) { // Redirect to home page router.push("/"); - }, [router]); + }, [staticMode, router]); /** * Handle visibility change - refresh user when tab becomes visible again. @@ -108,6 +118,8 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) { const lastCheckRef = React.useRef(0); useEffect(() => { + if (staticMode) return; + const handleVisibilityChange = () => { if (document.visibilityState !== "visible" || user === null) return; const now = Date.now(); @@ -120,7 +132,7 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) { return () => { document.removeEventListener("visibilitychange", handleVisibilityChange); }; - }, [user, refreshUser]); + }, [staticMode, user, refreshUser]); const value: AuthContextType = { user, @@ -155,6 +167,8 @@ export function useRequireAuth(): AuthContextType { const pathname = usePathname(); useEffect(() => { + if (isStaticWebsiteOnly()) return; + // Only redirect if we're sure user is not authenticated (not just loading) if (!auth.isLoading && !auth.isAuthenticated) { router.push(buildLoginUrl(pathname || "/workspace")); diff --git a/frontend/src/core/auth/server.ts b/frontend/src/core/auth/server.ts index 6ca3195c4..5712f1e89 100644 --- a/frontend/src/core/auth/server.ts +++ b/frontend/src/core/auth/server.ts @@ -1,6 +1,9 @@ import { cookies } from "next/headers"; +import { isStaticWebsiteOnly } from "../static-mode"; + import { getGatewayConfig } from "./gateway-config"; +import { STATIC_WEBSITE_USER } from "./static-user"; import { type AuthResult, userSchema } from "./types"; const SSR_AUTH_TIMEOUT_MS = 5_000; @@ -10,6 +13,13 @@ const SSR_AUTH_TIMEOUT_MS = 5_000; * Returns a tagged AuthResult — callers use exhaustive switch, no try/catch. */ export async function getServerSideUser(): Promise { + if (isStaticWebsiteOnly()) { + return { + tag: "authenticated", + user: STATIC_WEBSITE_USER, + }; + } + if (process.env.DEER_FLOW_AUTH_DISABLED === "1") { return { tag: "authenticated", diff --git a/frontend/src/core/auth/static-user.ts b/frontend/src/core/auth/static-user.ts new file mode 100644 index 000000000..31615e1d4 --- /dev/null +++ b/frontend/src/core/auth/static-user.ts @@ -0,0 +1,8 @@ +import type { User } from "./types"; + +export const STATIC_WEBSITE_USER: User = { + id: "static-website-user", + email: "static@example.local", + system_role: "admin", + needs_setup: false, +}; diff --git a/frontend/src/core/models/api.ts b/frontend/src/core/models/api.ts index 46675bf6d..d924e3529 100644 --- a/frontend/src/core/models/api.ts +++ b/frontend/src/core/models/api.ts @@ -1,8 +1,18 @@ import { getBackendBaseURL } from "../config"; +import { isStaticWebsiteOnly } from "../static-mode"; import type { ModelsResponse } from "./types"; +const STATIC_MODELS_RESPONSE: ModelsResponse = { + models: [], + token_usage: { enabled: false }, +}; + export async function loadModels(): Promise { + if (isStaticWebsiteOnly()) { + return STATIC_MODELS_RESPONSE; + } + const res = await fetch(`${getBackendBaseURL()}/api/models`); const data = (await res.json()) as Partial; return { diff --git a/frontend/src/core/static-mode.ts b/frontend/src/core/static-mode.ts new file mode 100644 index 000000000..2d035f128 --- /dev/null +++ b/frontend/src/core/static-mode.ts @@ -0,0 +1,5 @@ +import { env } from "@/env"; + +export function isStaticWebsiteOnly() { + return env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true"; +} diff --git a/frontend/src/core/threads/static-demo.ts b/frontend/src/core/threads/static-demo.ts new file mode 100644 index 000000000..93c8c1c53 --- /dev/null +++ b/frontend/src/core/threads/static-demo.ts @@ -0,0 +1,87 @@ +import type { ThreadState } from "@langchain/langgraph-sdk"; +import type { ThreadsClient } from "@langchain/langgraph-sdk/client"; + +import type { AgentThread, AgentThreadState } from "./types"; + +export const DEMO_THREAD_IDS = [ + "21cfea46-34bd-4aa6-9e1f-3009452fbeb9", + "3823e443-4e2b-4679-b496-a9506eae462b", + "4f3e55ee-f853-43db-bfb3-7d1a411f03cb", + "5aa47db1-d0cb-4eb9-aea5-3dac1b371c5a", + "7cfa5f8f-a2f8-47ad-acbd-da7137baf990", + "7f9dc56c-e49c-4671-a3d2-c492ff4dce0c", + "90040b36-7eba-4b97-ba89-02c3ad47a8b9", + "ad76c455-5bf9-4335-8517-fc03834ab828", + "b83fbb2a-4e36-4d82-9de0-7b2a02c2092a", + "c02bb4d5-4202-490e-ae8f-ff4864fc0d2e", + "d3e5adaf-084c-4dd5-9d29-94f1d6bccd98", + "f4125791-0128-402a-8ca9-50e0947557e4", + "fe3f7974-1bcb-4a01-a950-79673baafefd", +] as const; + +export type ThreadSearchParams = NonNullable< + Parameters[0] +>; + +export async function loadStaticDemoThreads( + params: ThreadSearchParams = {}, +): Promise { + const threads = await Promise.all( + DEMO_THREAD_IDS.map((threadId) => loadStaticDemoThread(threadId)), + ); + + const sortBy = params.sortBy ?? "updated_at"; + const sortOrder = params.sortOrder ?? "desc"; + const sortedThreads = [...threads].sort((a, b) => { + const aTimestamp = (a as unknown as Record)[sortBy]; + const bTimestamp = (b as unknown as Record)[sortBy]; + const aParsed = typeof aTimestamp === "string" ? Date.parse(aTimestamp) : 0; + const bParsed = typeof bTimestamp === "string" ? Date.parse(bTimestamp) : 0; + const aValue = Number.isNaN(aParsed) ? 0 : aParsed; + const bValue = Number.isNaN(bParsed) ? 0 : bParsed; + return sortOrder === "asc" ? aValue - bValue : bValue - aValue; + }); + + const offset = Math.max(0, Math.floor(params.offset ?? 0)); + const limit = + typeof params.limit === "number" + ? Math.max(0, Math.floor(params.limit)) + : sortedThreads.length; + return sortedThreads.slice(offset, offset + limit); +} + +export async function loadStaticDemoThread( + threadId: string, +): Promise { + const response = await globalThis.fetch( + `/demo/threads/${encodeURIComponent(threadId)}/thread.json`, + ); + if (!response.ok) { + throw new Error(`Failed to load demo thread ${threadId}`); + } + const thread = (await response.json()) as AgentThread; + return { + ...thread, + thread_id: threadId, + updated_at: thread.updated_at ?? thread.created_at, + }; +} + +export function staticDemoThreadState( + thread: AgentThread, +): ThreadState { + return { + values: thread.values, + next: [], + checkpoint: { + thread_id: thread.thread_id, + checkpoint_ns: "", + checkpoint_id: null, + checkpoint_map: null, + }, + metadata: thread.metadata ?? null, + created_at: thread.updated_at ?? thread.created_at ?? null, + parent_checkpoint: null, + tasks: [], + }; +} diff --git a/frontend/tests/unit/core/artifacts/utils.test.ts b/frontend/tests/unit/core/artifacts/utils.test.ts new file mode 100644 index 000000000..c0400b371 --- /dev/null +++ b/frontend/tests/unit/core/artifacts/utils.test.ts @@ -0,0 +1,69 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; + +const ENV_KEYS = [ + "NEXT_PUBLIC_BACKEND_BASE_URL", + "NEXT_PUBLIC_STATIC_WEBSITE_ONLY", +] as const; + +type EnvSnapshot = Partial< + Record<(typeof ENV_KEYS)[number], string | undefined> +>; + +function snapshotEnv(): EnvSnapshot { + const snapshot: EnvSnapshot = {}; + for (const key of ENV_KEYS) { + snapshot[key] = process.env[key]; + } + return snapshot; +} + +function setEnv(key: (typeof ENV_KEYS)[number], value: string | undefined) { + const env = process.env as Record; + if (value === undefined) { + delete env[key]; + } else { + env[key] = value; + } +} + +function restoreEnv(snapshot: EnvSnapshot) { + for (const key of ENV_KEYS) { + setEnv(key, snapshot[key]); + } +} + +async function loadFreshArtifactUtils() { + vi.resetModules(); + return await import("@/core/artifacts/utils"); +} + +describe("artifact URL helpers", () => { + let saved: EnvSnapshot; + + beforeEach(() => { + saved = snapshotEnv(); + setEnv("NEXT_PUBLIC_BACKEND_BASE_URL", undefined); + setEnv("NEXT_PUBLIC_STATIC_WEBSITE_ONLY", undefined); + }); + + afterEach(() => { + restoreEnv(saved); + }); + + test("maps static demo artifact paths to bundled public files", async () => { + setEnv("NEXT_PUBLIC_STATIC_WEBSITE_ONLY", "true"); + + const { resolveArtifactURL, urlOfArtifact } = + await loadFreshArtifactUtils(); + + expect( + urlOfArtifact({ + filepath: "/mnt/user-data/outputs/index.html", + threadId: "thread-1", + }), + ).toBe("/demo/threads/thread-1/user-data/outputs/index.html"); + expect( + resolveArtifactURL("/mnt/user-data/outputs/style.css", "thread-1"), + ).toBe("/demo/threads/thread-1/user-data/outputs/style.css"); + }); +}); diff --git a/frontend/tests/unit/core/auth/server.test.ts b/frontend/tests/unit/core/auth/server.test.ts new file mode 100644 index 000000000..fea6ef830 --- /dev/null +++ b/frontend/tests/unit/core/auth/server.test.ts @@ -0,0 +1,77 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from "vitest"; + +import { STATIC_WEBSITE_USER } from "@/core/auth/static-user"; + +vi.mock("next/headers", () => ({ + cookies: vi.fn(() => { + throw new Error("cookies should not be read in static website mode"); + }), +})); + +const ENV_KEYS = [ + "DEER_FLOW_AUTH_DISABLED", + "NEXT_PUBLIC_STATIC_WEBSITE_ONLY", +] as const; + +type EnvSnapshot = Partial< + Record<(typeof ENV_KEYS)[number], string | undefined> +>; + +function snapshotEnv(): EnvSnapshot { + const snapshot: EnvSnapshot = {}; + for (const key of ENV_KEYS) { + snapshot[key] = process.env[key]; + } + return snapshot; +} + +function setEnv(key: (typeof ENV_KEYS)[number], value: string | undefined) { + const env = process.env as Record; + if (value === undefined) { + delete env[key]; + } else { + env[key] = value; + } +} + +function restoreEnv(snapshot: EnvSnapshot) { + for (const key of ENV_KEYS) { + setEnv(key, snapshot[key]); + } +} + +async function loadFreshServerAuth() { + vi.resetModules(); + return await import("@/core/auth/server"); +} + +describe("getServerSideUser", () => { + let saved: EnvSnapshot; + + beforeEach(() => { + saved = snapshotEnv(); + setEnv("DEER_FLOW_AUTH_DISABLED", undefined); + setEnv("NEXT_PUBLIC_STATIC_WEBSITE_ONLY", undefined); + }); + + afterEach(() => { + restoreEnv(saved); + vi.unstubAllGlobals(); + }); + + test("bypasses gateway auth in static website mode", async () => { + setEnv("NEXT_PUBLIC_STATIC_WEBSITE_ONLY", "true"); + const fetchSpy = vi.fn(() => { + throw new Error("fetch should not be called in static website mode"); + }); + vi.stubGlobal("fetch", fetchSpy); + + const { getServerSideUser } = await loadFreshServerAuth(); + + await expect(getServerSideUser()).resolves.toEqual({ + tag: "authenticated", + user: STATIC_WEBSITE_USER, + }); + expect(fetchSpy).not.toHaveBeenCalled(); + }); +}); From a64a39dbc056f6c2b3f28e5feb65d2014471cebe Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Sat, 23 May 2026 09:38:25 +0200 Subject: [PATCH 73/86] config: raise default summarization trigger before v2.0-m1 (#3174) * config: update summarization configuration * docs: sync summarization trigger guidance --- config.example.yaml | 4 ++-- frontend/src/content/en/harness/middlewares.mdx | 2 +- frontend/src/content/zh/harness/middlewares.mdx | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/config.example.yaml b/config.example.yaml index 8e289fac9..4e5a1abce 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -799,9 +799,9 @@ summarization: # Summarization runs when ANY threshold is met (OR logic) # You can specify a single trigger or a list of triggers trigger: - # Trigger when token count reaches 15564 + # Trigger when token count reaches 32000 - type: tokens - value: 15564 + value: 32000 # Uncomment to also trigger when message count reaches 50 # - type: messages # value: 50 diff --git a/frontend/src/content/en/harness/middlewares.mdx b/frontend/src/content/en/harness/middlewares.mdx index 389b9881c..b185696a7 100644 --- a/frontend/src/content/en/harness/middlewares.mdx +++ b/frontend/src/content/en/harness/middlewares.mdx @@ -162,7 +162,7 @@ summarization: # Trigger conditions — summarization runs when ANY threshold is met trigger: - type: tokens # trigger when context exceeds N tokens - value: 15564 + value: 32000 # - type: messages # trigger when there are more than N messages # value: 50 # - type: fraction # trigger when context exceeds X% of model max diff --git a/frontend/src/content/zh/harness/middlewares.mdx b/frontend/src/content/zh/harness/middlewares.mdx index 9e81caa3e..0b6dc4894 100644 --- a/frontend/src/content/zh/harness/middlewares.mdx +++ b/frontend/src/content/zh/harness/middlewares.mdx @@ -154,7 +154,7 @@ summarization: # 触发条件——满足任意一个条件时运行摘要 trigger: - type: tokens # 当上下文超过 N 个 token 时触发 - value: 15564 + value: 32000 # - type: messages # 当消息数超过 N 时触发 # value: 50 # - type: fraction # 当上下文达到模型最大输入的 X% 时触发 From 604fcbb9d238d6ffff081be7a0f7b43babbc03cf Mon Sep 17 00:00:00 2001 From: AochenShen99 Date: Sat, 23 May 2026 16:56:14 +0800 Subject: [PATCH 74/86] Stabilize write artifact previews (#3172) --- frontend/.prettierignore | 2 + .../artifacts/artifact-file-detail.tsx | 239 +++++++++++--- frontend/src/core/artifacts/loader.ts | 9 + frontend/src/core/artifacts/preview.ts | 278 ++++++++++++++++ frontend/tests/e2e/artifact-preview.spec.ts | 174 ++++++++++ frontend/tests/e2e/utils/mock-api.ts | 56 +++- .../tests/unit/core/artifacts/preview.test.ts | 310 ++++++++++++++++++ 7 files changed, 1022 insertions(+), 46 deletions(-) create mode 100644 frontend/src/core/artifacts/preview.ts create mode 100644 frontend/tests/e2e/artifact-preview.spec.ts create mode 100644 frontend/tests/unit/core/artifacts/preview.test.ts diff --git a/frontend/.prettierignore b/frontend/.prettierignore index 1eebfc69d..c409ef819 100644 --- a/frontend/.prettierignore +++ b/frontend/.prettierignore @@ -1,3 +1,5 @@ pnpm-lock.yaml .omc/ src/content/**/*.mdx +playwright-report/ +test-results/ diff --git a/frontend/src/components/workspace/artifacts/artifact-file-detail.tsx b/frontend/src/components/workspace/artifacts/artifact-file-detail.tsx index 93130c44f..46ae18441 100644 --- a/frontend/src/components/workspace/artifacts/artifact-file-detail.tsx +++ b/frontend/src/components/workspace/artifacts/artifact-file-detail.tsx @@ -8,7 +8,7 @@ import { SquareArrowOutUpRightIcon, XIcon, } from "lucide-react"; -import { useCallback, useEffect, useMemo, useState } from "react"; +import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { toast } from "sonner"; import { Streamdown } from "streamdown"; @@ -30,8 +30,16 @@ import { import { ToggleGroup, ToggleGroupItem } from "@/components/ui/toggle-group"; import { CodeEditor } from "@/components/workspace/code-editor"; import { useArtifactContent } from "@/core/artifacts/hooks"; +import { + appendHtmlPreviewBaseHref, + appendHtmlPreviewScrollRestoration, + createHtmlPreviewScrollKey, + getArtifactViewState, + HTML_PREVIEW_SCROLL_MESSAGE_SOURCE, +} from "@/core/artifacts/preview"; import { urlOfArtifact } from "@/core/artifacts/utils"; import { useI18n } from "@/core/i18n/hooks"; +import { findToolCallResult } from "@/core/messages/utils"; import { installSkill } from "@/core/skills/api"; import { streamdownPlugins } from "@/core/streamdown"; import { checkCodeFile, getFileName } from "@/core/utils/files"; @@ -44,6 +52,8 @@ import { Tooltip } from "../tooltip"; import { useArtifacts } from "./context"; +const WRITE_FILE_PREVIEW_REFRESH_INTERVAL_MS = 3000; + export function ArtifactFileDetail({ className, filepath: filepathFromProps, @@ -55,6 +65,7 @@ export function ArtifactFileDetail({ }) { const { t } = useI18n(); const { artifacts, setOpen, select } = useArtifacts(); + const { thread, isMock } = useThread(); const isWriteFile = useMemo(() => { return filepathFromProps.startsWith("write-file:"); }, [filepathFromProps]); @@ -83,6 +94,22 @@ export function ArtifactFileDetail({ const isSupportPreview = useMemo(() => { return language === "html" || language === "markdown"; }, [language]); + const toolResult = (() => { + if (!isWriteFile) { + return undefined; + } + const url = new URL(filepathFromProps); + const toolCallId = url.searchParams.get("tool_call_id"); + if (!toolCallId) { + return undefined; + } + return findToolCallResult(toolCallId, thread.messages); + })(); + const artifactViewState = getArtifactViewState({ + filepath: filepathFromProps, + isSupportPreview, + toolResult, + }); const { content, url } = useArtifactContent({ threadId, filepath: filepathFromProps, @@ -90,17 +117,20 @@ export function ArtifactFileDetail({ }); const displayContent = content ?? ""; + const isWritingFile = isWriteFile && toolResult === undefined; + const visibleContent = useThrottledValue( + displayContent, + isWritingFile ? WRITE_FILE_PREVIEW_REFRESH_INTERVAL_MS : 0, + filepathFromProps, + ); - const [viewMode, setViewMode] = useState<"code" | "preview">("code"); + const [viewMode, setViewMode] = useState<"code" | "preview">( + artifactViewState.initialViewMode, + ); const [isInstalling, setIsInstalling] = useState(false); - const { isMock } = useThread(); useEffect(() => { - if (isSupportPreview) { - setViewMode("preview"); - } else { - setViewMode("code"); - } - }, [isSupportPreview]); + setViewMode(artifactViewState.initialViewMode); + }, [artifactViewState.initialViewMode]); const handleInstallSkill = useCallback(async () => { if (isInstalling) return; @@ -149,7 +179,7 @@ export function ArtifactFileDetail({
- {isSupportPreview && ( + {artifactViewState.canPreview && ( { try { - await navigator.clipboard.writeText(displayContent ?? ""); + await navigator.clipboard.writeText(visibleContent ?? ""); toast.success(t.clipboard.copiedToClipboard); } catch (error) { toast.error("Failed to copy to clipboard"); @@ -249,20 +279,20 @@ export function ArtifactFileDetail({
- {isSupportPreview && + {artifactViewState.canPreview && viewMode === "preview" && (language === "markdown" || language === "html") && ( )} {isCodeFile && viewMode === "code" && ( )} @@ -279,25 +309,78 @@ export function ArtifactFileDetail({ export function ArtifactFilePreview({ content, - isWriteFile, language, + scrollKey, url, }: { content: string; - isWriteFile: boolean; language: string; + scrollKey: string; url?: string; }) { + const iframeRef = useRef(null); + const scrollPositionRef = useRef({ x: 0, y: 0 }); + const scrollMessageKey = useMemo( + () => createHtmlPreviewScrollKey(scrollKey), + [scrollKey], + ); const [htmlPreviewUrl, setHtmlPreviewUrl] = useState(); useEffect(() => { - if (language !== "html" || isWriteFile) { + scrollPositionRef.current = { x: 0, y: 0 }; + }, [scrollMessageKey]); + + useEffect(() => { + if (language !== "html") { + return; + } + + const handleMessage = (event: MessageEvent) => { + if (event.source !== iframeRef.current?.contentWindow) { + return; + } + if (!isArtifactScrollMessage(event.data, scrollMessageKey)) { + return; + } + + if (event.data.type === "save") { + const x = scrollCoordinate(event.data.x); + const y = scrollCoordinate(event.data.y); + if (x !== undefined && y !== undefined) { + scrollPositionRef.current = { x, y }; + } + return; + } + + iframeRef.current?.contentWindow?.postMessage( + { + source: HTML_PREVIEW_SCROLL_MESSAGE_SOURCE, + key: scrollMessageKey, + type: "restore", + ...scrollPositionRef.current, + }, + "*", + ); + }; + + window.addEventListener("message", handleMessage); + return () => { + window.removeEventListener("message", handleMessage); + }; + }, [language, scrollMessageKey]); + + useEffect(() => { + if (language !== "html") { setHtmlPreviewUrl(undefined); return; } - const blob = new Blob([htmlWithBaseHref(content ?? "", url)], { - type: "text/html", + const previewContent = appendHtmlPreviewScrollRestoration( + appendHtmlPreviewBaseHref(content ?? "", url), + scrollKey, + ); + const blob = new Blob([previewContent], { + type: "text/html;charset=utf-8", }); const objectUrl = URL.createObjectURL(blob); setHtmlPreviewUrl(objectUrl); @@ -305,7 +388,7 @@ export function ArtifactFilePreview({ return () => { URL.revokeObjectURL(objectUrl); }; - }, [content, isWriteFile, language, url]); + }, [content, language, scrollKey, url]); if (language === "markdown") { return ( @@ -323,38 +406,110 @@ export function ArtifactFilePreview({ if (language === "html") { return (