From 8939ccaed2f03ba831b16c83605d09b80f36e632 Mon Sep 17 00:00:00 2001 From: KiteEater <145987840+Kiteeater@users.noreply.github.com> Date: Fri, 1 May 2026 20:19:30 +0800 Subject: [PATCH 01/11] fix(uploads): enforce streaming upload limits in gateway (#2589) * fix: enforce gateway upload limits * fix: acquire sandbox before upload writes * Fix upload limit config wiring * Sanitize upload size error filenames * test: call upload routes unwrapped * fix: guard upload limits endpoint --------- Co-authored-by: Willem Jiang --- backend/app/gateway/routers/uploads.py | 124 +++++++++++- backend/docs/FILE_UPLOAD.md | 22 ++- backend/tests/test_uploads_router.py | 256 ++++++++++++++++++++++++- config.example.yaml | 5 + 4 files changed, 393 insertions(+), 14 deletions(-) diff --git a/backend/app/gateway/routers/uploads.py b/backend/app/gateway/routers/uploads.py index 0ecc2266a..604a6e154 100644 --- a/backend/app/gateway/routers/uploads.py +++ b/backend/app/gateway/routers/uploads.py @@ -30,6 +30,11 @@ logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/threads/{thread_id}/uploads", tags=["uploads"]) +UPLOAD_CHUNK_SIZE = 8192 +DEFAULT_MAX_FILES = 10 +DEFAULT_MAX_FILE_SIZE = 50 * 1024 * 1024 +DEFAULT_MAX_TOTAL_SIZE = 100 * 1024 * 1024 + class UploadResponse(BaseModel): """Response model for file upload.""" @@ -39,6 +44,14 @@ class UploadResponse(BaseModel): message: str +class UploadLimits(BaseModel): + """Application-level upload limits exposed to clients.""" + + max_files: int + max_file_size: int + max_total_size: int + + def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None: """Ensure uploaded files remain writable when mounted into non-local sandboxes. @@ -69,6 +82,62 @@ def _get_uploads_config_value(app_config: AppConfig, key: str, default: object) return getattr(uploads_cfg, key, default) +def _get_upload_limit(app_config: AppConfig, key: str, default: int, *, legacy_key: str | None = None) -> int: + try: + value = _get_uploads_config_value(app_config, key, None) + if value is None and legacy_key is not None: + value = _get_uploads_config_value(app_config, legacy_key, None) + if value is None: + value = default + limit = int(value) + if limit <= 0: + raise ValueError + return limit + except Exception: + logger.warning("Invalid uploads.%s value; falling back to %d", key, default) + return default + + +def _get_upload_limits(app_config: AppConfig) -> UploadLimits: + return UploadLimits( + max_files=_get_upload_limit(app_config, "max_files", DEFAULT_MAX_FILES, legacy_key="max_file_count"), + max_file_size=_get_upload_limit(app_config, "max_file_size", DEFAULT_MAX_FILE_SIZE, legacy_key="max_single_file_size"), + max_total_size=_get_upload_limit(app_config, "max_total_size", DEFAULT_MAX_TOTAL_SIZE), + ) + + +def _cleanup_uploaded_paths(paths: list[os.PathLike[str] | str]) -> None: + for path in reversed(paths): + try: + os.unlink(path) + except FileNotFoundError: + pass + except Exception: + logger.warning("Failed to clean up upload path after rejected request: %s", path, exc_info=True) + + +async def _write_upload_file_streaming( + file: UploadFile, + file_path: os.PathLike[str] | str, + *, + display_filename: str, + max_single_file_size: int, + max_total_size: int, + total_size: int, +) -> tuple[int, int]: + file_size = 0 + with open(file_path, "wb") as output: + while chunk := await file.read(UPLOAD_CHUNK_SIZE): + file_size += len(chunk) + total_size += len(chunk) + if file_size > max_single_file_size: + raise HTTPException(status_code=413, detail=f"File too large: {display_filename}") + if total_size > max_total_size: + raise HTTPException(status_code=413, detail="Total upload size too large") + output.write(chunk) + return file_size, total_size + + def _auto_convert_documents_enabled(app_config: AppConfig) -> bool: """Return whether automatic host-side document conversion is enabled. @@ -96,12 +165,19 @@ async def upload_files( if not files: raise HTTPException(status_code=400, detail="No files provided") + limits = _get_upload_limits(config) + if len(files) > limits.max_files: + raise HTTPException(status_code=413, detail=f"Too many files: maximum is {limits.max_files}") + try: uploads_dir = ensure_uploads_dir(thread_id) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) uploaded_files = [] + written_paths = [] + sandbox_sync_targets = [] + total_size = 0 sandbox_provider = get_sandbox_provider() sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider) @@ -109,6 +185,8 @@ async def upload_files( if sync_to_sandbox: sandbox_id = sandbox_provider.acquire(thread_id) sandbox = sandbox_provider.get(sandbox_id) + if sandbox is None: + raise HTTPException(status_code=500, detail="Failed to acquire sandbox") auto_convert_documents = _auto_convert_documents_enabled(config) for file in files: @@ -122,35 +200,41 @@ async def upload_files( continue try: - content = await file.read() file_path = uploads_dir / safe_filename - file_path.write_bytes(content) + written_paths.append(file_path) + file_size, total_size = await _write_upload_file_streaming( + file, + file_path, + display_filename=safe_filename, + max_single_file_size=limits.max_file_size, + max_total_size=limits.max_total_size, + total_size=total_size, + ) virtual_path = upload_virtual_path(safe_filename) - if sync_to_sandbox and sandbox is not None: - _make_file_sandbox_writable(file_path) - sandbox.update_file(virtual_path, content) + if sync_to_sandbox: + sandbox_sync_targets.append((file_path, virtual_path)) file_info = { "filename": safe_filename, - "size": str(len(content)), + "size": str(file_size), "path": str(sandbox_uploads / safe_filename), "virtual_path": virtual_path, "artifact_url": upload_artifact_url(thread_id, safe_filename), } - logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}") + logger.info(f"Saved file: {safe_filename} ({file_size} bytes) to {file_info['path']}") file_ext = file_path.suffix.lower() if auto_convert_documents and file_ext in CONVERTIBLE_EXTENSIONS: md_path = await convert_file_to_markdown(file_path) if md_path: + written_paths.append(md_path) md_virtual_path = upload_virtual_path(md_path.name) - if sync_to_sandbox and sandbox is not None: - _make_file_sandbox_writable(md_path) - sandbox.update_file(md_virtual_path, md_path.read_bytes()) + if sync_to_sandbox: + sandbox_sync_targets.append((md_path, md_virtual_path)) file_info["markdown_file"] = md_path.name file_info["markdown_path"] = str(sandbox_uploads / md_path.name) @@ -159,10 +243,19 @@ async def upload_files( uploaded_files.append(file_info) + except HTTPException as e: + _cleanup_uploaded_paths(written_paths) + raise e except Exception as e: logger.error(f"Failed to upload {file.filename}: {e}") + _cleanup_uploaded_paths(written_paths) raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}") + if sync_to_sandbox: + for file_path, virtual_path in sandbox_sync_targets: + _make_file_sandbox_writable(file_path) + sandbox.update_file(virtual_path, file_path.read_bytes()) + return UploadResponse( success=True, files=uploaded_files, @@ -170,6 +263,17 @@ async def upload_files( ) +@router.get("/limits", response_model=UploadLimits) +@require_permission("threads", "read", owner_check=True) +async def get_upload_limits( + thread_id: str, + request: Request, + config: AppConfig = Depends(get_config), +) -> UploadLimits: + """Return upload limits used by the gateway for this thread.""" + return _get_upload_limits(config) + + @router.get("/list", response_model=dict) @require_permission("threads", "read", owner_check=True) async def list_uploaded_files(thread_id: str, request: Request) -> dict: diff --git a/backend/docs/FILE_UPLOAD.md b/backend/docs/FILE_UPLOAD.md index bf6962e4f..2b15b27e7 100644 --- a/backend/docs/FILE_UPLOAD.md +++ b/backend/docs/FILE_UPLOAD.md @@ -22,6 +22,8 @@ POST /api/threads/{thread_id}/uploads **请求体:** `multipart/form-data` - `files`: 一个或多个文件 +网关会在应用层限制上传规模,默认最多 10 个文件、单文件 50 MiB、单次请求总计 100 MiB。可通过 `config.yaml` 的 `uploads.max_files`、`uploads.max_file_size`、`uploads.max_total_size` 调整;前端会读取同一组限制并在选择文件时提示,超过限制时后端返回 `413 Payload Too Large`。 + **响应:** ```json { @@ -48,7 +50,23 @@ POST /api/threads/{thread_id}/uploads - `virtual_path`: Agent 在沙箱中使用的虚拟路径 - `artifact_url`: 前端通过 HTTP 访问文件的 URL -### 2. 列出已上传文件 +### 2. 查询上传限制 +``` +GET /api/threads/{thread_id}/uploads/limits +``` + +返回网关当前生效的上传限制,供前端在用户选择文件前提示和拦截。 + +**响应:** +```json +{ + "max_files": 10, + "max_file_size": 52428800, + "max_total_size": 104857600 +} +``` + +### 3. 列出已上传文件 ``` GET /api/threads/{thread_id}/uploads/list ``` @@ -71,7 +89,7 @@ GET /api/threads/{thread_id}/uploads/list } ``` -### 3. 删除文件 +### 4. 删除文件 ``` DELETE /api/threads/{thread_id}/uploads/{filename} ``` diff --git a/backend/tests/test_uploads_router.py b/backend/tests/test_uploads_router.py index 7f9b442d0..a2538ec40 100644 --- a/backend/tests/test_uploads_router.py +++ b/backend/tests/test_uploads_router.py @@ -5,12 +5,35 @@ from pathlib import Path from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch -from _router_auth_helpers import call_unwrapped -from fastapi import UploadFile +import pytest +from _router_auth_helpers import call_unwrapped, make_authed_test_app +from fastapi import HTTPException, UploadFile +from fastapi.testclient import TestClient from app.gateway.routers import uploads +class ChunkedUpload: + def __init__(self, filename: str, chunks: list[bytes]): + self.filename = filename + self._chunks = list(chunks) + self.read_calls: list[int | None] = [] + + async def read(self, size: int | None = None) -> bytes: + self.read_calls.append(size) + if size is None: + raise AssertionError("upload must be read with an explicit chunk size") + if not self._chunks: + return b"" + return self._chunks.pop(0) + + +def _mounted_provider() -> MagicMock: + provider = MagicMock() + provider.uses_thread_data_mounts = True + return provider + + def test_upload_files_writes_thread_storage_and_skips_local_sandbox_sync(tmp_path): thread_uploads_dir = tmp_path / "uploads" thread_uploads_dir.mkdir(parents=True) @@ -178,6 +201,173 @@ def test_upload_files_does_not_adjust_permissions_for_local_sandbox(tmp_path): make_writable.assert_not_called() +def test_upload_files_acquires_non_local_sandbox_before_writing(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + provider = MagicMock() + provider.uses_thread_data_mounts = False + sandbox = MagicMock() + provider.get.return_value = sandbox + + def acquire_before_writes(thread_id: str) -> str: + assert list(thread_uploads_dir.iterdir()) == [] + return "aio-1" + + provider.acquire.side_effect = acquire_before_writes + + with ( + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=provider), + ): + file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads")) + result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace())) + + assert result.success is True + provider.acquire.assert_called_once_with("thread-aio") + sandbox.update_file.assert_called_once_with("/mnt/user-data/uploads/notes.txt", b"hello uploads") + + +def test_upload_files_fails_before_writing_when_non_local_sandbox_unavailable(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + provider = MagicMock() + provider.uses_thread_data_mounts = False + provider.acquire.side_effect = RuntimeError("sandbox unavailable") + file = ChunkedUpload("notes.txt", [b"hello uploads"]) + + with ( + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=provider), + ): + with pytest.raises(RuntimeError, match="sandbox unavailable"): + asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace())) + + assert list(thread_uploads_dir.iterdir()) == [] + assert file.read_calls == [] + provider.get.assert_not_called() + + +def test_upload_files_rejects_too_many_files_before_writing(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + with ( + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=_mounted_provider()), + patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=1, max_file_size=10, max_total_size=20)), + ): + files = [ + ChunkedUpload("one.txt", [b"one"]), + ChunkedUpload("two.txt", [b"two"]), + ] + with pytest.raises(HTTPException) as exc_info: + asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=files, config=SimpleNamespace())) + + assert exc_info.value.status_code == 413 + assert list(thread_uploads_dir.iterdir()) == [] + assert files[0].read_calls == [] + assert files[1].read_calls == [] + + +def test_upload_files_rejects_oversized_single_file_and_removes_partial_file(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + provider = _mounted_provider() + file = ChunkedUpload("big.txt", [b"123456"]) + + with ( + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=provider), + patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=10, max_file_size=5, max_total_size=20)), + ): + with pytest.raises(HTTPException) as exc_info: + asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace())) + + assert exc_info.value.status_code == 413 + assert not (thread_uploads_dir / "big.txt").exists() + assert file.read_calls == [8192] + provider.acquire.assert_not_called() + + +def test_upload_files_rejects_total_size_over_limit_and_cleans_request_files(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + with ( + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=_mounted_provider()), + patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=10, max_file_size=10, max_total_size=5)), + ): + files = [ + ChunkedUpload("first.txt", [b"123"]), + ChunkedUpload("second.txt", [b"456"]), + ] + with pytest.raises(HTTPException) as exc_info: + asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=files, config=SimpleNamespace())) + + assert exc_info.value.status_code == 413 + assert not (thread_uploads_dir / "first.txt").exists() + assert not (thread_uploads_dir / "second.txt").exists() + + +def test_upload_files_does_not_sync_non_local_sandbox_when_total_size_exceeds_limit(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + provider = MagicMock() + provider.uses_thread_data_mounts = False + provider.acquire.return_value = "aio-1" + sandbox = MagicMock() + provider.get.return_value = sandbox + + with ( + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=provider), + patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=10, max_file_size=10, max_total_size=5)), + ): + files = [ + ChunkedUpload("first.txt", [b"123"]), + ChunkedUpload("second.txt", [b"456"]), + ] + with pytest.raises(HTTPException) as exc_info: + asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=files, config=SimpleNamespace())) + + assert exc_info.value.status_code == 413 + provider.acquire.assert_called_once_with("thread-aio") + provider.get.assert_called_once_with("aio-1") + sandbox.update_file.assert_not_called() + + +def test_upload_files_does_not_sync_non_local_sandbox_when_conversion_fails(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + provider = MagicMock() + provider.uses_thread_data_mounts = False + provider.acquire.return_value = "aio-1" + sandbox = MagicMock() + provider.get.return_value = sandbox + + with ( + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=provider), + patch.object(uploads, "_auto_convert_documents_enabled", return_value=True), + patch.object(uploads, "convert_file_to_markdown", AsyncMock(side_effect=RuntimeError("conversion failed"))), + ): + file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes")) + with pytest.raises(HTTPException) as exc_info: + asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace())) + + assert exc_info.value.status_code == 500 + provider.acquire.assert_called_once_with("thread-aio") + provider.get.assert_called_once_with("aio-1") + sandbox.update_file.assert_not_called() + assert not (thread_uploads_dir / "report.pdf").exists() + + def test_make_file_sandbox_writable_adds_write_bits_for_regular_files(tmp_path): file_path = tmp_path / "report.pdf" file_path.write_bytes(b"pdf-bytes") @@ -286,3 +476,65 @@ def test_auto_convert_documents_enabled_accepts_boolean_and_string_truthy_values assert uploads._auto_convert_documents_enabled(true_cfg) is True assert uploads._auto_convert_documents_enabled(string_true_cfg) is True assert uploads._auto_convert_documents_enabled(string_false_cfg) is False + + +def test_upload_limits_endpoint_reads_uploads_config(): + cfg = MagicMock() + cfg.uploads = { + "max_files": 15, + "max_file_size": "1048576", + "max_total_size": 2097152, + } + + result = asyncio.run(call_unwrapped(uploads.get_upload_limits, "thread-local", request=MagicMock(), config=cfg)) + + assert result.max_files == 15 + assert result.max_file_size == 1048576 + assert result.max_total_size == 2097152 + + +def test_upload_limits_endpoint_requires_thread_access(): + cfg = MagicMock() + cfg.uploads = {} + app = make_authed_test_app(owner_check_passes=False) + app.state.config = cfg + app.include_router(uploads.router) + + with TestClient(app) as client: + response = client.get("/api/threads/thread-local/uploads/limits") + + assert response.status_code == 404 + + +def test_upload_limits_accept_legacy_config_keys(): + cfg = MagicMock() + cfg.uploads = { + "max_file_count": 7, + "max_single_file_size": 123, + "max_total_size": 456, + } + + limits = uploads._get_upload_limits(cfg) + + assert limits == uploads.UploadLimits(max_files=7, max_file_size=123, max_total_size=456) + + +def test_upload_files_uses_configured_file_count_limit(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + + cfg = MagicMock() + cfg.uploads = {"max_files": 1} + + with ( + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=_mounted_provider()), + ): + files = [ + ChunkedUpload("one.txt", [b"one"]), + ChunkedUpload("two.txt", [b"two"]), + ] + with pytest.raises(HTTPException) as exc_info: + asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=files, config=cfg)) + + assert exc_info.value.status_code == 413 diff --git a/config.example.yaml b/config.example.yaml index a14ca5886..04ccd0b12 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -501,6 +501,11 @@ tool_search: # Option 1: Local Sandbox (Default) # Executes commands directly on the host machine uploads: + # Application-level upload limits enforced by the gateway and exposed to the + # frontend before file selection. + max_files: 10 + max_file_size: 52428800 # 50 MiB + max_total_size: 104857600 # 100 MiB # Automatic Office/PDF conversion runs on the backend host before sandbox # isolation applies. Keep this disabled unless uploads come from a fully # trusted source and you intentionally accept host-side parser risk. From c09c33454458f2d6b7dc1c1352a440ba49746072 Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Fri, 1 May 2026 16:19:50 +0200 Subject: [PATCH 02/11] fix(harness): resolve runtime paths from project root (#2642) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(harness): resolve runtime paths from project root * docs(config): update * fix(config): address runtime path review feedback * test(config): fix skills path e2e root * test(config): cover legacy config fallback when project root lacks config files Verifies that when DEER_FLOW_PROJECT_ROOT is unset and cwd has no config.yaml/extensions_config.json, AppConfig and ExtensionsConfig fall back to the legacy backend/repo-root candidates — the backward-compat path requested in PR #2642 review. --------- Co-authored-by: Willem Jiang --- README.md | 2 +- README_zh.md | 2 +- backend/docs/CONFIGURATION.md | 16 +- backend/docs/SETUP.md | 22 ++- .../harness/deerflow/config/app_config.py | 16 +- .../deerflow/config/extensions_config.py | 13 +- .../packages/harness/deerflow/config/paths.py | 9 +- .../harness/deerflow/config/runtime_paths.py | 41 +++++ .../harness/deerflow/config/skills_config.py | 22 +-- .../skills/storage/local_skill_storage.py | 7 +- backend/tests/test_client_e2e.py | 7 + backend/tests/test_runtime_paths.py | 145 ++++++++++++++++++ backend/tests/test_skills_loader.py | 21 ++- config.example.yaml | 10 +- docker/docker-compose-dev.yaml | 1 + docker/docker-compose.yaml | 5 +- 16 files changed, 284 insertions(+), 55 deletions(-) create mode 100644 backend/packages/harness/deerflow/config/runtime_paths.py create mode 100644 backend/tests/test_runtime_paths.py diff --git a/README.md b/README.md index c67fdc005..0fc8f173e 100644 --- a/README.md +++ b/README.md @@ -251,7 +251,7 @@ See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide. If you prefer running services locally: -Prerequisite: complete the "Configuration" steps above first (`make setup`). `make dev` requires a valid `config.yaml` in the project root (can be overridden via `DEER_FLOW_CONFIG_PATH`). Run `make doctor` to verify your setup before starting. +Prerequisite: complete the "Configuration" steps above first (`make setup`). `make dev` requires a valid `config.yaml` in the project root. Set `DEER_FLOW_PROJECT_ROOT` to define that root explicitly, or `DEER_FLOW_CONFIG_PATH` to point at a specific config file. Runtime state defaults to `.deer-flow` under the project root and can be moved with `DEER_FLOW_HOME`; skills default to `skills/` under the project root and can be moved with `DEER_FLOW_SKILLS_PATH`. Run `make doctor` to verify your setup before starting. On Windows, run the local development flow from Git Bash. Native `cmd.exe` and PowerShell shells are not supported for the bash-based service scripts, and WSL is not guaranteed because some scripts rely on Git for Windows utilities such as `cygpath`. 1. **Check prerequisites**: diff --git a/README_zh.md b/README_zh.md index 6e4a618c7..d5317082e 100644 --- a/README_zh.md +++ b/README_zh.md @@ -194,7 +194,7 @@ make down # 停止并移除容器 如果你更希望直接在本地启动各个服务: -前提:先完成上面的“配置”步骤(`make config` 和模型 API key 配置)。`make dev` 需要有效配置文件,默认读取项目根目录下的 `config.yaml`,也可以通过 `DEER_FLOW_CONFIG_PATH` 覆盖。 +前提:先完成上面的“配置”步骤(`make config` 和模型 API key 配置)。`make dev` 需要有效配置文件,默认读取项目根目录下的 `config.yaml`。可以用 `DEER_FLOW_PROJECT_ROOT` 显式指定项目根目录,也可以用 `DEER_FLOW_CONFIG_PATH` 指向某个具体配置文件。运行期状态默认写到项目根目录下的 `.deer-flow`,可用 `DEER_FLOW_HOME` 覆盖;skills 默认读取项目根目录下的 `skills/`,可用 `DEER_FLOW_SKILLS_PATH` 覆盖。 在 Windows 上,请使用 Git Bash 运行本地开发流程。基于 bash 的服务脚本不支持直接在原生 `cmd.exe` 或 PowerShell 中执行,且 WSL 也不保证可用,因为部分脚本依赖 Git for Windows 的 `cygpath` 等工具。 1. **检查依赖环境**: diff --git a/backend/docs/CONFIGURATION.md b/backend/docs/CONFIGURATION.md index f87fdd236..26137951f 100644 --- a/backend/docs/CONFIGURATION.md +++ b/backend/docs/CONFIGURATION.md @@ -321,12 +321,16 @@ models: - `DEEPSEEK_API_KEY` - DeepSeek API key - `NOVITA_API_KEY` - Novita API key (OpenAI-compatible endpoint) - `TAVILY_API_KEY` - Tavily search API key +- `DEER_FLOW_PROJECT_ROOT` - Project root for relative runtime paths - `DEER_FLOW_CONFIG_PATH` - Custom config file path +- `DEER_FLOW_EXTENSIONS_CONFIG_PATH` - Custom extensions config file path +- `DEER_FLOW_HOME` - Runtime state directory (defaults to `.deer-flow` under the project root) +- `DEER_FLOW_SKILLS_PATH` - Skills directory when `skills.path` is omitted - `GATEWAY_ENABLE_DOCS` - Set to `false` to disable Swagger UI (`/docs`), ReDoc (`/redoc`), and OpenAPI schema (`/openapi.json`) endpoints (default: `true`) ## Configuration Location -The configuration file should be placed in the **project root directory** (`deer-flow/config.yaml`), not in the backend directory. +The configuration file should be placed in the **project root directory** (`deer-flow/config.yaml`). Set `DEER_FLOW_PROJECT_ROOT` when the process may start from another working directory, or set `DEER_FLOW_CONFIG_PATH` to point at a specific file. ## Configuration Priority @@ -334,12 +338,12 @@ DeerFlow searches for configuration in this order: 1. Path specified in code via `config_path` argument 2. Path from `DEER_FLOW_CONFIG_PATH` environment variable -3. `config.yaml` in current working directory (typically `backend/` when running) -4. `config.yaml` in parent directory (project root: `deer-flow/`) +3. `config.yaml` under `DEER_FLOW_PROJECT_ROOT`, or under the current working directory when `DEER_FLOW_PROJECT_ROOT` is unset +4. Legacy backend/repository-root locations for monorepo compatibility ## Best Practices -1. **Place `config.yaml` in project root** - Not in `backend/` directory +1. **Place `config.yaml` in project root** - Set `DEER_FLOW_PROJECT_ROOT` if the runtime starts elsewhere 2. **Never commit `config.yaml`** - It's already in `.gitignore` 3. **Use environment variables for secrets** - Don't hardcode API keys 4. **Keep `config.example.yaml` updated** - Document all new options @@ -350,7 +354,7 @@ DeerFlow searches for configuration in this order: ### "Config file not found" - Ensure `config.yaml` exists in the **project root** directory (`deer-flow/config.yaml`) -- The backend searches parent directory by default, so root location is preferred +- If the runtime starts outside the project root, set `DEER_FLOW_PROJECT_ROOT` - Alternatively, set `DEER_FLOW_CONFIG_PATH` environment variable to custom location ### "Invalid API key" @@ -360,7 +364,7 @@ DeerFlow searches for configuration in this order: ### "Skills not loading" - Check that `deer-flow/skills/` directory exists - Verify skills have valid `SKILL.md` files -- Check `skills.path` configuration if using custom path +- Check `skills.path` or `DEER_FLOW_SKILLS_PATH` if using a custom path ### "Docker sandbox fails to start" - Ensure Docker is running diff --git a/backend/docs/SETUP.md b/backend/docs/SETUP.md index 50885eb3f..aff0e287f 100644 --- a/backend/docs/SETUP.md +++ b/backend/docs/SETUP.md @@ -23,6 +23,9 @@ DeerFlow uses a YAML configuration file that should be placed in the **project r # Option A: Set environment variables (recommended) export OPENAI_API_KEY="your-key-here" + # Optional: pin the project root when running from another directory + export DEER_FLOW_PROJECT_ROOT="/path/to/deer-flow" + # Option B: Edit config.yaml directly vim config.yaml # or your preferred editor ``` @@ -35,17 +38,20 @@ DeerFlow uses a YAML configuration file that should be placed in the **project r ## Important Notes -- **Location**: `config.yaml` should be in `deer-flow/` (project root), not `deer-flow/backend/` +- **Location**: `config.yaml` should be in `deer-flow/` (project root) - **Git**: `config.yaml` is automatically ignored by git (contains secrets) -- **Priority**: If both `backend/config.yaml` and `../config.yaml` exist, backend version takes precedence +- **Runtime root**: Set `DEER_FLOW_PROJECT_ROOT` if DeerFlow may start from outside the project root +- **Runtime data**: State defaults to `.deer-flow` under the project root; set `DEER_FLOW_HOME` to move it +- **Skills**: Skills default to `skills/` under the project root; set `DEER_FLOW_SKILLS_PATH` or `skills.path` to move them ## Configuration File Locations The backend searches for `config.yaml` in this order: -1. `DEER_FLOW_CONFIG_PATH` environment variable (if set) -2. `backend/config.yaml` (current directory when running from backend/) -3. `deer-flow/config.yaml` (parent directory - **recommended location**) +1. Explicit `config_path` argument from code +2. `DEER_FLOW_CONFIG_PATH` environment variable (if set) +3. `config.yaml` under `DEER_FLOW_PROJECT_ROOT`, or the current working directory when `DEER_FLOW_PROJECT_ROOT` is unset +4. Legacy backend/repository-root locations for monorepo compatibility **Recommended**: Place `config.yaml` in project root (`deer-flow/config.yaml`). @@ -77,8 +83,8 @@ python -c "from deerflow.config.app_config import AppConfig; print(AppConfig.res If it can't find the config: 1. Ensure you've copied `config.example.yaml` to `config.yaml` -2. Verify you're in the correct directory -3. Check the file exists: `ls -la ../config.yaml` +2. Verify you're in the project root, or set `DEER_FLOW_PROJECT_ROOT` +3. Check the file exists: `ls -la config.yaml` ### Permission denied @@ -89,4 +95,4 @@ chmod 600 ../config.yaml # Protect sensitive configuration ## See Also - [Configuration Guide](CONFIGURATION.md) - Detailed configuration options -- [Architecture Overview](../CLAUDE.md) - System architecture \ No newline at end of file +- [Architecture Overview](../CLAUDE.md) - System architecture diff --git a/backend/packages/harness/deerflow/config/app_config.py b/backend/packages/harness/deerflow/config/app_config.py index b31d396a5..a41108372 100644 --- a/backend/packages/harness/deerflow/config/app_config.py +++ b/backend/packages/harness/deerflow/config/app_config.py @@ -17,6 +17,7 @@ from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_ from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict from deerflow.config.model_config import ModelConfig from deerflow.config.run_events_config import RunEventsConfig +from deerflow.config.runtime_paths import existing_project_file from deerflow.config.sandbox_config import SandboxConfig from deerflow.config.skill_evolution_config import SkillEvolutionConfig from deerflow.config.skills_config import SkillsConfig @@ -46,8 +47,8 @@ class CircuitBreakerConfig(BaseModel): recovery_timeout_sec: int = Field(default=60, description="Time in seconds before attempting to recover the circuit") -def _default_config_candidates() -> tuple[Path, ...]: - """Return deterministic config.yaml locations without relying on cwd.""" +def _legacy_config_candidates() -> tuple[Path, ...]: + """Return source-tree config.yaml locations for monorepo compatibility.""" backend_dir = Path(__file__).resolve().parents[4] repo_root = backend_dir.parent return (backend_dir / "config.yaml", repo_root / "config.yaml") @@ -110,7 +111,8 @@ class AppConfig(BaseModel): Priority: 1. If provided `config_path` argument, use it. 2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it. - 3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`. + 3. Otherwise, search the caller project root. + 4. Finally, search legacy backend/repository-root defaults for monorepo compatibility. """ if config_path: path = Path(config_path) @@ -123,10 +125,14 @@ class AppConfig(BaseModel): raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}") return path else: - for path in _default_config_candidates(): + project_config = existing_project_file(("config.yaml",)) + if project_config is not None: + return project_config + + for path in _legacy_config_candidates(): if path.exists(): return path - raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations") + raise FileNotFoundError("`config.yaml` file not found in the project root or legacy backend/repository root locations") @classmethod def from_file(cls, config_path: str | None = None) -> Self: diff --git a/backend/packages/harness/deerflow/config/extensions_config.py b/backend/packages/harness/deerflow/config/extensions_config.py index e7a48d166..a2daa71f4 100644 --- a/backend/packages/harness/deerflow/config/extensions_config.py +++ b/backend/packages/harness/deerflow/config/extensions_config.py @@ -7,6 +7,8 @@ from typing import Any, Literal from pydantic import BaseModel, ConfigDict, Field +from deerflow.config.runtime_paths import existing_project_file + class McpOAuthConfig(BaseModel): """OAuth configuration for an MCP server (HTTP/SSE transports).""" @@ -73,8 +75,8 @@ class ExtensionsConfig(BaseModel): Priority: 1. If provided `config_path` argument, use it. 2. If provided `DEER_FLOW_EXTENSIONS_CONFIG_PATH` environment variable, use it. - 3. Otherwise, check for `extensions_config.json` in the current directory, then in the parent directory. - 4. For backward compatibility, also check for `mcp_config.json` if `extensions_config.json` is not found. + 3. Otherwise, search the caller project root for `extensions_config.json`, then `mcp_config.json`. + 4. For backward compatibility, also search legacy backend/repository-root defaults. 5. If not found, return None (extensions are optional). Args: @@ -83,8 +85,9 @@ class ExtensionsConfig(BaseModel): Resolution order: 1. If provided `config_path` argument, use it. 2. If provided `DEER_FLOW_EXTENSIONS_CONFIG_PATH` environment variable, use it. - 3. Otherwise, search backend/repository-root defaults for + 3. Otherwise, search the caller project root for `extensions_config.json`, then legacy `mcp_config.json`. + 4. Finally, search backend/repository-root defaults for monorepo compatibility. Returns: Path to the extensions config file if found, otherwise None. @@ -100,6 +103,10 @@ class ExtensionsConfig(BaseModel): raise FileNotFoundError(f"Extensions config file specified by environment variable `DEER_FLOW_EXTENSIONS_CONFIG_PATH` not found at {path}") return path else: + project_config = existing_project_file(("extensions_config.json", "mcp_config.json")) + if project_config is not None: + return project_config + backend_dir = Path(__file__).resolve().parents[4] repo_root = backend_dir.parent for path in ( diff --git a/backend/packages/harness/deerflow/config/paths.py b/backend/packages/harness/deerflow/config/paths.py index f1ce7eae1..9fa633f54 100644 --- a/backend/packages/harness/deerflow/config/paths.py +++ b/backend/packages/harness/deerflow/config/paths.py @@ -3,6 +3,8 @@ import re import shutil from pathlib import Path, PureWindowsPath +from deerflow.config.runtime_paths import runtime_home + # Virtual path prefix seen by agents inside the sandbox VIRTUAL_PATH_PREFIX = "/mnt/user-data" @@ -11,9 +13,8 @@ _SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$") def _default_local_base_dir() -> Path: - """Return the repo-local DeerFlow state directory without relying on cwd.""" - backend_dir = Path(__file__).resolve().parents[4] - return backend_dir / ".deer-flow" + """Return the caller project's writable DeerFlow state directory.""" + return runtime_home() def _validate_thread_id(thread_id: str) -> str: @@ -81,7 +82,7 @@ class Paths: BaseDir resolution (in priority order): 1. Constructor argument `base_dir` 2. DEER_FLOW_HOME environment variable - 3. Repo-local fallback derived from this module path: `{backend_dir}/.deer-flow` + 3. Caller project fallback: `{project_root}/.deer-flow` """ def __init__(self, base_dir: str | Path | None = None) -> None: diff --git a/backend/packages/harness/deerflow/config/runtime_paths.py b/backend/packages/harness/deerflow/config/runtime_paths.py new file mode 100644 index 000000000..25157106f --- /dev/null +++ b/backend/packages/harness/deerflow/config/runtime_paths.py @@ -0,0 +1,41 @@ +"""Runtime path resolution for standalone harness usage.""" + +import os +from pathlib import Path + + +def project_root() -> Path: + """Return the caller project root for runtime-owned files.""" + if env_root := os.getenv("DEER_FLOW_PROJECT_ROOT"): + root = Path(env_root).resolve() + if not root.exists(): + raise ValueError(f"DEER_FLOW_PROJECT_ROOT is set to '{env_root}', but the resolved path '{root}' does not exist.") + if not root.is_dir(): + raise ValueError(f"DEER_FLOW_PROJECT_ROOT is set to '{env_root}', but the resolved path '{root}' is not a directory.") + return root + return Path.cwd().resolve() + + +def runtime_home() -> Path: + """Return the writable DeerFlow state directory.""" + if env_home := os.getenv("DEER_FLOW_HOME"): + return Path(env_home).resolve() + return project_root() / ".deer-flow" + + +def resolve_path(value: str | os.PathLike[str], *, base: Path | None = None) -> Path: + """Resolve absolute paths as-is and relative paths against the project root.""" + path = Path(value) + if not path.is_absolute(): + path = (base or project_root()) / path + return path.resolve() + + +def existing_project_file(names: tuple[str, ...]) -> Path | None: + """Return the first existing named file under the project root.""" + root = project_root() + for name in names: + candidate = root / name + if candidate.is_file(): + return candidate + return None diff --git a/backend/packages/harness/deerflow/config/skills_config.py b/backend/packages/harness/deerflow/config/skills_config.py index 266a98b91..671b48fde 100644 --- a/backend/packages/harness/deerflow/config/skills_config.py +++ b/backend/packages/harness/deerflow/config/skills_config.py @@ -1,11 +1,9 @@ +import os from pathlib import Path from pydantic import BaseModel, Field - -def _default_repo_root() -> Path: - """Resolve the repo root without relying on the current working directory.""" - return Path(__file__).resolve().parents[5] +from deerflow.config.runtime_paths import project_root, resolve_path class SkillsConfig(BaseModel): @@ -17,7 +15,7 @@ class SkillsConfig(BaseModel): ) path: str | None = Field( default=None, - description="Path to skills directory. If not specified, defaults to ../skills relative to backend directory", + description="Path to skills directory. If not specified, defaults to skills under the caller project root.", ) container_path: str = Field( default="/mnt/skills", @@ -32,15 +30,11 @@ class SkillsConfig(BaseModel): Path to the skills directory """ if self.path: - # Use configured path (can be absolute or relative) - path = Path(self.path) - if not path.is_absolute(): - # If relative, resolve from the repo root for deterministic behavior. - path = _default_repo_root() / path - return path.resolve() - else: - # Default: /skills - return _default_repo_root() / "skills" + # Use configured path (can be absolute or relative to project root) + return resolve_path(self.path) + if env_path := os.getenv("DEER_FLOW_SKILLS_PATH"): + return resolve_path(env_path) + return project_root() / "skills" def get_skill_container_path(self, skill_name: str, category: str = "public") -> str: """ diff --git a/backend/packages/harness/deerflow/skills/storage/local_skill_storage.py b/backend/packages/harness/deerflow/skills/storage/local_skill_storage.py index 047cd6163..4b7dffde4 100644 --- a/backend/packages/harness/deerflow/skills/storage/local_skill_storage.py +++ b/backend/packages/harness/deerflow/skills/storage/local_skill_storage.py @@ -12,7 +12,7 @@ from collections.abc import Iterable from datetime import UTC, datetime from pathlib import Path -from deerflow.config.skills_config import _default_repo_root +from deerflow.config.runtime_paths import resolve_path from deerflow.skills.storage.skill_storage import SKILL_MD_FILE, SkillStorage from deerflow.skills.types import SkillCategory @@ -44,10 +44,7 @@ class LocalSkillStorage(SkillStorage): config = app_config or get_app_config() self._host_root: Path = config.skills.get_skills_path() else: - path = Path(host_path) - if not path.is_absolute(): - path = _default_repo_root() / path - self._host_root = path.resolve() + self._host_root = resolve_path(host_path) # ------------------------------------------------------------------ # Abstract operation implementations diff --git a/backend/tests/test_client_e2e.py b/backend/tests/test_client_e2e.py index 0c3872e41..4b6a62ea9 100644 --- a/backend/tests/test_client_e2e.py +++ b/backend/tests/test_client_e2e.py @@ -17,6 +17,7 @@ import json import os import uuid import zipfile +from pathlib import Path import pytest from dotenv import load_dotenv @@ -94,12 +95,18 @@ def e2e_env(tmp_path, monkeypatch): """Isolated filesystem environment for E2E tests. - DEER_FLOW_HOME → tmp_path (all thread data lands in a temp dir) + - DEER_FLOW_PROJECT_ROOT → repository root (shared skills/config assets + still resolve correctly when tests run from backend/) - Singletons reset so they pick up the new env - Title/memory/summarization disabled to avoid extra LLM calls - AppConfig built programmatically (avoids config.yaml param-name issues) """ # 1. Filesystem isolation monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path)) + monkeypatch.setenv( + "DEER_FLOW_PROJECT_ROOT", + str(Path(__file__).resolve().parents[2]), + ) monkeypatch.setattr("deerflow.config.paths._paths", None) monkeypatch.setattr("deerflow.sandbox.sandbox_provider._default_sandbox_provider", None) diff --git a/backend/tests/test_runtime_paths.py b/backend/tests/test_runtime_paths.py new file mode 100644 index 000000000..aa9e94641 --- /dev/null +++ b/backend/tests/test_runtime_paths.py @@ -0,0 +1,145 @@ +"""Runtime path policy tests for standalone harness usage.""" + +from pathlib import Path + +import pytest +import yaml + +from deerflow.config import app_config as app_config_module +from deerflow.config import extensions_config as extensions_config_module +from deerflow.config.app_config import AppConfig +from deerflow.config.extensions_config import ExtensionsConfig +from deerflow.config.paths import Paths +from deerflow.config.runtime_paths import project_root +from deerflow.config.skills_config import SkillsConfig +from deerflow.skills.storage import get_or_new_skill_storage + + +def _clear_path_env(monkeypatch): + for name in ( + "DEER_FLOW_CONFIG_PATH", + "DEER_FLOW_EXTENSIONS_CONFIG_PATH", + "DEER_FLOW_HOME", + "DEER_FLOW_PROJECT_ROOT", + "DEER_FLOW_SKILLS_PATH", + ): + monkeypatch.delenv(name, raising=False) + + +def test_default_runtime_paths_resolve_from_current_project(tmp_path: Path, monkeypatch): + _clear_path_env(monkeypatch) + monkeypatch.chdir(tmp_path) + + (tmp_path / "config.yaml").write_text( + yaml.safe_dump({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}), + encoding="utf-8", + ) + (tmp_path / "extensions_config.json").write_text('{"mcpServers": {}, "skills": {}}', encoding="utf-8") + + assert AppConfig.resolve_config_path() == tmp_path / "config.yaml" + assert ExtensionsConfig.resolve_config_path() == tmp_path / "extensions_config.json" + assert Paths().base_dir == tmp_path / ".deer-flow" + assert SkillsConfig().get_skills_path() == tmp_path / "skills" + assert get_or_new_skill_storage(skills_path=SkillsConfig().get_skills_path()).get_skills_root_path() == tmp_path / "skills" + + +def test_deer_flow_project_root_overrides_current_directory(tmp_path: Path, monkeypatch): + _clear_path_env(monkeypatch) + project_root = tmp_path / "project" + other_cwd = tmp_path / "other" + project_root.mkdir() + other_cwd.mkdir() + monkeypatch.chdir(other_cwd) + monkeypatch.setenv("DEER_FLOW_PROJECT_ROOT", str(project_root)) + + (project_root / "config.yaml").write_text( + yaml.safe_dump({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}), + encoding="utf-8", + ) + (project_root / "mcp_config.json").write_text('{"mcpServers": {}, "skills": {}}', encoding="utf-8") + + assert AppConfig.resolve_config_path() == project_root / "config.yaml" + assert ExtensionsConfig.resolve_config_path() == project_root / "mcp_config.json" + assert Paths().base_dir == project_root / ".deer-flow" + assert SkillsConfig(path="custom-skills").get_skills_path() == project_root / "custom-skills" + + +def test_deer_flow_skills_path_overrides_project_default(tmp_path: Path, monkeypatch): + _clear_path_env(monkeypatch) + monkeypatch.chdir(tmp_path) + monkeypatch.setenv("DEER_FLOW_SKILLS_PATH", "team-skills") + + assert SkillsConfig().get_skills_path() == tmp_path / "team-skills" + assert get_or_new_skill_storage(skills_path=SkillsConfig().get_skills_path()).get_skills_root_path() == tmp_path / "team-skills" + + +def test_deer_flow_project_root_must_exist(tmp_path: Path, monkeypatch): + _clear_path_env(monkeypatch) + missing_root = tmp_path / "missing" + monkeypatch.setenv("DEER_FLOW_PROJECT_ROOT", str(missing_root)) + + with pytest.raises(ValueError, match="does not exist"): + project_root() + + +def test_deer_flow_project_root_must_be_directory(tmp_path: Path, monkeypatch): + _clear_path_env(monkeypatch) + project_root_file = tmp_path / "project-root" + project_root_file.write_text("", encoding="utf-8") + monkeypatch.setenv("DEER_FLOW_PROJECT_ROOT", str(project_root_file)) + + with pytest.raises(ValueError, match="not a directory"): + project_root() + + +def test_app_config_falls_back_to_legacy_when_project_root_lacks_config(tmp_path: Path, monkeypatch): + """When DEER_FLOW_PROJECT_ROOT is unset and cwd has no config.yaml, the + legacy backend/repo-root candidates must be used for monorepo compatibility.""" + _clear_path_env(monkeypatch) + cwd = tmp_path / "cwd" + cwd.mkdir() + monkeypatch.chdir(cwd) + + legacy_backend = tmp_path / "legacy-backend" + legacy_repo = tmp_path / "legacy-repo" + legacy_backend.mkdir() + legacy_repo.mkdir() + legacy_backend_config = legacy_backend / "config.yaml" + legacy_backend_config.write_text( + yaml.safe_dump({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}), + encoding="utf-8", + ) + repo_root_config = legacy_repo / "config.yaml" + repo_root_config.write_text("", encoding="utf-8") + + monkeypatch.setattr( + app_config_module, + "_legacy_config_candidates", + lambda: (legacy_backend_config, repo_root_config), + ) + + assert AppConfig.resolve_config_path() == legacy_backend_config + + +def test_extensions_config_falls_back_to_legacy_when_project_root_lacks_file(tmp_path: Path, monkeypatch): + """ExtensionsConfig should hit the legacy backend/repo-root locations when + the caller project root has no extensions_config.json/mcp_config.json.""" + _clear_path_env(monkeypatch) + cwd = tmp_path / "cwd" + cwd.mkdir() + monkeypatch.chdir(cwd) + + fake_backend = tmp_path / "fake-backend" + fake_repo = tmp_path / "fake-repo" + fake_backend.mkdir() + fake_repo.mkdir() + legacy_extensions = fake_backend / "extensions_config.json" + legacy_extensions.write_text('{"mcpServers": {}, "skills": {}}', encoding="utf-8") + + fake_paths_module_file = fake_backend / "packages" / "harness" / "deerflow" / "config" / "extensions_config.py" + fake_paths_module_file.parent.mkdir(parents=True) + fake_paths_module_file.write_text("", encoding="utf-8") + + monkeypatch.setattr(extensions_config_module, "__file__", str(fake_paths_module_file)) + + assert ExtensionsConfig.resolve_config_path() == legacy_extensions diff --git a/backend/tests/test_skills_loader.py b/backend/tests/test_skills_loader.py index 5a03532c6..886090f71 100644 --- a/backend/tests/test_skills_loader.py +++ b/backend/tests/test_skills_loader.py @@ -14,12 +14,25 @@ def _write_skill(skill_dir: Path, name: str, description: str) -> None: (skill_dir / "SKILL.md").write_text(content, encoding="utf-8") -def test_get_skills_root_path_points_to_project_root_skills(): - """get_skills_root_path() should point to deer-flow/skills (sibling of backend/), not backend/packages/skills.""" +def test_get_skills_root_path_points_to_current_project_skills(tmp_path: Path, monkeypatch): + """get_skills_root_path() should point to the caller project skills directory.""" + monkeypatch.delenv("DEER_FLOW_SKILLS_PATH", raising=False) + monkeypatch.delenv("DEER_FLOW_PROJECT_ROOT", raising=False) + monkeypatch.chdir(tmp_path) + app_config = SimpleNamespace(skills=SkillsConfig()) path = get_or_new_skill_storage(app_config=app_config).get_skills_root_path() - assert path.name == "skills", f"Expected 'skills', got '{path.name}'" - assert (path.parent / "backend").is_dir(), f"Expected skills path's parent to be project root containing 'backend/', but got {path}" + assert path == tmp_path / "skills" + + +def test_get_skills_root_path_honors_env_override(tmp_path: Path, monkeypatch): + """DEER_FLOW_SKILLS_PATH should override the caller project skills directory.""" + skills_root = tmp_path / "team-skills" + monkeypatch.setenv("DEER_FLOW_SKILLS_PATH", str(skills_root)) + + app_config = SimpleNamespace(skills=SkillsConfig()) + path = get_or_new_skill_storage(app_config=app_config).get_skills_root_path() + assert path == skills_root def test_load_skills_discovers_nested_skills_and_sets_container_paths(tmp_path: Path): diff --git a/config.example.yaml b/config.example.yaml index 04ccd0b12..b16b4a6bb 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -2,8 +2,11 @@ # # Guidelines: # - Copy this file to `config.yaml` and customize it for your environment -# - The default path of this configuration file is `config.yaml` in the current working directory. -# However you can change it using the `DEER_FLOW_CONFIG_PATH` environment variable. +# - The default path of this configuration file is `config.yaml` in the project root. +# You can set `DEER_FLOW_PROJECT_ROOT` to define that root explicitly, or use +# `DEER_FLOW_CONFIG_PATH` to point at a specific config file. +# - Runtime state defaults to `.deer-flow` under the project root. Override it +# with `DEER_FLOW_HOME` when you need a different writable data directory. # - Environment variables are available for all field values. Example: `api_key: $OPENAI_API_KEY` # - The `use` path is a string that looks like "package_name.sub_package_name.module_name:class_name/variable_name". @@ -678,7 +681,8 @@ sandbox: skills: # Path to skills directory on the host (relative to project root or absolute) - # Default: ../skills (relative to backend directory) + # Default: skills under the project root + # Override with DEER_FLOW_SKILLS_PATH when this field is omitted. # Uncomment to customize: # path: /absolute/path/to/custom/skills diff --git a/docker/docker-compose-dev.yaml b/docker/docker-compose-dev.yaml index 8fb95124d..6d00d71ff 100644 --- a/docker/docker-compose-dev.yaml +++ b/docker/docker-compose-dev.yaml @@ -157,6 +157,7 @@ services: working_dir: /app environment: - CI=true + - DEER_FLOW_PROJECT_ROOT=/app - DEER_FLOW_HOME=/app/backend/.deer-flow - DEER_FLOW_CHANNELS_LANGGRAPH_URL=${DEER_FLOW_CHANNELS_LANGGRAPH_URL:-http://gateway:8001/api} - DEER_FLOW_CHANNELS_GATEWAY_URL=${DEER_FLOW_CHANNELS_GATEWAY_URL:-http://gateway:8001} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 82cb62425..8d82980d3 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -8,9 +8,11 @@ # - provisioner: (optional) Sandbox provisioner for Kubernetes mode # # Key environment variables (set via environment/.env or scripts/deploy.sh): -# DEER_FLOW_HOME — runtime data dir, default $REPO_ROOT/backend/.deer-flow +# DEER_FLOW_PROJECT_ROOT — project root for relative runtime paths +# DEER_FLOW_HOME — runtime data dir, default .deer-flow under $DEER_FLOW_PROJECT_ROOT (or cwd) # DEER_FLOW_CONFIG_PATH — path to config.yaml # DEER_FLOW_EXTENSIONS_CONFIG_PATH — path to extensions_config.json +# DEER_FLOW_SKILLS_PATH — skills dir, default $DEER_FLOW_PROJECT_ROOT/skills # DEER_FLOW_DOCKER_SOCKET — Docker socket path, default /var/run/docker.sock # DEER_FLOW_REPO_ROOT — repo root (used for skills host path in DooD) # BETTER_AUTH_SECRET — required for frontend auth/session security @@ -93,6 +95,7 @@ services: working_dir: /app environment: - CI=true + - DEER_FLOW_PROJECT_ROOT=/app - DEER_FLOW_HOME=/app/backend/.deer-flow - DEER_FLOW_CONFIG_PATH=/app/backend/config.yaml - DEER_FLOW_EXTENSIONS_CONFIG_PATH=/app/backend/extensions_config.json From 487c1d939fb150d107bf41f9b8a6508e06454610 Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Fri, 1 May 2026 16:21:10 +0200 Subject: [PATCH 03/11] fix(subagents): use model override for tools and middleware (#2641) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(subagents): use model override for tools and middleware * fix(config): resolve effective subagent model * fix(subagents): defer app config loading * fix(subagents): fully defer config.yaml load in executor __init__ The previous attempt only relocated the explicit get_app_config() call, but left resolve_subagent_model_name(...) running eagerly in __init__. That helper has its own internal get_app_config() fallback, which still fired when both app_config and parent_model were None and config.model == "inherit" — exactly the path unit tests hit, breaking 21 tests in CI with FileNotFoundError: config.yaml. Skip the eager resolve in __init__ when it would require loading the config file, and defer to _create_agent (which already has the app_config or get_app_config() fallback). --- .../tool_error_handling_middleware.py | 25 ++++- .../harness/deerflow/subagents/config.py | 25 +++++ .../harness/deerflow/subagents/executor.py | 45 ++++----- .../deerflow/tools/builtins/task_tool.py | 16 +++- backend/tests/test_subagent_executor.py | 1 + backend/tests/test_task_tool_core_logic.py | 55 ++++++++++- .../test_tool_error_handling_middleware.py | 91 +++++++++++++++++-- 7 files changed, 219 insertions(+), 39 deletions(-) diff --git a/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py index db0230cf9..4393bd360 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py @@ -136,11 +136,32 @@ def build_lead_runtime_middlewares(*, app_config: AppConfig, lazy_init: bool = T ) -def build_subagent_runtime_middlewares(*, app_config: AppConfig, lazy_init: bool = True) -> list[AgentMiddleware]: +def build_subagent_runtime_middlewares( + *, + app_config: AppConfig | None = None, + model_name: str | None = None, + lazy_init: bool = True, +) -> list[AgentMiddleware]: """Middlewares shared by subagent runtime before subagent-only middlewares.""" - return _build_runtime_middlewares( + if app_config is None: + from deerflow.config import get_app_config + + app_config = get_app_config() + + middlewares = _build_runtime_middlewares( app_config=app_config, include_uploads=False, include_dangling_tool_call_patch=True, lazy_init=lazy_init, ) + + if model_name is None and app_config.models: + model_name = app_config.models[0].name + + model_config = app_config.get_model_config(model_name) if model_name else None + if model_config is not None and model_config.supports_vision: + from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware + + middlewares.append(ViewImageMiddleware()) + + return middlewares diff --git a/backend/packages/harness/deerflow/subagents/config.py b/backend/packages/harness/deerflow/subagents/config.py index a2c961b9d..b0b094e28 100644 --- a/backend/packages/harness/deerflow/subagents/config.py +++ b/backend/packages/harness/deerflow/subagents/config.py @@ -1,6 +1,10 @@ """Subagent configuration definitions.""" from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from deerflow.config.app_config import AppConfig @dataclass @@ -29,3 +33,24 @@ class SubagentConfig: model: str = "inherit" max_turns: int = 50 timeout_seconds: int = 900 + + +def _default_model_name(app_config: "AppConfig") -> str: + if not app_config.models: + raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.") + return app_config.models[0].name + + +def resolve_subagent_model_name(config: SubagentConfig, parent_model: str | None, *, app_config: "AppConfig | None" = None) -> str: + """Resolve the effective model name a subagent should use.""" + if config.model != "inherit": + return config.model + + if parent_model is not None: + return parent_model + + if app_config is None: + from deerflow.config import get_app_config + + app_config = get_app_config() + return _default_model_name(app_config) diff --git a/backend/packages/harness/deerflow/subagents/executor.py b/backend/packages/harness/deerflow/subagents/executor.py index 539244af8..ab850ede7 100644 --- a/backend/packages/harness/deerflow/subagents/executor.py +++ b/backend/packages/harness/deerflow/subagents/executor.py @@ -20,9 +20,10 @@ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langchain_core.runnables import RunnableConfig from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadState +from deerflow.config import get_app_config from deerflow.config.app_config import AppConfig from deerflow.models import create_chat_model -from deerflow.subagents.config import SubagentConfig +from deerflow.subagents.config import SubagentConfig, resolve_subagent_model_name logger = logging.getLogger(__name__) @@ -213,21 +214,6 @@ def _filter_tools( return filtered -def _get_model_name(config: SubagentConfig, parent_model: str | None) -> str | None: - """Resolve the model name for a subagent. - - Args: - config: Subagent configuration. - parent_model: The parent agent's model name. - - Returns: - Model name to use, or None to use default. - """ - if config.model == "inherit": - return parent_model - return config.model - - class SubagentExecutor: """Executor for running subagents.""" @@ -247,9 +233,9 @@ class SubagentExecutor: Args: config: Subagent configuration. tools: List of all available tools (will be filtered). - app_config: Resolved AppConfig; threaded into middleware factories - at agent-build time. When None, ``_create_agent`` falls back to - ``get_app_config()`` (matches the lead-agent factory's pattern). + app_config: Resolved AppConfig. When None, ``_create_agent`` falls + back to ``get_app_config()`` (matches the lead-agent factory's + pattern). parent_model: The parent agent's model name for inheritance. sandbox_state: Sandbox state from parent agent. thread_data: Thread data from parent agent. @@ -259,6 +245,13 @@ class SubagentExecutor: self.config = config self.app_config = app_config self.parent_model = parent_model + # Resolve eagerly only when it does not require loading config.yaml; otherwise defer + # to _create_agent (which already loads app_config) so unit tests can construct + # executors without a config file present. + if config.model != "inherit" or parent_model is not None or app_config is not None: + self.model_name: str | None = resolve_subagent_model_name(config, parent_model, app_config=app_config) + else: + self.model_name = None self.sandbox_state = sandbox_state self.thread_data = thread_data self.thread_id = thread_id @@ -276,17 +269,15 @@ class SubagentExecutor: def _create_agent(self): """Create the agent instance.""" - # Mirror lead-agent factory pattern: prefer explicit app_config, - # fall back to ambient lookup at agent-build time. - from deerflow.config import get_app_config - - resolved_app_config = self.app_config or get_app_config() - model_name = _get_model_name(self.config, self.parent_model) - model = create_chat_model(name=model_name, thinking_enabled=False, app_config=resolved_app_config) + app_config = self.app_config or get_app_config() + if self.model_name is None: + self.model_name = resolve_subagent_model_name(self.config, self.parent_model, app_config=app_config) + model = create_chat_model(name=self.model_name, thinking_enabled=False, app_config=app_config) from deerflow.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares - middlewares = build_subagent_runtime_middlewares(app_config=resolved_app_config, lazy_init=True) + # Reuse shared middleware composition with lead agent. + middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name=self.model_name, lazy_init=True) return create_agent( model=model, diff --git a/backend/packages/harness/deerflow/tools/builtins/task_tool.py b/backend/packages/harness/deerflow/tools/builtins/task_tool.py index 59613272c..42062f0aa 100644 --- a/backend/packages/harness/deerflow/tools/builtins/task_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/task_tool.py @@ -11,9 +11,16 @@ from langgraph.config import get_stream_writer from langgraph.typing import ContextT from deerflow.agents.thread_state import ThreadState +from deerflow.config import get_app_config from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config -from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result, request_cancel_background_task +from deerflow.subagents.config import resolve_subagent_model_name +from deerflow.subagents.executor import ( + SubagentStatus, + cleanup_background_task, + get_background_task_result, + request_cancel_background_task, +) logger = logging.getLogger(__name__) @@ -129,14 +136,19 @@ async def task_tool( # Inherit parent agent's tool_groups so subagents respect the same restrictions parent_tool_groups = metadata.get("tool_groups") + app_config = None + if config.model == "inherit" and parent_model is None: + app_config = get_app_config() + effective_model = resolve_subagent_model_name(config, parent_model, app_config=app_config) # Subagents should not have subagent tools enabled (prevent recursive nesting) - tools = get_available_tools(model_name=parent_model, groups=parent_tool_groups, subagent_enabled=False) + tools = get_available_tools(model_name=effective_model, groups=parent_tool_groups, subagent_enabled=False) # Create executor executor = SubagentExecutor( config=config, tools=tools, + app_config=app_config, parent_model=parent_model, sandbox_state=sandbox_state, thread_data=thread_data, diff --git a/backend/tests/test_subagent_executor.py b/backend/tests/test_subagent_executor.py index 774bd2dd9..1b2251444 100644 --- a/backend/tests/test_subagent_executor.py +++ b/backend/tests/test_subagent_executor.py @@ -258,6 +258,7 @@ class TestAgentConstruction: } assert captured["middlewares"] == { "app_config": app_config, + "model_name": "parent-model", "lazy_init": True, } assert captured["agent"]["model"] is model diff --git a/backend/tests/test_task_tool_core_logic.py b/backend/tests/test_task_tool_core_logic.py index 1ae008df2..d436f1725 100644 --- a/backend/tests/test_task_tool_core_logic.py +++ b/backend/tests/test_task_tool_core_logic.py @@ -223,6 +223,56 @@ def test_task_tool_propagates_tool_groups_to_subagent(monkeypatch): get_available_tools.assert_called_once_with(model_name="ark-model", groups=parent_tool_groups, subagent_enabled=False) +def test_task_tool_uses_subagent_model_override_for_tool_loading(monkeypatch): + """Subagent model overrides should drive model-gated tool loading.""" + config = SubagentConfig( + name="general-purpose", + description="General helper", + system_prompt="Base system prompt", + model="vision-subagent-model", + max_turns=50, + timeout_seconds=10, + ) + runtime = _make_runtime() + runtime.config["metadata"]["model_name"] = "parent-text-model" + events = [] + get_available_tools = MagicMock(return_value=[]) + + class DummyExecutor: + def __init__(self, **kwargs): + pass + + def execute_async(self, prompt, task_id=None): + return task_id or "generated-task-id" + + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor) + monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config) + monkeypatch.setattr( + task_tool_module, + "get_background_task_result", + lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"), + ) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) + monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools) + + output = _run_task_tool( + runtime=runtime, + description="inspect image", + prompt="inspect the uploaded image", + subagent_type="general-purpose", + tool_call_id="tc-issue-2543", + ) + + assert output == "Task Succeeded. Result: done" + get_available_tools.assert_called_once_with( + model_name="vision-subagent-model", + groups=None, + subagent_enabled=False, + ) + + def test_task_tool_inherits_parent_skill_allowlist_for_default_subagent(monkeypatch): config = _make_subagent_config() runtime = _make_runtime() @@ -371,6 +421,7 @@ def test_task_tool_runtime_none_passes_groups_none(monkeypatch): monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools) + monkeypatch.setattr(task_tool_module, "get_app_config", lambda: SimpleNamespace(models=[SimpleNamespace(name="default-model")])) output = _run_task_tool( runtime=None, @@ -381,8 +432,8 @@ def test_task_tool_runtime_none_passes_groups_none(monkeypatch): ) assert output == "Task Succeeded. Result: ok" - # runtime is None → metadata is empty dict → groups=None - get_available_tools.assert_called_once_with(model_name=None, groups=None, subagent_enabled=False) + # runtime is None -> metadata is empty dict -> groups=None, model falls back to app default. + get_available_tools.assert_called_once_with(model_name="default-model", groups=None, subagent_enabled=False) config = _make_subagent_config() events = [] diff --git a/backend/tests/test_tool_error_handling_middleware.py b/backend/tests/test_tool_error_handling_middleware.py index 4add370f0..2c28dac35 100644 --- a/backend/tests/test_tool_error_handling_middleware.py +++ b/backend/tests/test_tool_error_handling_middleware.py @@ -9,11 +9,20 @@ from deerflow.agents.middlewares.tool_error_handling_middleware import ( ToolErrorHandlingMiddleware, build_subagent_runtime_middlewares, ) +from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware from deerflow.config.app_config import AppConfig, CircuitBreakerConfig from deerflow.config.guardrails_config import GuardrailsConfig +from deerflow.config.model_config import ModelConfig from deerflow.config.sandbox_config import SandboxConfig +def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"): + tool_call = {"name": name} + if tool_call_id is not None: + tool_call["id"] = tool_call_id + return SimpleNamespace(tool_call=tool_call) + + def _module(name: str, **attrs): module = ModuleType(name) for key, value in attrs.items(): @@ -21,19 +30,62 @@ def _module(name: str, **attrs): return module -def _make_app_config() -> AppConfig: +def _make_app_config(*, supports_vision: bool = False) -> AppConfig: return AppConfig( + models=[ + ModelConfig( + name="test-model", + display_name="test-model", + description=None, + use="langchain_openai:ChatOpenAI", + model="test-model", + supports_vision=supports_vision, + ) + ], sandbox=SandboxConfig(use="test"), guardrails=GuardrailsConfig(enabled=False), circuit_breaker=CircuitBreakerConfig(failure_threshold=7, recovery_timeout_sec=11), ) -def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"): - tool_call = {"name": name} - if tool_call_id is not None: - tool_call["id"] = tool_call_id - return SimpleNamespace(tool_call=tool_call) +def _stub_runtime_middleware_imports(monkeypatch: pytest.MonkeyPatch) -> None: + class FakeMiddleware: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + class FakeLLMErrorHandlingMiddleware: + def __init__(self, *, app_config): + self.app_config = app_config + + monkeypatch.setitem( + sys.modules, + "deerflow.agents.middlewares.llm_error_handling_middleware", + _module( + "deerflow.agents.middlewares.llm_error_handling_middleware", + LLMErrorHandlingMiddleware=FakeLLMErrorHandlingMiddleware, + ), + ) + monkeypatch.setitem( + sys.modules, + "deerflow.agents.middlewares.thread_data_middleware", + _module("deerflow.agents.middlewares.thread_data_middleware", ThreadDataMiddleware=FakeMiddleware), + ) + monkeypatch.setitem( + sys.modules, + "deerflow.sandbox.middleware", + _module("deerflow.sandbox.middleware", SandboxMiddleware=FakeMiddleware), + ) + monkeypatch.setitem( + sys.modules, + "deerflow.agents.middlewares.dangling_tool_call_middleware", + _module("deerflow.agents.middlewares.dangling_tool_call_middleware", DanglingToolCallMiddleware=FakeMiddleware), + ) + monkeypatch.setitem( + sys.modules, + "deerflow.agents.middlewares.sandbox_audit_middleware", + _module("deerflow.agents.middlewares.sandbox_audit_middleware", SandboxAuditMiddleware=FakeMiddleware), + ) def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware(monkeypatch: pytest.MonkeyPatch): @@ -166,3 +218,30 @@ async def test_awrap_tool_call_reraises_graph_interrupt(): with pytest.raises(GraphInterrupt): await middleware.awrap_tool_call(req, _interrupt) + + +def test_subagent_runtime_middlewares_include_view_image_for_vision_model(monkeypatch): + app_config = _make_app_config(supports_vision=True) + _stub_runtime_middleware_imports(monkeypatch) + + middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name="test-model") + + assert any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares) + + +def test_subagent_runtime_middlewares_include_view_image_for_default_vision_model(monkeypatch): + app_config = _make_app_config(supports_vision=True) + _stub_runtime_middleware_imports(monkeypatch) + + middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name=None) + + assert any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares) + + +def test_subagent_runtime_middlewares_skip_view_image_for_text_model(monkeypatch): + app_config = _make_app_config(supports_vision=False) + _stub_runtime_middleware_imports(monkeypatch) + + middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name="test-model") + + assert not any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares) From 189b82405c2b5e65652fe474ad9e1b334277a606 Mon Sep 17 00:00:00 2001 From: Willem Jiang Date: Fri, 1 May 2026 22:27:02 +0800 Subject: [PATCH 04/11] fix(sandbox): pass no_change_timeout to exec_command to prevent 120s premature termination (#2685) * fix(sandbox): pass no_change_timeout to exec_command to prevent 120s premature termination The agent_sandbox library's shell API defaults no_change_timeout to 120 seconds. When AioSandbox.execute_command() called exec_command() without this parameter, commands producing no output for 120s would return with NO_CHANGE_TIMEOUT status even though the script was still running. Pass no_change_timeout=600 to all exec_command calls (matching the client-level HTTP timeout) so long-running commands are not cut short. Fixes #2668 * test(sandbox): add assertions for no_change_timeout in execute_command and list_dir Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/2f37bc72-0826-4443-a6ba-e5b78c22fb5a Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> --- .../community/aio_sandbox/aio_sandbox.py | 12 +++-- backend/tests/test_aio_sandbox.py | 52 +++++++++++++++++++ 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox.py b/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox.py index b6041f7ed..97da4144d 100644 --- a/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox.py +++ b/backend/packages/harness/deerflow/community/aio_sandbox/aio_sandbox.py @@ -48,6 +48,12 @@ class AioSandbox(Sandbox): self._home_dir = context.home_dir return self._home_dir + # Default no_change_timeout for exec_command (seconds). Matches the + # client-level timeout so that long-running commands which produce no + # output are not prematurely terminated by the sandbox's built-in 120 s + # default. + _DEFAULT_NO_CHANGE_TIMEOUT = 600 + def execute_command(self, command: str) -> str: """Execute a shell command in the sandbox. @@ -66,13 +72,13 @@ class AioSandbox(Sandbox): """ with self._lock: try: - result = self._client.shell.exec_command(command=command) + result = self._client.shell.exec_command(command=command, no_change_timeout=self._DEFAULT_NO_CHANGE_TIMEOUT) output = result.data.output if result.data else "" if output and _ERROR_OBSERVATION_SIGNATURE in output: logger.warning("ErrorObservation detected in sandbox output, retrying with a fresh session") fresh_id = str(uuid.uuid4()) - result = self._client.shell.exec_command(command=command, id=fresh_id) + result = self._client.shell.exec_command(command=command, id=fresh_id, no_change_timeout=self._DEFAULT_NO_CHANGE_TIMEOUT) output = result.data.output if result.data else "" return output if output else "(no output)" @@ -108,7 +114,7 @@ class AioSandbox(Sandbox): """ with self._lock: try: - result = self._client.shell.exec_command(command=f"find {shlex.quote(path)} -maxdepth {max_depth} -type f -o -type d 2>/dev/null | head -500") + result = self._client.shell.exec_command(command=f"find {shlex.quote(path)} -maxdepth {max_depth} -type f -o -type d 2>/dev/null | head -500", no_change_timeout=self._DEFAULT_NO_CHANGE_TIMEOUT) output = result.data.output if result.data else "" if output: return [line.strip() for line in output.strip().split("\n") if line.strip()] diff --git a/backend/tests/test_aio_sandbox.py b/backend/tests/test_aio_sandbox.py index 789fbde20..c6acb46eb 100644 --- a/backend/tests/test_aio_sandbox.py +++ b/backend/tests/test_aio_sandbox.py @@ -133,6 +133,58 @@ class TestListDirSerialization: assert lock_was_held == [True], "list_dir must hold the lock during exec_command" +class TestNoChangeTimeout: + """Verify that no_change_timeout is forwarded to every exec_command call.""" + + def test_execute_command_passes_no_change_timeout(self, sandbox): + """execute_command should pass no_change_timeout to exec_command.""" + calls = [] + + def mock_exec(command, **kwargs): + calls.append(kwargs) + return SimpleNamespace(data=SimpleNamespace(output="ok")) + + sandbox._client.shell.exec_command = mock_exec + + sandbox.execute_command("echo hello") + + assert len(calls) == 1 + assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT + + def test_retry_passes_no_change_timeout(self, sandbox): + """The ErrorObservation retry path should also pass no_change_timeout.""" + calls = [] + + def mock_exec(command, **kwargs): + calls.append(kwargs) + if len(calls) == 1: + return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'")) + return SimpleNamespace(data=SimpleNamespace(output="ok")) + + sandbox._client.shell.exec_command = mock_exec + + sandbox.execute_command("echo hello") + + assert len(calls) == 2 + assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT + assert calls[1].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT + + def test_list_dir_passes_no_change_timeout(self, sandbox): + """list_dir should pass no_change_timeout to exec_command.""" + calls = [] + + def mock_exec(command, **kwargs): + calls.append(kwargs) + return SimpleNamespace(data=SimpleNamespace(output="/a\n/b")) + + sandbox._client.shell.exec_command = mock_exec + + sandbox.list_dir("/test") + + assert len(calls) == 1 + assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT + + class TestConcurrentFileWrites: """Verify file write paths do not lose concurrent updates.""" From 8ba01dfd836525f45a7faa40a9c30a86dac86995 Mon Sep 17 00:00:00 2001 From: greatmengqi Date: Sat, 2 May 2026 06:37:49 +0800 Subject: [PATCH 05/11] refactor: thread app_config through lead and subagent task path (#2666) * refactor: thread app config through lead prompt * fix: honor explicit app config across runtime paths * style: format subagent executor tests * fix: thread resolved app config and guard subagents-only fallback Address two PR review findings: 1. _create_summarization_middleware passed the original (possibly None) app_config into create_chat_model, forcing the model factory back to ambient get_app_config() and risking config drift between the middleware's resolved view and the model's view. Pass the resolved AppConfig instance through end-to-end. 2. get_available_subagent_names accepted Any-typed config and forwarded it to is_host_bash_allowed, which reads ``.sandbox``. A SubagentsAppConfig (also accepted upstream as a sum-type input) has no ``.sandbox`` attribute and would be silently treated as "no sandbox configured", incorrectly disabling the bash subagent. Guard on hasattr and fall back to ambient lookup otherwise. Adds regression tests for both paths. * chore: simplify hasattr guard and tighten regression tests - Collapse if/else into ternary in get_available_subagent_names; hasattr(None, ...) is False so the explicit None check was redundant. - Drop comments that narrate the change rather than explain non-obvious WHY (test names already convey intent). - Replace stringly-typed sentinel "no-arg" in regression test with direct args tuple comparison. --------- Co-authored-by: greatmengqi --- .../deerflow/agents/lead_agent/agent.py | 26 ++-- .../deerflow/agents/lead_agent/prompt.py | 98 +++++++----- .../agents/middlewares/memory_middleware.py | 12 +- .../agents/middlewares/title_middleware.py | 35 ++++- .../harness/deerflow/config/app_config.py | 3 +- .../harness/deerflow/runtime/runs/worker.py | 35 ++++- .../harness/deerflow/subagents/executor.py | 22 ++- .../harness/deerflow/subagents/registry.py | 66 ++++---- .../deerflow/tools/builtins/task_tool.py | 69 ++++++--- .../packages/harness/deerflow/tools/tools.py | 8 +- backend/tests/test_invoke_acp_agent_tool.py | 30 ++++ .../tests/test_lead_agent_model_resolution.py | 105 ++++++++++++- backend/tests/test_lead_agent_prompt.py | 141 +++++++++++++++++- backend/tests/test_lead_agent_skills.py | 4 + backend/tests/test_run_worker_rollback.py | 73 ++++++++- backend/tests/test_subagent_executor.py | 39 ++++- backend/tests/test_subagent_skills_config.py | 44 ++++++ backend/tests/test_task_tool_core_logic.py | 83 ++++++++++- .../tests/test_title_middleware_core_logic.py | 29 ++++ 19 files changed, 769 insertions(+), 153 deletions(-) diff --git a/backend/packages/harness/deerflow/agents/lead_agent/agent.py b/backend/packages/harness/deerflow/agents/lead_agent/agent.py index 12fedd5b2..a908e9f96 100644 --- a/backend/packages/harness/deerflow/agents/lead_agent/agent.py +++ b/backend/packages/harness/deerflow/agents/lead_agent/agent.py @@ -19,8 +19,6 @@ from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddlewar from deerflow.agents.thread_state import ThreadState from deerflow.config.agents_config import load_agent_config, validate_agent_name from deerflow.config.app_config import AppConfig, get_app_config -from deerflow.config.memory_config import get_memory_config -from deerflow.config.summarization_config import get_summarization_config from deerflow.models import create_chat_model logger = logging.getLogger(__name__) @@ -52,7 +50,8 @@ def _resolve_model_name(requested_model_name: str | None = None, *, app_config: def _create_summarization_middleware(*, app_config: AppConfig | None = None) -> DeerFlowSummarizationMiddleware | None: """Create and configure the summarization middleware from config.""" - config = get_summarization_config() + resolved_app_config = app_config or get_app_config() + config = resolved_app_config.summarization if not config.enabled: return None @@ -73,9 +72,9 @@ def _create_summarization_middleware(*, app_config: AppConfig | None = None) -> # as middleware rather than lead_agent (SummarizationMiddleware is a # LangChain built-in, so we tag the model at creation time). if config.model_name: - model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=app_config) + model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=resolved_app_config) else: - model = create_chat_model(thinking_enabled=False, app_config=app_config) + model = create_chat_model(thinking_enabled=False, app_config=resolved_app_config) model = model.with_config(tags=["middleware:summarize"]) # Prepare kwargs @@ -92,18 +91,13 @@ def _create_summarization_middleware(*, app_config: AppConfig | None = None) -> kwargs["summary_prompt"] = config.summary_prompt hooks: list[BeforeSummarizationHook] = [] - if get_memory_config().enabled: + if resolved_app_config.memory.enabled: hooks.append(memory_flush_hook) # The logic below relies on two assumptions holding true: this factory is # the sole entry point for DeerFlowSummarizationMiddleware, and the runtime # config is not expected to change after startup. - try: - resolved_app_config = app_config or get_app_config() - skills_container_path = resolved_app_config.skills.container_path or "/mnt/skills" - except Exception: - logger.exception("Failed to resolve skills container path; falling back to default") - skills_container_path = "/mnt/skills" + skills_container_path = resolved_app_config.skills.container_path or "/mnt/skills" return DeerFlowSummarizationMiddleware( **kwargs, @@ -279,10 +273,10 @@ def _build_middlewares( middlewares.append(TokenUsageMiddleware()) # Add TitleMiddleware - middlewares.append(TitleMiddleware()) + middlewares.append(TitleMiddleware(app_config=resolved_app_config)) # Add MemoryMiddleware (after TitleMiddleware) - middlewares.append(MemoryMiddleware(agent_name=agent_name)) + middlewares.append(MemoryMiddleware(agent_name=agent_name, memory_config=resolved_app_config.memory)) # Add ViewImageMiddleware only if the current model supports vision. # Use the resolved runtime model_name from make_lead_agent to avoid stale config values. @@ -316,7 +310,9 @@ def _build_middlewares( def make_lead_agent(config: RunnableConfig): """LangGraph graph factory; keep the signature compatible with LangGraph Server.""" - return _make_lead_agent(config, app_config=get_app_config()) + runtime_config = _get_runtime_config(config) + runtime_app_config = runtime_config.get("app_config") + return _make_lead_agent(config, app_config=runtime_app_config or get_app_config()) def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig): diff --git a/backend/packages/harness/deerflow/agents/lead_agent/prompt.py b/backend/packages/harness/deerflow/agents/lead_agent/prompt.py index 9b6fd9cd4..b02c86344 100644 --- a/backend/packages/harness/deerflow/agents/lead_agent/prompt.py +++ b/backend/packages/harness/deerflow/agents/lead_agent/prompt.py @@ -158,7 +158,7 @@ Skip simple one-off tasks. """ -def _build_available_subagents_description(available_names: list[str], bash_available: bool) -> str: +def _build_available_subagents_description(available_names: list[str], bash_available: bool, *, app_config: AppConfig | None = None) -> str: """Dynamically build subagent type descriptions from registry. Mirrors Codex's pattern where agent_type_description is dynamically generated @@ -180,7 +180,7 @@ def _build_available_subagents_description(available_names: list[str], bash_avai if name in builtin_descriptions: lines.append(f"- **{name}**: {builtin_descriptions[name]}") else: - config = get_subagent_config(name) + config = get_subagent_config(name, app_config=app_config) if config is not None: desc = config.description.split("\n")[0].strip() # First line only for brevity lines.append(f"- **{name}**: {desc}") @@ -188,7 +188,7 @@ def _build_available_subagents_description(available_names: list[str], bash_avai return "\n".join(lines) -def _build_subagent_section(max_concurrent: int) -> str: +def _build_subagent_section(max_concurrent: int, *, app_config: AppConfig | None = None) -> str: """Build the subagent system prompt section with dynamic concurrency limit. Args: @@ -198,12 +198,12 @@ def _build_subagent_section(max_concurrent: int) -> str: Formatted subagent section string. """ n = max_concurrent - available_names = get_available_subagent_names() + available_names = get_available_subagent_names(app_config=app_config) if app_config is not None else get_available_subagent_names() bash_available = "bash" in available_names # Dynamically build subagent type descriptions from registry (aligned with Codex's # agent_type_description pattern where all registered roles are listed in the tool spec). - available_subagents = _build_available_subagents_description(available_names, bash_available) + available_subagents = _build_available_subagents_description(available_names, bash_available, app_config=app_config) direct_tool_examples = "bash, ls, read_file, web_search, etc." if bash_available else "ls, read_file, web_search, etc." direct_execution_example = ( '# User asks: "Run the tests"\n# Thinking: Cannot decompose into parallel sub-tasks\n# → Execute directly\n\nbash("npm test") # Direct execution, not task()' @@ -530,21 +530,28 @@ combined with a FastAPI gateway for REST API access [citation:FastAPI](https://f """ -def _get_memory_context(agent_name: str | None = None) -> str: +def _get_memory_context(agent_name: str | None = None, *, app_config: AppConfig | None = None) -> str: """Get memory context for injection into system prompt. Args: agent_name: If provided, loads per-agent memory. If None, loads global memory. + app_config: Explicit application config. When provided, memory options + are read from this value instead of the global config singleton. Returns: Formatted memory context string wrapped in XML tags, or empty string if disabled. """ try: from deerflow.agents.memory import format_memory_for_injection, get_memory_data - from deerflow.config.memory_config import get_memory_config from deerflow.runtime.user_context import get_effective_user_id - config = get_memory_config() + if app_config is None: + from deerflow.config.memory_config import get_memory_config + + config = get_memory_config() + else: + config = app_config.memory + if not config.enabled or not config.injection_enabled: return "" @@ -558,8 +565,8 @@ def _get_memory_context(agent_name: str | None = None) -> str: {memory_content} """ - except Exception as e: - logger.error("Failed to load memory context: %s", e) + except Exception: + logger.exception("Failed to load memory context") return "" @@ -599,15 +606,20 @@ def get_skills_prompt_section(available_skills: set[str] | None = None, *, app_c """Generate the skills prompt section with available skills list.""" skills = _get_enabled_skills_for_config(app_config) - try: - from deerflow.config import get_app_config + if app_config is None: + try: + from deerflow.config import get_app_config - config = app_config or get_app_config() + config = get_app_config() + container_base_path = config.skills.container_path + skill_evolution_enabled = config.skill_evolution.enabled + except Exception: + container_base_path = "/mnt/skills" + skill_evolution_enabled = False + else: + config = app_config container_base_path = config.skills.container_path skill_evolution_enabled = config.skill_evolution.enabled - except Exception: - container_base_path = "/mnt/skills" - skill_evolution_enabled = False if not skills and not skill_evolution_enabled: return "" @@ -640,13 +652,17 @@ def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) -> """ from deerflow.tools.builtins.tool_search import get_deferred_registry - try: - from deerflow.config import get_app_config + if app_config is None: + try: + from deerflow.config import get_app_config - config = app_config or get_app_config() - if not config.tool_search.enabled: + config = get_app_config() + except Exception: return "" - except Exception: + else: + config = app_config + + if not config.tool_search.enabled: return "" registry = get_deferred_registry() @@ -657,15 +673,19 @@ def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) -> return f"\n{names}\n" -def _build_acp_section() -> str: +def _build_acp_section(*, app_config: AppConfig | None = None) -> str: """Build the ACP agent prompt section, only if ACP agents are configured.""" - try: - from deerflow.config.acp_config import get_acp_agents + if app_config is None: + try: + from deerflow.config.acp_config import get_acp_agents - agents = get_acp_agents() - if not agents: + agents = get_acp_agents() + except Exception: return "" - except Exception: + else: + agents = getattr(app_config, "acp_agents", {}) or {} + + if not agents: return "" return ( @@ -679,14 +699,18 @@ def _build_acp_section() -> str: def _build_custom_mounts_section(*, app_config: AppConfig | None = None) -> str: """Build a prompt section for explicitly configured sandbox mounts.""" - try: - from deerflow.config import get_app_config + if app_config is None: + try: + from deerflow.config import get_app_config - config = app_config or get_app_config() - mounts = config.sandbox.mounts or [] - except Exception: - logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt") - return "" + config = get_app_config() + except Exception: + logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt") + return "" + else: + config = app_config + + mounts = config.sandbox.mounts or [] if not mounts: return "" @@ -709,11 +733,11 @@ def apply_prompt_template( app_config: AppConfig | None = None, ) -> str: # Get memory context - memory_context = _get_memory_context(agent_name) + memory_context = _get_memory_context(agent_name, app_config=app_config) # Include subagent section only if enabled (from runtime parameter) n = max_concurrent_subagents - subagent_section = _build_subagent_section(n) if subagent_enabled else "" + subagent_section = _build_subagent_section(n, app_config=app_config) if subagent_enabled else "" # Add subagent reminder to critical_reminders if enabled subagent_reminder = ( @@ -740,7 +764,7 @@ def apply_prompt_template( deferred_tools_section = get_deferred_tools_prompt_section(app_config=app_config) # Build ACP agent section only if ACP agents are configured - acp_section = _build_acp_section() + acp_section = _build_acp_section(app_config=app_config) custom_mounts_section = _build_custom_mounts_section(app_config=app_config) acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section) diff --git a/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py index 059f8ffc2..ae5f9cfbb 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/memory_middleware.py @@ -1,7 +1,7 @@ """Middleware for memory mechanism.""" import logging -from typing import override +from typing import TYPE_CHECKING, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware @@ -13,6 +13,9 @@ from deerflow.agents.memory.queue import get_memory_queue from deerflow.config.memory_config import get_memory_config from deerflow.runtime.user_context import get_effective_user_id +if TYPE_CHECKING: + from deerflow.config.memory_config import MemoryConfig + logger = logging.getLogger(__name__) @@ -34,14 +37,17 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]): state_schema = MemoryMiddlewareState - def __init__(self, agent_name: str | None = None): + def __init__(self, agent_name: str | None = None, *, memory_config: "MemoryConfig | None" = None): """Initialize the MemoryMiddleware. Args: agent_name: If provided, memory is stored per-agent. If None, uses global memory. + memory_config: Explicit memory config. When omitted, legacy global + config fallback is used. """ super().__init__() self._agent_name = agent_name + self._memory_config = memory_config @override def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None: @@ -54,7 +60,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]): Returns: None (no state changes needed from this middleware). """ - config = get_memory_config() + config = self._memory_config or get_memory_config() if not config.enabled: return None diff --git a/backend/packages/harness/deerflow/agents/middlewares/title_middleware.py b/backend/packages/harness/deerflow/agents/middlewares/title_middleware.py index 5cd5bb46c..01080be14 100644 --- a/backend/packages/harness/deerflow/agents/middlewares/title_middleware.py +++ b/backend/packages/harness/deerflow/agents/middlewares/title_middleware.py @@ -2,7 +2,7 @@ import logging import re -from typing import Any, NotRequired, override +from typing import TYPE_CHECKING, Any, NotRequired, override from langchain.agents import AgentState from langchain.agents.middleware import AgentMiddleware @@ -12,6 +12,10 @@ from langgraph.runtime import Runtime from deerflow.config.title_config import get_title_config from deerflow.models import create_chat_model +if TYPE_CHECKING: + from deerflow.config.app_config import AppConfig + from deerflow.config.title_config import TitleConfig + logger = logging.getLogger(__name__) @@ -26,6 +30,18 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]): state_schema = TitleMiddlewareState + def __init__(self, *, app_config: "AppConfig | None" = None, title_config: "TitleConfig | None" = None): + super().__init__() + self._app_config = app_config + self._title_config = title_config + + def _get_title_config(self): + if self._title_config is not None: + return self._title_config + if self._app_config is not None: + return self._app_config.title + return get_title_config() + def _normalize_content(self, content: object) -> str: if isinstance(content, str): return content @@ -47,7 +63,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]): def _should_generate_title(self, state: TitleMiddlewareState) -> bool: """Check if we should generate a title for this thread.""" - config = get_title_config() + config = self._get_title_config() if not config.enabled: return False @@ -72,7 +88,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]): Returns (prompt_string, user_msg) so callers can use user_msg as fallback. """ - config = get_title_config() + config = self._get_title_config() messages = state.get("messages", []) user_msg_content = next((m.content for m in messages if m.type == "human"), "") @@ -94,14 +110,14 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]): def _parse_title(self, content: object) -> str: """Normalize model output into a clean title string.""" - config = get_title_config() + config = self._get_title_config() title_content = self._normalize_content(content) title_content = self._strip_think_tags(title_content) title = title_content.strip().strip('"').strip("'") return title[: config.max_chars] if len(title) > config.max_chars else title def _fallback_title(self, user_msg: str) -> str: - config = get_title_config() + config = self._get_title_config() fallback_chars = min(config.max_chars, 50) if len(user_msg) > fallback_chars: return user_msg[:fallback_chars].rstrip() + "..." @@ -135,14 +151,17 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]): if not self._should_generate_title(state): return None - config = get_title_config() + config = self._get_title_config() prompt, user_msg = self._build_title_prompt(state) try: + model_kwargs = {"thinking_enabled": False} + if self._app_config is not None: + model_kwargs["app_config"] = self._app_config if config.model_name: - model = create_chat_model(name=config.model_name, thinking_enabled=False) + model = create_chat_model(name=config.model_name, **model_kwargs) else: - model = create_chat_model(thinking_enabled=False) + model = create_chat_model(**model_kwargs) response = await model.ainvoke(prompt, config=self._get_runnable_config()) title = self._parse_title(response.content) if title: diff --git a/backend/packages/harness/deerflow/config/app_config.py b/backend/packages/harness/deerflow/config/app_config.py index a41108372..dae41b14d 100644 --- a/backend/packages/harness/deerflow/config/app_config.py +++ b/backend/packages/harness/deerflow/config/app_config.py @@ -8,7 +8,7 @@ import yaml from dotenv import load_dotenv from pydantic import BaseModel, ConfigDict, Field -from deerflow.config.acp_config import load_acp_config_from_dict +from deerflow.config.acp_config import ACPAgentConfig, load_acp_config_from_dict from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict from deerflow.config.database_config import DatabaseConfig @@ -95,6 +95,7 @@ class AppConfig(BaseModel): summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration") memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration") agents_api: AgentsApiConfig = Field(default_factory=AgentsApiConfig, description="Custom-agent management API configuration") + acp_agents: dict[str, ACPAgentConfig] = Field(default_factory=dict, description="ACP-compatible agent configuration") subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration") guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration") circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration") diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index 1223c2127..d8f9c139b 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -21,7 +21,7 @@ import inspect import logging from dataclasses import dataclass, field from functools import lru_cache -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast if TYPE_CHECKING: from langchain_core.messages import HumanMessage @@ -39,12 +39,19 @@ logger = logging.getLogger(__name__) _VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"} -def _build_runtime_context(thread_id: str, run_id: str, caller_context: Any | None) -> dict[str, Any]: +def _build_runtime_context( + thread_id: str, + run_id: str, + caller_context: Any | None, + app_config: AppConfig | None = None, +) -> dict[str, Any]: """Build the dict that becomes ``ToolRuntime.context`` for the run. Always includes ``thread_id`` and ``run_id``. Additional keys from the caller's ``config['context']`` (e.g. ``agent_name`` for the bootstrap flow — issue #2677) - are merged in but never override ``thread_id``/``run_id``. + are merged in but never override ``thread_id``/``run_id``. The resolved + ``AppConfig`` is added by the worker so tools can consume it without ambient + global lookups. langgraph 1.1+ surfaces this as ``runtime.context`` via the parent runtime stored under ``config['configurable']['__pregel_runtime']`` — see @@ -54,6 +61,8 @@ def _build_runtime_context(thread_id: str, run_id: str, caller_context: Any | No if isinstance(caller_context, dict): for key, value in caller_context.items(): runtime_ctx.setdefault(key, value) + if app_config is not None: + runtime_ctx["app_config"] = app_config return runtime_ctx @@ -74,6 +83,18 @@ class RunContext: app_config: AppConfig | None = field(default=None) +def _install_runtime_context(config: dict, runtime_context: dict[str, Any]) -> None: + existing_context = config.get("context") + if isinstance(existing_context, dict): + existing_context.setdefault("thread_id", runtime_context["thread_id"]) + existing_context.setdefault("run_id", runtime_context["run_id"]) + if "app_config" in runtime_context: + existing_context["app_config"] = runtime_context["app_config"] + return + + config["context"] = dict(runtime_context) + + def _compute_agent_factory_supports_app_config(agent_factory: Any) -> bool: try: return "app_config" in inspect.signature(agent_factory).parameters @@ -191,11 +212,9 @@ async def run_agent( # access thread-level data. langgraph-cli does this automatically; we must do it # manually here because we drive the graph through ``agent.astream(config=...)`` # without passing the official ``context=`` parameter. - runtime_ctx = _build_runtime_context(thread_id, run_id, config.get("context")) - if "context" in config and isinstance(config["context"], dict): - config["context"].setdefault("thread_id", thread_id) - config["context"].setdefault("run_id", run_id) - runtime = Runtime(context=runtime_ctx, store=store) + runtime_ctx = _build_runtime_context(thread_id, run_id, config.get("context"), ctx.app_config) + _install_runtime_context(config, runtime_ctx) + runtime = Runtime(context=cast(Any, runtime_ctx), store=store) config.setdefault("configurable", {})["__pregel_runtime"] = runtime # Inject RunJournal as a LangChain callback handler. diff --git a/backend/packages/harness/deerflow/subagents/executor.py b/backend/packages/harness/deerflow/subagents/executor.py index ab850ede7..2fe5c05dc 100644 --- a/backend/packages/harness/deerflow/subagents/executor.py +++ b/backend/packages/harness/deerflow/subagents/executor.py @@ -168,6 +168,8 @@ def _get_isolated_subagent_loop() -> asyncio.AbstractEventLoop: _isolated_subagent_loop_thread = thread _isolated_subagent_loop_started = started_event + if _isolated_subagent_loop is None: + raise RuntimeError("Isolated subagent event loop is not initialized") return _isolated_subagent_loop @@ -308,8 +310,10 @@ class SubagentExecutor: try: from deerflow.skills.storage import get_or_new_skill_storage + storage_kwargs = {"app_config": self.app_config} if self.app_config is not None else {} + storage = await asyncio.to_thread(get_or_new_skill_storage, **storage_kwargs) # Use asyncio.to_thread to avoid blocking the event loop (LangGraph ASGI requirement) - all_skills = await asyncio.to_thread(get_or_new_skill_storage().load_skills, enabled_only=True) + all_skills = await asyncio.to_thread(storage.load_skills, enabled_only=True) logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} loaded {len(all_skills)} enabled skills from disk") except Exception: logger.warning(f"[trace={self.trace_id}] Failed to load skills for subagent {self.config.name}", exc_info=True) @@ -395,6 +399,10 @@ class SubagentExecutor: status=SubagentStatus.RUNNING, started_at=datetime.now(), ) + ai_messages = result.ai_messages + if ai_messages is None: + ai_messages = [] + result.ai_messages = ai_messages try: agent = self._create_agent() @@ -404,10 +412,12 @@ class SubagentExecutor: run_config: RunnableConfig = { "recursion_limit": self.config.max_turns, } - context = {} + context: dict[str, Any] = {} if self.thread_id: run_config["configurable"] = {"thread_id": self.thread_id} context["thread_id"] = self.thread_id + if self.app_config is not None: + context["app_config"] = self.app_config logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} starting async execution with max_turns={self.config.max_turns}") @@ -454,13 +464,13 @@ class SubagentExecutor: message_id = message_dict.get("id") is_duplicate = False if message_id: - is_duplicate = any(msg.get("id") == message_id for msg in result.ai_messages) + is_duplicate = any(msg.get("id") == message_id for msg in ai_messages) else: - is_duplicate = message_dict in result.ai_messages + is_duplicate = message_dict in ai_messages if not is_duplicate: - result.ai_messages.append(message_dict) - logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(result.ai_messages)}") + ai_messages.append(message_dict) + logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(ai_messages)}") logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed async execution") diff --git a/backend/packages/harness/deerflow/subagents/registry.py b/backend/packages/harness/deerflow/subagents/registry.py index b34d7e9bd..4c4f3f183 100644 --- a/backend/packages/harness/deerflow/subagents/registry.py +++ b/backend/packages/harness/deerflow/subagents/registry.py @@ -2,6 +2,7 @@ import logging from dataclasses import replace +from typing import Any from deerflow.sandbox.security import is_host_bash_allowed from deerflow.subagents.builtins import BUILTIN_SUBAGENTS @@ -10,19 +11,26 @@ from deerflow.subagents.config import SubagentConfig logger = logging.getLogger(__name__) -def _build_custom_subagent_config(name: str) -> SubagentConfig | None: +def _resolve_subagents_app_config(app_config: Any | None = None): + if app_config is None: + from deerflow.config.subagents_config import get_subagents_app_config + + return get_subagents_app_config() + return getattr(app_config, "subagents", app_config) + + +def _build_custom_subagent_config(name: str, *, app_config: Any | None = None) -> SubagentConfig | None: """Build a SubagentConfig from config.yaml custom_agents section. Args: name: The name of the custom subagent. + app_config: Optional AppConfig or SubagentsAppConfig to resolve from. Returns: SubagentConfig if found in custom_agents, None otherwise. """ - from deerflow.config.subagents_config import get_subagents_app_config - - app_config = get_subagents_app_config() - custom = app_config.custom_agents.get(name) + subagents_config = _resolve_subagents_app_config(app_config) + custom = subagents_config.custom_agents.get(name) if custom is None: return None @@ -39,7 +47,7 @@ def _build_custom_subagent_config(name: str) -> SubagentConfig | None: ) -def get_subagent_config(name: str) -> SubagentConfig | None: +def get_subagent_config(name: str, *, app_config: Any | None = None) -> SubagentConfig | None: """Get a subagent configuration by name, with config.yaml overrides applied. Resolution order (mirrors Codex's config layering): @@ -49,6 +57,7 @@ def get_subagent_config(name: str) -> SubagentConfig | None: Args: name: The name of the subagent. + app_config: Optional AppConfig or SubagentsAppConfig to resolve overrides from. Returns: SubagentConfig if found (with any config.yaml overrides applied), None otherwise. @@ -56,7 +65,7 @@ def get_subagent_config(name: str) -> SubagentConfig | None: # Step 1: Look up built-in, then fall back to custom_agents config = BUILTIN_SUBAGENTS.get(name) if config is None: - config = _build_custom_subagent_config(name) + config = _build_custom_subagent_config(name, app_config=app_config) if config is None: return None @@ -65,12 +74,9 @@ def get_subagent_config(name: str) -> SubagentConfig | None: # (timeout_seconds, max_turns at the top level) apply to built-in agents # but must NOT override custom agents' own values — custom agents define # their own defaults in the custom_agents section. - # Lazy import to avoid circular deps. - from deerflow.config.subagents_config import get_subagents_app_config - - app_config = get_subagents_app_config() + subagents_config = _resolve_subagents_app_config(app_config) is_builtin = name in BUILTIN_SUBAGENTS - agent_override = app_config.agents.get(name) + agent_override = subagents_config.agents.get(name) overrides = {} @@ -79,27 +85,27 @@ def get_subagent_config(name: str) -> SubagentConfig | None: if agent_override.timeout_seconds != config.timeout_seconds: logger.debug("Subagent '%s': timeout overridden (%ss -> %ss)", name, config.timeout_seconds, agent_override.timeout_seconds) overrides["timeout_seconds"] = agent_override.timeout_seconds - elif is_builtin and app_config.timeout_seconds != config.timeout_seconds: - logger.debug("Subagent '%s': timeout from global default (%ss -> %ss)", name, config.timeout_seconds, app_config.timeout_seconds) - overrides["timeout_seconds"] = app_config.timeout_seconds + elif is_builtin and subagents_config.timeout_seconds != config.timeout_seconds: + logger.debug("Subagent '%s': timeout from global default (%ss -> %ss)", name, config.timeout_seconds, subagents_config.timeout_seconds) + overrides["timeout_seconds"] = subagents_config.timeout_seconds # Max turns: per-agent override > global default (builtins only) > config's own value if agent_override is not None and agent_override.max_turns is not None: if agent_override.max_turns != config.max_turns: logger.debug("Subagent '%s': max_turns overridden (%s -> %s)", name, config.max_turns, agent_override.max_turns) overrides["max_turns"] = agent_override.max_turns - elif is_builtin and app_config.max_turns is not None and app_config.max_turns != config.max_turns: - logger.debug("Subagent '%s': max_turns from global default (%s -> %s)", name, config.max_turns, app_config.max_turns) - overrides["max_turns"] = app_config.max_turns + elif is_builtin and subagents_config.max_turns is not None and subagents_config.max_turns != config.max_turns: + logger.debug("Subagent '%s': max_turns from global default (%s -> %s)", name, config.max_turns, subagents_config.max_turns) + overrides["max_turns"] = subagents_config.max_turns # Model: per-agent override only (no global default for model) - effective_model = app_config.get_model_for(name) + effective_model = subagents_config.get_model_for(name) if effective_model is not None and effective_model != config.model: logger.debug("Subagent '%s': model overridden (%s -> %s)", name, config.model, effective_model) overrides["model"] = effective_model # Skills: per-agent override only (no global default for skills) - effective_skills = app_config.get_skills_for(name) + effective_skills = subagents_config.get_skills_for(name) if effective_skills is not None and effective_skills != config.skills: logger.debug("Subagent '%s': skills overridden (%s -> %s)", name, config.skills, effective_skills) overrides["skills"] = effective_skills @@ -110,21 +116,21 @@ def get_subagent_config(name: str) -> SubagentConfig | None: return config -def list_subagents() -> list[SubagentConfig]: +def list_subagents(*, app_config: Any | None = None) -> list[SubagentConfig]: """List all available subagent configurations (with config.yaml overrides applied). Returns: List of all registered SubagentConfig instances (built-in + custom). """ configs = [] - for name in get_subagent_names(): - config = get_subagent_config(name) + for name in get_subagent_names(app_config=app_config): + config = get_subagent_config(name, app_config=app_config) if config is not None: configs.append(config) return configs -def get_subagent_names() -> list[str]: +def get_subagent_names(*, app_config: Any | None = None) -> list[str]: """Get all available subagent names (built-in + custom). Returns: @@ -133,25 +139,23 @@ def get_subagent_names() -> list[str]: names = list(BUILTIN_SUBAGENTS.keys()) # Merge custom_agents from config.yaml - from deerflow.config.subagents_config import get_subagents_app_config - - app_config = get_subagents_app_config() - for custom_name in app_config.custom_agents: + subagents_config = _resolve_subagents_app_config(app_config) + for custom_name in subagents_config.custom_agents: if custom_name not in names: names.append(custom_name) return names -def get_available_subagent_names() -> list[str]: +def get_available_subagent_names(*, app_config: Any | None = None) -> list[str]: """Get subagent names that should be exposed to the active runtime. Returns: List of subagent names visible to the current sandbox configuration. """ - names = get_subagent_names() + names = get_subagent_names(app_config=app_config) try: - host_bash_allowed = is_host_bash_allowed() + host_bash_allowed = is_host_bash_allowed(app_config) if hasattr(app_config, "sandbox") else is_host_bash_allowed() except Exception: logger.debug("Could not determine host bash availability; exposing all subagents") return names diff --git a/backend/packages/harness/deerflow/tools/builtins/task_tool.py b/backend/packages/harness/deerflow/tools/builtins/task_tool.py index 42062f0aa..1328507b2 100644 --- a/backend/packages/harness/deerflow/tools/builtins/task_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/task_tool.py @@ -4,7 +4,7 @@ import asyncio import logging import uuid from dataclasses import replace -from typing import Annotated +from typing import TYPE_CHECKING, Annotated, Any, cast from langchain.tools import InjectedToolCallId, ToolRuntime, tool from langgraph.config import get_stream_writer @@ -22,9 +22,21 @@ from deerflow.subagents.executor import ( request_cancel_background_task, ) +if TYPE_CHECKING: + from deerflow.config.app_config import AppConfig + logger = logging.getLogger(__name__) +def _get_runtime_app_config(runtime: Any) -> "AppConfig | None": + context = getattr(runtime, "context", None) + if isinstance(context, dict): + app_config = context.get("app_config") + if app_config is not None: + return cast("AppConfig", app_config) + return None + + def _merge_skill_allowlists(parent: list[str] | None, child: list[str] | None) -> list[str] | None: """Return the effective subagent skill allowlist under the parent policy.""" if parent is None: @@ -81,15 +93,18 @@ async def task_tool( subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD. max_turns: Optional maximum number of agent turns. Defaults to subagent's configured max. """ - available_subagent_names = get_available_subagent_names() + runtime_app_config = _get_runtime_app_config(runtime) + available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names() # Get subagent configuration - config = get_subagent_config(subagent_type) + config = get_subagent_config(subagent_type, app_config=runtime_app_config) if runtime_app_config is not None else get_subagent_config(subagent_type) if config is None: available = ", ".join(available_subagent_names) return f"Error: Unknown subagent type '{subagent_type}'. Available: {available}" - if subagent_type == "bash" and not is_host_bash_allowed(): - return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}" + if subagent_type == "bash": + host_bash_allowed = is_host_bash_allowed(runtime_app_config) if runtime_app_config is not None else is_host_bash_allowed() + if not host_bash_allowed: + return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}" # Build config overrides overrides: dict = {} @@ -136,25 +151,34 @@ async def task_tool( # Inherit parent agent's tool_groups so subagents respect the same restrictions parent_tool_groups = metadata.get("tool_groups") - app_config = None - if config.model == "inherit" and parent_model is None: - app_config = get_app_config() - effective_model = resolve_subagent_model_name(config, parent_model, app_config=app_config) + resolved_app_config = runtime_app_config + if config.model == "inherit" and parent_model is None and resolved_app_config is None: + resolved_app_config = get_app_config() + effective_model = resolve_subagent_model_name(config, parent_model, app_config=resolved_app_config) # Subagents should not have subagent tools enabled (prevent recursive nesting) - tools = get_available_tools(model_name=effective_model, groups=parent_tool_groups, subagent_enabled=False) + available_tools_kwargs = { + "model_name": effective_model, + "groups": parent_tool_groups, + "subagent_enabled": False, + } + if resolved_app_config is not None: + available_tools_kwargs["app_config"] = resolved_app_config + tools = get_available_tools(**available_tools_kwargs) # Create executor - executor = SubagentExecutor( - config=config, - tools=tools, - app_config=app_config, - parent_model=parent_model, - sandbox_state=sandbox_state, - thread_data=thread_data, - thread_id=thread_id, - trace_id=trace_id, - ) + executor_kwargs = { + "config": config, + "tools": tools, + "parent_model": parent_model, + "sandbox_state": sandbox_state, + "thread_data": thread_data, + "thread_id": thread_id, + "trace_id": trace_id, + } + if resolved_app_config is not None: + executor_kwargs["app_config"] = resolved_app_config + executor = SubagentExecutor(**executor_kwargs) # Start background execution (always async to prevent blocking) # Use tool_call_id as task_id for better traceability @@ -189,11 +213,12 @@ async def task_tool( last_status = result.status # Check for new AI messages and send task_running events - current_message_count = len(result.ai_messages) + ai_messages = result.ai_messages or [] + current_message_count = len(ai_messages) if current_message_count > last_message_count: # Send task_running event for each new message for i in range(last_message_count, current_message_count): - message = result.ai_messages[i] + message = ai_messages[i] writer( { "type": "task_running", diff --git a/backend/packages/harness/deerflow/tools/tools.py b/backend/packages/harness/deerflow/tools/tools.py index 2ba6eb6b4..14d93e65f 100644 --- a/backend/packages/harness/deerflow/tools/tools.py +++ b/backend/packages/harness/deerflow/tools/tools.py @@ -141,10 +141,14 @@ def get_available_tools( # Add invoke_acp_agent tool if any ACP agents are configured acp_tools: list[BaseTool] = [] try: - from deerflow.config.acp_config import get_acp_agents from deerflow.tools.builtins.invoke_acp_agent_tool import build_invoke_acp_agent_tool - acp_agents = get_acp_agents() + if app_config is None: + from deerflow.config.acp_config import get_acp_agents + + acp_agents = get_acp_agents() + else: + acp_agents = getattr(config, "acp_agents", {}) or {} if acp_agents: acp_tools.append(build_invoke_acp_agent_tool(acp_agents)) logger.info(f"Including invoke_acp_agent tool ({len(acp_agents)} agent(s): {list(acp_agents.keys())})") diff --git a/backend/tests/test_invoke_acp_agent_tool.py b/backend/tests/test_invoke_acp_agent_tool.py index 3c5f6f0ff..8c44403b8 100644 --- a/backend/tests/test_invoke_acp_agent_tool.py +++ b/backend/tests/test_invoke_acp_agent_tool.py @@ -697,3 +697,33 @@ def test_get_available_tools_includes_invoke_acp_agent_when_agents_configured(mo assert "invoke_acp_agent" in [tool.name for tool in tools] load_acp_config_from_dict({}) + + +def test_get_available_tools_uses_explicit_app_config_for_acp_agents(monkeypatch): + explicit_agents = {"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")} + explicit_config = SimpleNamespace( + tools=[], + models=[], + tool_search=SimpleNamespace(enabled=False), + skill_evolution=SimpleNamespace(enabled=False), + get_model_config=lambda name: None, + acp_agents=explicit_agents, + ) + sentinel_tool = SimpleNamespace(name="invoke_acp_agent") + captured: dict[str, object] = {} + + def fail_get_acp_agents(): + raise AssertionError("ambient get_acp_agents() must not be used when app_config is explicit") + + def fake_build_invoke_acp_agent_tool(agents): + captured["agents"] = agents + return sentinel_tool + + monkeypatch.setattr("deerflow.tools.tools.is_host_bash_allowed", lambda config=None: True) + monkeypatch.setattr("deerflow.config.acp_config.get_acp_agents", fail_get_acp_agents) + monkeypatch.setattr("deerflow.tools.builtins.invoke_acp_agent_tool.build_invoke_acp_agent_tool", fake_build_invoke_acp_agent_tool) + + tools = get_available_tools(include_mcp=False, subagent_enabled=False, app_config=explicit_config) + + assert captured["agents"] is explicit_agents + assert "invoke_acp_agent" in [tool.name for tool in tools] diff --git a/backend/tests/test_lead_agent_model_resolution.py b/backend/tests/test_lead_agent_model_resolution.py index c22377b88..b240116cd 100644 --- a/backend/tests/test_lead_agent_model_resolution.py +++ b/backend/tests/test_lead_agent_model_resolution.py @@ -72,6 +72,44 @@ def test_internal_make_lead_agent_uses_explicit_app_config(monkeypatch): assert result["model"] is not None +def test_make_lead_agent_uses_runtime_app_config_from_context_without_global_read(monkeypatch): + app_config = _make_app_config([_make_model("context-model", supports_thinking=False)]) + + import deerflow.tools as tools_module + + def _raise_get_app_config(): + raise AssertionError("ambient get_app_config() must not be used when runtime context already carries app_config") + + monkeypatch.setattr(lead_agent_module, "get_app_config", _raise_get_app_config) + monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: []) + monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None, **kwargs: []) + + captured: dict[str, object] = {} + + def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None): + captured["name"] = name + captured["app_config"] = app_config + return object() + + monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model) + monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs) + + result = lead_agent_module.make_lead_agent( + { + "context": { + "model_name": "context-model", + "app_config": app_config, + } + } + ) + + assert captured == { + "name": "context-model", + "app_config": app_config, + } + assert result["model"] is not None + + def test_resolve_model_name_falls_back_to_default(monkeypatch, caplog): app_config = _make_app_config( [ @@ -276,6 +314,16 @@ def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypa ) monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda **kwargs: None) monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None) + monkeypatch.setattr( + lead_agent_module, + "TitleMiddleware", + lambda *, app_config: captured.setdefault("title_app_config", app_config) or "title-middleware", + ) + monkeypatch.setattr( + lead_agent_module, + "MemoryMiddleware", + lambda agent_name=None, *, memory_config: captured.setdefault("memory_config", memory_config) or "memory-middleware", + ) middlewares = lead_agent_module._build_middlewares( {"configurable": {"is_plan_mode": False, "subagent_enabled": False}}, @@ -286,17 +334,16 @@ def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypa assert captured == { "app_config": app_config, "lazy_init": True, + "title_app_config": app_config, + "memory_config": app_config.memory, } assert middlewares[0] == "base-middleware" def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch): - monkeypatch.setattr( - lead_agent_module, - "get_summarization_config", - lambda: SummarizationConfig(enabled=True, model_name="model-masswork"), - ) - monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False)) + app_config = _make_app_config([_make_model("model-masswork", supports_thinking=False)]) + app_config.summarization = SummarizationConfig(enabled=True, model_name="model-masswork") + app_config.memory = MemoryConfig(enabled=False) from unittest.mock import MagicMock @@ -311,13 +358,55 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch captured["app_config"] = app_config return fake_model + def _raise_get_app_config(): + raise AssertionError("ambient get_app_config() must not be used when app_config is explicit") + + monkeypatch.setattr(lead_agent_module, "get_app_config", _raise_get_app_config) monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model) monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs) - middleware = lead_agent_module._create_summarization_middleware(app_config=_make_app_config([_make_model("model-masswork", supports_thinking=False)])) + middleware = lead_agent_module._create_summarization_middleware(app_config=app_config) assert captured["name"] == "model-masswork" assert captured["thinking_enabled"] is False - assert captured["app_config"] is not None + assert captured["app_config"] is app_config assert middleware["model"] is fake_model fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"]) + + +def test_create_summarization_middleware_threads_resolved_app_config_to_model(monkeypatch): + fallback_app_config = _make_app_config([_make_model("fallback-model", supports_thinking=False)]) + fallback_app_config.summarization = SummarizationConfig(enabled=True, model_name="fallback-model") + fallback_app_config.memory = MemoryConfig(enabled=False) + + from unittest.mock import MagicMock + + captured: dict[str, object] = {} + fake_model = MagicMock() + fake_model.with_config.return_value = fake_model + + def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None): + captured["app_config"] = app_config + return fake_model + + monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: fallback_app_config) + monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model) + monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs) + + lead_agent_module._create_summarization_middleware() + + assert captured["app_config"] is fallback_app_config + + +def test_memory_middleware_uses_explicit_memory_config_without_global_read(monkeypatch): + from deerflow.agents.middlewares import memory_middleware as memory_middleware_module + from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware + + def _raise_get_memory_config(): + raise AssertionError("ambient get_memory_config() must not be used when memory_config is explicit") + + monkeypatch.setattr(memory_middleware_module, "get_memory_config", _raise_get_memory_config) + + middleware = MemoryMiddleware(memory_config=MemoryConfig(enabled=False)) + + assert middleware.after_agent({"messages": []}, runtime=MagicMock(context={"thread_id": "thread-1"})) is None diff --git a/backend/tests/test_lead_agent_prompt.py b/backend/tests/test_lead_agent_prompt.py index edbcd5193..ecaca314a 100644 --- a/backend/tests/test_lead_agent_prompt.py +++ b/backend/tests/test_lead_agent_prompt.py @@ -4,6 +4,7 @@ from types import SimpleNamespace import anyio from deerflow.agents.lead_agent import prompt as prompt_module +from deerflow.config.subagents_config import CustomSubagentConfig, SubagentsAppConfig from deerflow.skills.types import Skill @@ -40,6 +41,21 @@ def test_build_custom_mounts_section_lists_configured_mounts(monkeypatch): assert "read-only" in section +def test_build_custom_mounts_section_uses_explicit_app_config_without_global_read(monkeypatch): + mounts = [SimpleNamespace(container_path="/home/user/shared", read_only=False)] + config = SimpleNamespace(sandbox=SimpleNamespace(mounts=mounts)) + + def fail_get_app_config(): + raise AssertionError("ambient get_app_config() must not be used when app_config is explicit") + + monkeypatch.setattr("deerflow.config.get_app_config", fail_get_app_config) + + section = prompt_module._build_custom_mounts_section(app_config=config) + + assert "`/home/user/shared`" in section + assert "read-write" in section + + def test_apply_prompt_template_includes_custom_mounts(monkeypatch): mounts = [SimpleNamespace(container_path="/home/user/shared", read_only=False)] config = SimpleNamespace( @@ -49,8 +65,8 @@ def test_apply_prompt_template_includes_custom_mounts(monkeypatch): monkeypatch.setattr("deerflow.config.get_app_config", lambda: config) monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: []) monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda **kwargs: "") - monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "") - monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "") + monkeypatch.setattr(prompt_module, "_build_acp_section", lambda **kwargs: "") + monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None, **kwargs: "") monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "") prompt = prompt_module.apply_prompt_template() @@ -67,8 +83,8 @@ def test_apply_prompt_template_includes_relative_path_guidance(monkeypatch): monkeypatch.setattr("deerflow.config.get_app_config", lambda: config) monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: []) monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda **kwargs: "") - monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "") - monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "") + monkeypatch.setattr(prompt_module, "_build_acp_section", lambda **kwargs: "") + monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None, **kwargs: "") monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "") prompt = prompt_module.apply_prompt_template() @@ -77,6 +93,123 @@ def test_apply_prompt_template_includes_relative_path_guidance(monkeypatch): assert "`hello.txt`, `../uploads/data.csv`, and `../outputs/report.md`" in prompt +def test_apply_prompt_template_threads_explicit_app_config_without_global_config(monkeypatch): + mounts = [SimpleNamespace(container_path="/home/user/shared", read_only=False)] + explicit_config = SimpleNamespace( + sandbox=SimpleNamespace(mounts=mounts), + skills=SimpleNamespace(container_path="/mnt/explicit-skills"), + skill_evolution=SimpleNamespace(enabled=False), + tool_search=SimpleNamespace(enabled=False), + memory=SimpleNamespace(enabled=False, injection_enabled=True, max_injection_tokens=2000), + acp_agents={}, + ) + + def fail_get_app_config(): + raise AssertionError("ambient get_app_config() must not be used when app_config is explicit") + + def fail_get_memory_config(): + raise AssertionError("ambient get_memory_config() must not be used when app_config is explicit") + + monkeypatch.setattr("deerflow.config.get_app_config", fail_get_app_config) + monkeypatch.setattr("deerflow.config.memory_config.get_memory_config", fail_get_memory_config) + monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", lambda app_config=None: SimpleNamespace(load_skills=lambda enabled_only=True: [])) + monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "") + + prompt = prompt_module.apply_prompt_template(app_config=explicit_config) + + assert "`/home/user/shared`" in prompt + assert "Custom Mounted Directories" in prompt + + +def test_apply_prompt_template_threads_explicit_app_config_to_subagents_without_global_config(monkeypatch): + explicit_config = SimpleNamespace( + sandbox=SimpleNamespace( + use="deerflow.sandbox.local:LocalSandboxProvider", + allow_host_bash=False, + mounts=[], + ), + subagents=SubagentsAppConfig( + custom_agents={ + "researcher": CustomSubagentConfig( + description="Research agent\nwith details", + system_prompt="You research.", + ) + } + ), + skills=SimpleNamespace(container_path="/mnt/skills"), + skill_evolution=SimpleNamespace(enabled=False), + tool_search=SimpleNamespace(enabled=False), + memory=SimpleNamespace(enabled=False, injection_enabled=True, max_injection_tokens=2000), + acp_agents={}, + ) + + def fail_get_app_config(): + raise AssertionError("ambient get_app_config() must not be used when app_config is explicit") + + def fail_get_subagents_app_config(): + raise AssertionError("ambient get_subagents_app_config() must not be used when app_config is explicit") + + monkeypatch.setattr("deerflow.config.get_app_config", fail_get_app_config) + monkeypatch.setattr("deerflow.config.subagents_config.get_subagents_app_config", fail_get_subagents_app_config) + monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", lambda app_config=None: SimpleNamespace(load_skills=lambda enabled_only=True: [])) + monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "") + + prompt = prompt_module.apply_prompt_template(subagent_enabled=True, app_config=explicit_config) + + assert "**researcher**: Research agent" in prompt + assert "**bash**" not in prompt + + +def test_build_acp_section_uses_explicit_app_config_without_global_config(monkeypatch): + explicit_config = SimpleNamespace(acp_agents={"codex": object()}) + + def fail_get_acp_agents(): + raise AssertionError("ambient get_acp_agents() must not be used when app_config is explicit") + + monkeypatch.setattr("deerflow.config.acp_config.get_acp_agents", fail_get_acp_agents) + + section = prompt_module._build_acp_section(app_config=explicit_config) + + assert "ACP Agent Tasks" in section + assert "/mnt/acp-workspace/" in section + + +def test_get_memory_context_uses_explicit_app_config_without_global_config(monkeypatch): + explicit_config = SimpleNamespace( + memory=SimpleNamespace(enabled=True, injection_enabled=True, max_injection_tokens=1234), + ) + captured: dict[str, object] = {} + + def fail_get_memory_config(): + raise AssertionError("ambient get_memory_config() must not be used when app_config is explicit") + + def fake_get_memory_data(agent_name=None, *, user_id=None): + captured["agent_name"] = agent_name + captured["user_id"] = user_id + return {"facts": []} + + def fake_format_memory_for_injection(memory_data, *, max_tokens): + captured["memory_data"] = memory_data + captured["max_tokens"] = max_tokens + return "remember this" + + monkeypatch.setattr("deerflow.config.memory_config.get_memory_config", fail_get_memory_config) + monkeypatch.setattr("deerflow.runtime.user_context.get_effective_user_id", lambda: "user-1") + monkeypatch.setattr("deerflow.agents.memory.get_memory_data", fake_get_memory_data) + monkeypatch.setattr("deerflow.agents.memory.format_memory_for_injection", fake_format_memory_for_injection) + + context = prompt_module._get_memory_context("agent-a", app_config=explicit_config) + + assert "" in context + assert "remember this" in context + assert captured == { + "agent_name": "agent-a", + "user_id": "user-1", + "memory_data": {"facts": []}, + "max_tokens": 1234, + } + + def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatch, tmp_path): def make_skill(name: str) -> Skill: skill_dir = tmp_path / name diff --git a/backend/tests/test_lead_agent_skills.py b/backend/tests/test_lead_agent_skills.py index fe983d916..576f6bd19 100644 --- a/backend/tests/test_lead_agent_skills.py +++ b/backend/tests/test_lead_agent_skills.py @@ -106,7 +106,11 @@ def test_get_skills_prompt_section_uses_explicit_config_for_enabled_skills(monke skill_evolution=SimpleNamespace(enabled=False), ) + def fail_get_app_config(): + raise AssertionError("ambient get_app_config() must not be used when app_config is explicit") + monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: [_make_skill("global-skill")]) + monkeypatch.setattr("deerflow.config.get_app_config", fail_get_app_config) monkeypatch.setattr( "deerflow.agents.lead_agent.prompt.get_or_new_skill_storage", lambda app_config=None, **kwargs: __import__("types").SimpleNamespace(load_skills=lambda *, enabled_only: [_make_skill("explicit-skill")] if app_config is explicit_config else []), diff --git a/backend/tests/test_run_worker_rollback.py b/backend/tests/test_run_worker_rollback.py index b2b8da77f..0c99663ad 100644 --- a/backend/tests/test_run_worker_rollback.py +++ b/backend/tests/test_run_worker_rollback.py @@ -1,8 +1,12 @@ +import asyncio +from types import SimpleNamespace from unittest.mock import AsyncMock, call import pytest -from deerflow.runtime.runs.worker import _agent_factory_supports_app_config, _build_runtime_context, _rollback_to_pre_run_checkpoint +from deerflow.runtime.runs.manager import RunManager +from deerflow.runtime.runs.schemas import RunStatus +from deerflow.runtime.runs.worker import RunContext, _agent_factory_supports_app_config, _build_runtime_context, _install_runtime_context, _rollback_to_pre_run_checkpoint, run_agent class FakeCheckpointer: @@ -12,6 +16,73 @@ class FakeCheckpointer: self.aput_writes = AsyncMock() +def test_build_runtime_context_includes_app_config_when_present(): + app_config = object() + + context = _build_runtime_context("thread-1", "run-1", None, app_config) + + assert context["thread_id"] == "thread-1" + assert context["run_id"] == "run-1" + assert context["app_config"] is app_config + + +def test_install_runtime_context_preserves_existing_thread_id_and_threads_app_config(): + app_config = object() + config = {"context": {"thread_id": "caller-thread"}} + + _install_runtime_context( + config, + { + "thread_id": "record-thread", + "run_id": "run-1", + "app_config": app_config, + }, + ) + + assert config["context"]["thread_id"] == "caller-thread" + assert config["context"]["run_id"] == "run-1" + assert config["context"]["app_config"] is app_config + + +@pytest.mark.anyio +async def test_run_agent_threads_explicit_app_config_into_config_only_factory(): + run_manager = RunManager() + record = await run_manager.create("thread-1") + bridge = SimpleNamespace( + publish=AsyncMock(), + publish_end=AsyncMock(), + cleanup=AsyncMock(), + ) + app_config = object() + captured: dict[str, object] = {} + + class DummyAgent: + async def astream(self, graph_input, config=None, stream_mode=None, subgraphs=False): + captured["astream_context"] = config["context"] + yield {"messages": []} + + def factory(*, config): + captured["factory_context"] = config["context"] + return DummyAgent() + + await run_agent( + bridge, + run_manager, + record, + ctx=RunContext(checkpointer=None, app_config=app_config), + agent_factory=factory, + graph_input={}, + config={}, + ) + await asyncio.sleep(0) + + assert captured["factory_context"]["app_config"] is app_config + assert captured["astream_context"]["app_config"] is app_config + assert run_manager.get(record.run_id).status == RunStatus.success + bridge.publish_end.assert_awaited_once_with(record.run_id) + bridge.cleanup.assert_awaited_once_with(record.run_id, delay=60) + + @pytest.mark.anyio async def test_rollback_restores_snapshot_without_deleting_thread(): checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) diff --git a/backend/tests/test_subagent_executor.py b/backend/tests/test_subagent_executor.py index 1b2251444..102ac091a 100644 --- a/backend/tests/test_subagent_executor.py +++ b/backend/tests/test_subagent_executor.py @@ -204,7 +204,7 @@ class TestAgentConstruction: SubagentExecutor = classes["SubagentExecutor"] - app_config = object() + app_config = SimpleNamespace(models=[SimpleNamespace(name="default-model")]) model = object() middlewares = [object()] agent = object() @@ -266,6 +266,43 @@ class TestAgentConstruction: assert captured["agent"]["tools"] == [] assert captured["agent"]["system_prompt"] == base_config.system_prompt + @pytest.mark.anyio + async def test_load_skill_messages_uses_explicit_app_config_for_skill_storage( + self, + classes, + base_config, + monkeypatch: pytest.MonkeyPatch, + tmp_path, + ): + """Explicit app_config must be threaded into subagent skill storage lookup.""" + SubagentExecutor = classes["SubagentExecutor"] + + app_config = SimpleNamespace(models=[SimpleNamespace(name="default-model")]) + skill_dir = tmp_path / "demo-skill" + skill_dir.mkdir() + skill_file = skill_dir / "SKILL.md" + skill_file.write_text("Use demo skill", encoding="utf-8") + captured: dict[str, object] = {} + + def fake_get_or_new_skill_storage(*, app_config=None): + captured["app_config"] = app_config + return SimpleNamespace(load_skills=lambda *, enabled_only: [SimpleNamespace(name="demo-skill", skill_file=skill_file)]) + + monkeypatch.setattr("deerflow.skills.storage.get_or_new_skill_storage", fake_get_or_new_skill_storage) + + executor = SubagentExecutor( + config=base_config, + tools=[], + app_config=app_config, + thread_id="test-thread", + ) + + messages = await executor._load_skill_messages() + + assert captured["app_config"] is app_config + assert len(messages) == 1 + assert "Use demo skill" in messages[0].content + # ----------------------------------------------------------------------------- # Async Execution Path Tests diff --git a/backend/tests/test_subagent_skills_config.py b/backend/tests/test_subagent_skills_config.py index f121ccf25..b1ca0c24d 100644 --- a/backend/tests/test_subagent_skills_config.py +++ b/backend/tests/test_subagent_skills_config.py @@ -9,6 +9,8 @@ Covers: - Skills filter passthrough in task_tool config assembly """ +from types import SimpleNamespace + import pytest from deerflow.config.subagents_config import ( @@ -343,12 +345,54 @@ class TestRegistryCustomAgentLookup: assert config.timeout_seconds == 600 assert config.model == "inherit" + def test_custom_agent_found_from_explicit_app_config_without_global_config(self, monkeypatch): + from deerflow.subagents.registry import get_subagent_config + + def fail_get_subagents_app_config(): + raise AssertionError("ambient get_subagents_app_config() must not be used when app_config is explicit") + + monkeypatch.setattr("deerflow.config.subagents_config.get_subagents_app_config", fail_get_subagents_app_config) + + app_config = SimpleNamespace( + subagents=SubagentsAppConfig( + custom_agents={ + "analysis": CustomSubagentConfig( + description="Data analysis specialist", + system_prompt="You are a data analysis subagent.", + skills=["data-analysis"], + ) + } + ) + ) + + config = get_subagent_config("analysis", app_config=app_config) + + assert config is not None + assert config.name == "analysis" + assert config.skills == ["data-analysis"] + def test_custom_agent_not_found(self): from deerflow.subagents.registry import get_subagent_config _reset_subagents_config() assert get_subagent_config("nonexistent") is None + def test_get_available_subagent_names_falls_back_when_subagents_app_config_lacks_sandbox(self, monkeypatch): + from deerflow.subagents import registry as registry_module + from deerflow.subagents.registry import get_available_subagent_names + + captured: dict[str, tuple] = {} + + def fake_is_host_bash_allowed(*args, **kwargs): + captured["args"] = args + return True + + monkeypatch.setattr(registry_module, "is_host_bash_allowed", fake_is_host_bash_allowed) + + get_available_subagent_names(app_config=SubagentsAppConfig()) + + assert captured["args"] == () + def test_builtin_takes_priority_over_custom(self): """If a custom agent has the same name as a builtin, builtin wins.""" from deerflow.subagents.builtins import BUILTIN_SUBAGENTS diff --git a/backend/tests/test_task_tool_core_logic.py b/backend/tests/test_task_tool_core_logic.py index d436f1725..428b7a066 100644 --- a/backend/tests/test_task_tool_core_logic.py +++ b/backend/tests/test_task_tool_core_logic.py @@ -24,8 +24,11 @@ class FakeSubagentStatus(Enum): TIMED_OUT = "timed_out" -def _make_runtime() -> SimpleNamespace: +def _make_runtime(*, app_config=None) -> SimpleNamespace: # Minimal ToolRuntime-like object; task_tool only reads these three attributes. + context = {"thread_id": "thread-1"} + if app_config is not None: + context["app_config"] = app_config return SimpleNamespace( state={ "sandbox": {"sandbox_id": "local"}, @@ -35,14 +38,14 @@ def _make_runtime() -> SimpleNamespace: "outputs_path": "/tmp/outputs", }, }, - context={"thread_id": "thread-1"}, + context=context, config={"metadata": {"model_name": "ark-model", "trace_id": "trace-1"}}, ) -def _make_subagent_config() -> SubagentConfig: +def _make_subagent_config(name: str = "general-purpose") -> SubagentConfig: return SubagentConfig( - name="general-purpose", + name=name, description="General helper", system_prompt="Base system prompt", max_turns=50, @@ -112,6 +115,68 @@ def test_task_tool_rejects_bash_subagent_when_host_bash_disabled(monkeypatch): assert result.startswith("Error: Bash subagent is disabled") +def test_task_tool_threads_runtime_app_config_to_subagent_dependencies(monkeypatch): + app_config = object() + config = _make_subagent_config(name="bash") + runtime = _make_runtime(app_config=app_config) + events = [] + captured = {} + + class DummyExecutor: + def __init__(self, **kwargs): + captured["executor_kwargs"] = kwargs + + def execute_async(self, prompt, task_id=None): + captured["prompt"] = prompt + return task_id or "generated-task-id" + + def fake_get_available_subagent_names(*, app_config): + captured["names_app_config"] = app_config + return ["bash"] + + def fake_get_subagent_config(name, *, app_config): + captured["config_lookup"] = (name, app_config) + return config + + def fake_is_host_bash_allowed(config): + captured["bash_gate_app_config"] = config + return True + + def fake_get_available_tools(**kwargs): + captured["tools_kwargs"] = kwargs + return ["tool-a"] + + monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus) + monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor) + monkeypatch.setattr(task_tool_module, "get_available_subagent_names", fake_get_available_subagent_names) + monkeypatch.setattr(task_tool_module, "get_subagent_config", fake_get_subagent_config) + monkeypatch.setattr(task_tool_module, "is_host_bash_allowed", fake_is_host_bash_allowed) + monkeypatch.setattr( + task_tool_module, + "get_background_task_result", + lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"), + ) + monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) + monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) + monkeypatch.setattr("deerflow.tools.get_available_tools", fake_get_available_tools) + + output = _run_task_tool( + runtime=runtime, + description="运行命令", + prompt="inspect files", + subagent_type="bash", + tool_call_id="tc-explicit-config", + ) + + assert output == "Task Succeeded. Result: done" + assert captured["names_app_config"] is app_config + assert captured["config_lookup"] == ("bash", app_config) + assert captured["bash_gate_app_config"] is app_config + assert captured["tools_kwargs"]["app_config"] is app_config + assert captured["executor_kwargs"]["app_config"] is app_config + assert captured["executor_kwargs"]["tools"] == ["tool-a"] + + def test_task_tool_emits_running_and_completed_events(monkeypatch): config = _make_subagent_config() runtime = _make_runtime() @@ -421,7 +486,8 @@ def test_task_tool_runtime_none_passes_groups_none(monkeypatch): monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append) monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep) monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools) - monkeypatch.setattr(task_tool_module, "get_app_config", lambda: SimpleNamespace(models=[SimpleNamespace(name="default-model")])) + fallback_app_config = SimpleNamespace(models=[SimpleNamespace(name="default-model")]) + monkeypatch.setattr(task_tool_module, "get_app_config", lambda: fallback_app_config) output = _run_task_tool( runtime=None, @@ -433,7 +499,12 @@ def test_task_tool_runtime_none_passes_groups_none(monkeypatch): assert output == "Task Succeeded. Result: ok" # runtime is None -> metadata is empty dict -> groups=None, model falls back to app default. - get_available_tools.assert_called_once_with(model_name="default-model", groups=None, subagent_enabled=False) + get_available_tools.assert_called_once_with( + model_name="default-model", + groups=None, + subagent_enabled=False, + app_config=fallback_app_config, + ) config = _make_subagent_config() events = [] diff --git a/backend/tests/test_title_middleware_core_logic.py b/backend/tests/test_title_middleware_core_logic.py index afd10f2b3..ede4dc0a4 100644 --- a/backend/tests/test_title_middleware_core_logic.py +++ b/backend/tests/test_title_middleware_core_logic.py @@ -1,6 +1,7 @@ """Core behavior tests for TitleMiddleware.""" import asyncio +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock from langchain_core.messages import AIMessage, HumanMessage @@ -98,6 +99,34 @@ class TestTitleMiddlewareCoreLogic: "tags": ["middleware:title"], } + def test_generate_title_uses_explicit_app_config_without_global_config(self, monkeypatch): + title_config = TitleConfig(enabled=True, model_name="title-model", max_chars=20) + app_config = SimpleNamespace(title=title_config) + middleware = TitleMiddleware(app_config=app_config) + model = MagicMock() + model.ainvoke = AsyncMock(return_value=AIMessage(content="显式标题")) + + def fail_get_title_config(): + raise AssertionError("ambient get_title_config() must not be used when app_config is explicit") + + monkeypatch.setattr(title_middleware_module, "get_title_config", fail_get_title_config) + monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model)) + + state = { + "messages": [ + HumanMessage(content="请帮我写一个标题"), + AIMessage(content="好的"), + ] + } + result = asyncio.run(middleware._agenerate_title_result(state)) + + assert result == {"title": "显式标题"} + title_middleware_module.create_chat_model.assert_called_once_with( + name="title-model", + thinking_enabled=False, + app_config=app_config, + ) + def test_generate_title_normalizes_structured_message_content(self, monkeypatch): _set_test_title_config(max_chars=20) middleware = TitleMiddleware() From 866d1ca4098cd5fbd81ff83ebe02954d9f54967b Mon Sep 17 00:00:00 2001 From: KiteEater <145987840+Kiteeater@users.noreply.github.com> Date: Sat, 2 May 2026 11:16:03 +0800 Subject: [PATCH 06/11] Populate Codex usage metadata for token accounting (#2585) --- .../deerflow/models/openai_codex_provider.py | 3 ++ backend/tests/test_codex_provider.py | 30 +++++++++++++++++ backend/tests/test_token_usage_middleware.py | 32 +++++++++++++++++++ 3 files changed, 65 insertions(+) create mode 100644 backend/tests/test_token_usage_middleware.py diff --git a/backend/packages/harness/deerflow/models/openai_codex_provider.py b/backend/packages/harness/deerflow/models/openai_codex_provider.py index 86dee0fc6..d8e46c2ae 100644 --- a/backend/packages/harness/deerflow/models/openai_codex_provider.py +++ b/backend/packages/harness/deerflow/models/openai_codex_provider.py @@ -21,6 +21,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_openai.chat_models.base import _create_usage_metadata_responses from deerflow.models.credential_loader import CodexCliCredential, load_codex_cli_credential @@ -346,6 +347,7 @@ class CodexChatModel(BaseChatModel): ) usage = response.get("usage", {}) + usage_metadata = _create_usage_metadata_responses(usage) if usage else None additional_kwargs = {} if reasoning_content: additional_kwargs["reasoning_content"] = reasoning_content @@ -355,6 +357,7 @@ class CodexChatModel(BaseChatModel): tool_calls=tool_calls if tool_calls else [], invalid_tool_calls=invalid_tool_calls, additional_kwargs=additional_kwargs, + usage_metadata=usage_metadata, response_metadata={ "model": response.get("model", self.model), "usage": usage, diff --git a/backend/tests/test_codex_provider.py b/backend/tests/test_codex_provider.py index 65e53a21d..512154564 100644 --- a/backend/tests/test_codex_provider.py +++ b/backend/tests/test_codex_provider.py @@ -82,6 +82,36 @@ def test_parse_response_text_content(): assert result.generations[0].message.content == "Hello world" +def test_parse_response_populates_usage_metadata(): + model = _make_model() + response = { + "output": [ + { + "type": "message", + "content": [{"type": "output_text", "text": "Hello world"}], + } + ], + "usage": { + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + "input_tokens_details": {"cached_tokens": 3}, + "output_tokens_details": {"reasoning_tokens": 2}, + }, + "model": "gpt-5.4", + } + + result = model._parse_response(response) + + assert result.generations[0].message.usage_metadata == { + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + "input_token_details": {"cache_read": 3}, + "output_token_details": {"reasoning": 2}, + } + + def test_parse_response_reasoning_content(): model = _make_model() response = { diff --git a/backend/tests/test_token_usage_middleware.py b/backend/tests/test_token_usage_middleware.py new file mode 100644 index 000000000..66a1f2229 --- /dev/null +++ b/backend/tests/test_token_usage_middleware.py @@ -0,0 +1,32 @@ +from unittest.mock import MagicMock, patch + +from langchain_core.messages import AIMessage + +from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware + + +def test_after_model_logs_usage_metadata_counts(): + middleware = TokenUsageMiddleware() + state = { + "messages": [ + AIMessage( + content="done", + usage_metadata={ + "input_tokens": 10, + "output_tokens": 5, + "total_tokens": 15, + }, + ) + ] + } + + with patch("deerflow.agents.middlewares.token_usage_middleware.logger.info") as info_mock: + result = middleware.after_model(state=state, runtime=MagicMock()) + + assert result is None + info_mock.assert_called_once_with( + "LLM token usage: input=%s output=%s total=%s", + 10, + 5, + 15, + ) From 17447fccbe91aa685f5363d85ee2b5c0afa323ce Mon Sep 17 00:00:00 2001 From: KiteEater <145987840+Kiteeater@users.noreply.github.com> Date: Sat, 2 May 2026 11:25:45 +0800 Subject: [PATCH 07/11] fix(runtime): make rollback restore checkpoint supersede newer checkpoints (#2582) * Restore rollback checkpoints with fresh ids * Tighten rollback checkpoint tests and imports * Update test_run_worker_rollback.py --------- Co-authored-by: Willem Jiang --- backend/app/gateway/routers/threads.py | 3 +- .../harness/deerflow/runtime/runs/worker.py | 13 ++++ backend/pyproject.toml | 1 - backend/tests/test_run_worker_rollback.py | 77 +++++++++++++++---- 4 files changed, 75 insertions(+), 19 deletions(-) diff --git a/backend/app/gateway/routers/threads.py b/backend/app/gateway/routers/threads.py index 484582839..253717d11 100644 --- a/backend/app/gateway/routers/threads.py +++ b/backend/app/gateway/routers/threads.py @@ -18,6 +18,7 @@ import uuid from typing import Any from fastapi import APIRouter, HTTPException, Request +from langgraph.checkpoint.base import empty_checkpoint from pydantic import BaseModel, Field, field_validator from app.gateway.authz import require_permission @@ -262,8 +263,6 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe # Write an empty checkpoint so state endpoints work immediately config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} try: - from langgraph.checkpoint.base import empty_checkpoint - ckpt_metadata = { "step": -1, "source": "input", diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index d8f9c139b..2aecb9a1b 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -23,6 +23,8 @@ from dataclasses import dataclass, field from functools import lru_cache from typing import TYPE_CHECKING, Any, Literal, cast +from langgraph.checkpoint.base import empty_checkpoint + if TYPE_CHECKING: from langchain_core.messages import HumanMessage @@ -442,6 +444,12 @@ async def _rollback_to_pre_run_checkpoint( if checkpoint_to_restore.get("id") is None: logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id) return + restore_marker = _new_checkpoint_marker() + checkpoint_to_restore = { + **checkpoint_to_restore, + "id": restore_marker["id"], + "ts": restore_marker["ts"], + } metadata = pre_run_snapshot.get("metadata", {}) metadata_to_restore = metadata if isinstance(metadata, dict) else {} raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns") @@ -493,6 +501,11 @@ async def _rollback_to_pre_run_checkpoint( ) +def _new_checkpoint_marker() -> dict[str, str]: + marker = empty_checkpoint() + return {"id": marker["id"], "ts": marker["ts"]} + + def _lg_mode_to_sse_event(mode: str) -> str: """Map LangGraph internal stream_mode name to SSE event name. diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 64c6e74c3..1b74a77c4 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -47,4 +47,3 @@ members = ["packages/harness"] [tool.uv.sources] deerflow-harness = { workspace = true } - diff --git a/backend/tests/test_run_worker_rollback.py b/backend/tests/test_run_worker_rollback.py index 0c99663ad..0a4421e2f 100644 --- a/backend/tests/test_run_worker_rollback.py +++ b/backend/tests/test_run_worker_rollback.py @@ -3,6 +3,8 @@ from types import SimpleNamespace from unittest.mock import AsyncMock, call import pytest +from langgraph.checkpoint.base import empty_checkpoint +from langgraph.checkpoint.memory import InMemorySaver from deerflow.runtime.runs.manager import RunManager from deerflow.runtime.runs.schemas import RunStatus @@ -16,6 +18,14 @@ class FakeCheckpointer: self.aput_writes = AsyncMock() +def _make_checkpoint(checkpoint_id: str, messages: list[str], version: int): + checkpoint = empty_checkpoint() + checkpoint["id"] = checkpoint_id + checkpoint["channel_values"] = {"messages": messages} + checkpoint["channel_versions"] = {"messages": version} + return checkpoint + + def test_build_runtime_context_includes_app_config_when_present(): app_config = object() @@ -110,16 +120,16 @@ async def test_rollback_restores_snapshot_without_deleting_thread(): ) checkpointer.adelete_thread.assert_not_awaited() - checkpointer.aput.assert_awaited_once_with( - {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}, - { - "id": "ckpt-1", - "channel_versions": {"messages": 3}, - "channel_values": {"messages": ["before"]}, - }, - {"source": "input"}, - {"messages": 3}, - ) + checkpointer.aput.assert_awaited_once() + restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args + assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + assert restored_checkpoint["id"] != "ckpt-1" + assert "channel_versions" in restored_checkpoint + assert "channel_values" in restored_checkpoint + assert restored_checkpoint["channel_versions"] == {"messages": 3} + assert restored_checkpoint["channel_values"] == {"messages": ["before"]} + assert restored_metadata == {"source": "input"} + assert new_versions == {"messages": 3} assert checkpointer.aput_writes.await_args_list == [ call( {"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}, @@ -134,6 +144,40 @@ async def test_rollback_restores_snapshot_without_deleting_thread(): ] +@pytest.mark.anyio +async def test_rollback_restored_checkpoint_becomes_latest_with_real_checkpointer(): + checkpointer = InMemorySaver() + thread_config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + before_checkpoint = _make_checkpoint("0001", ["before"], 1) + before_config = checkpointer.put(thread_config, before_checkpoint, {"step": 1}, {"messages": 1}) + after_checkpoint = _make_checkpoint("0002", ["after"], 2) + after_config = checkpointer.put(before_config, after_checkpoint, {"step": 2}, {"messages": 2}) + checkpointer.put_writes(after_config, [("messages", "pending-after")], task_id="task-after") + + await _rollback_to_pre_run_checkpoint( + checkpointer=checkpointer, + thread_id="thread-1", + run_id="run-1", + pre_run_checkpoint_id="0001", + pre_run_snapshot={ + "checkpoint_ns": "", + "checkpoint": before_checkpoint, + "metadata": {"step": 1}, + "pending_writes": [("task-before", "messages", "pending-before")], + }, + snapshot_capture_failed=False, + ) + + latest = checkpointer.get_tuple(thread_config) + + assert latest is not None + assert latest.config["configurable"]["checkpoint_id"] != "0001" + assert latest.config["configurable"]["checkpoint_id"] != "0002" + assert latest.checkpoint["channel_values"] == {"messages": ["before"]} + assert latest.pending_writes == [("task-before", "messages", "pending-before")] + assert ("task-after", "messages", "pending-after") not in latest.pending_writes + + @pytest.mark.anyio async def test_rollback_deletes_thread_when_no_snapshot_exists(): checkpointer = FakeCheckpointer(put_result=None) @@ -194,12 +238,13 @@ async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace(): snapshot_capture_failed=False, ) - checkpointer.aput.assert_awaited_once_with( - {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}, - {"id": "ckpt-1", "channel_versions": {}}, - {}, - {}, - ) + checkpointer.aput.assert_awaited_once() + restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args + assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + assert restored_checkpoint["id"] != "ckpt-1" + assert restored_checkpoint["channel_versions"] == {} + assert restored_metadata == {} + assert new_versions == {} @pytest.mark.anyio From bb8b234d85d2568d4458dc0f9a3b3d0adcf846af Mon Sep 17 00:00:00 2001 From: Willem Jiang Date: Sat, 2 May 2026 15:04:11 +0800 Subject: [PATCH 08/11] chroe(2585): keep polishing the code of codex token usage (#2689) --- .../deerflow/models/openai_codex_provider.py | 31 +++++++++++++++++-- backend/tests/test_codex_provider.py | 14 ++++----- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/backend/packages/harness/deerflow/models/openai_codex_provider.py b/backend/packages/harness/deerflow/models/openai_codex_provider.py index d8e46c2ae..95cf12b09 100644 --- a/backend/packages/harness/deerflow/models/openai_codex_provider.py +++ b/backend/packages/harness/deerflow/models/openai_codex_provider.py @@ -21,13 +21,40 @@ from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.outputs import ChatGeneration, ChatResult -from langchain_openai.chat_models.base import _create_usage_metadata_responses from deerflow.models.credential_loader import CodexCliCredential, load_codex_cli_credential logger = logging.getLogger(__name__) CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex" + + +def _build_usage_metadata(oai_usage: dict) -> dict: + """Convert Codex/Responses API usage dict to LangChain usage_metadata format. + + Maps OpenAI Responses API token usage fields to the dict structure that + LangChain AIMessage.usage_metadata expects. This avoids depending on + langchain_openai private helpers like ``_create_usage_metadata_responses``. + """ + input_tokens = oai_usage.get("input_tokens", 0) + output_tokens = oai_usage.get("output_tokens", 0) + total_tokens = oai_usage.get("total_tokens", input_tokens + output_tokens) + metadata: dict = { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": total_tokens, + } + input_details = oai_usage.get("input_tokens_details") or {} + output_details = oai_usage.get("output_tokens_details") or {} + cache_read = input_details.get("cached_tokens") + if cache_read is not None: + metadata["input_token_details"] = {"cache_read": cache_read} + reasoning = output_details.get("reasoning_tokens") + if reasoning is not None: + metadata["output_token_details"] = {"reasoning": reasoning} + return metadata + + MAX_RETRIES = 3 @@ -347,7 +374,7 @@ class CodexChatModel(BaseChatModel): ) usage = response.get("usage", {}) - usage_metadata = _create_usage_metadata_responses(usage) if usage else None + usage_metadata = _build_usage_metadata(usage) if usage else None additional_kwargs = {} if reasoning_content: additional_kwargs["reasoning_content"] = reasoning_content diff --git a/backend/tests/test_codex_provider.py b/backend/tests/test_codex_provider.py index 512154564..1b9136b85 100644 --- a/backend/tests/test_codex_provider.py +++ b/backend/tests/test_codex_provider.py @@ -103,13 +103,13 @@ def test_parse_response_populates_usage_metadata(): result = model._parse_response(response) - assert result.generations[0].message.usage_metadata == { - "input_tokens": 10, - "output_tokens": 5, - "total_tokens": 15, - "input_token_details": {"cache_read": 3}, - "output_token_details": {"reasoning": 2}, - } + meta = result.generations[0].message.usage_metadata + assert meta is not None + assert meta["input_tokens"] == 10 + assert meta["output_tokens"] == 5 + assert meta["total_tokens"] == 15 + assert meta["input_token_details"]["cache_read"] == 3 + assert meta["output_token_details"]["reasoning"] == 2 def test_parse_response_reasoning_content(): From ca3332f8bf17d848b82cc85863ca955ed5b9adb8 Mon Sep 17 00:00:00 2001 From: Xinmin Zeng <135568692+fancyboi999@users.noreply.github.com> Date: Sat, 2 May 2026 15:16:16 +0800 Subject: [PATCH 09/11] fix(gateway): return ISO 8601 timestamps from threads endpoints (#2599) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(gateway): return ISO 8601 timestamps from threads endpoints (#2594) ThreadResponse documents created_at / updated_at as ISO timestamps, matching the LangGraph Platform schema (langgraph_sdk.schema.Thread exposes them as datetime, JSON-encoded as ISO 8601). The gateway threads router was instead emitting str(time.time()) — unix-second floats — breaking frontend new Date() parsing and producing a mixed ISO/unix wire format that also corrupted the search sort order. Centralize timestamp generation in deerflow.utils.time: - now_iso() — datetime.now(UTC).isoformat() - coerce_iso(x) — heals legacy unix-timestamp strings on read so the store converges to ISO without a one-shot migration threads.py: replace 6 time.time() call sites with now_iso(); wrap all read paths and Phase-2 checkpoint metadata with coerce_iso(); _store_upsert opportunistically heals legacy created_at on update; drop unused time import. thread_runs.py: reuse now_iso() instead of a private duplicate _now_iso(), preventing future drift between the two timestamp call sites. Tests: 9 unit tests for the helper; 5 integration tests pinning the ISO contract for create/get/patch/search and the legacy-healing path on the internal store upsert. Full suite: 2144 passed, 15 skipped, 0 failed. Closes #2594 * fix(gateway): coerce checkpoint metadata timestamps to ISO on read After the merge with main, three additional read paths in ``threads.py`` were still emitting raw ``str(metadata.get("created_at", ""))`` — ``get_thread_state``, ``update_thread_state``, and ``get_thread_history``. Same root cause as #2594: when the checkpoint metadata's ``created_at`` is a unix-second float (legacy data, or a checkpoint written by an older Gateway version), ``str(float)`` produces ``"1777252410.411327"`` and the frontend's ``new Date(...)`` returns ``Invalid Date``. The fix on the ``/threads/{id}`` GET path was already in place; these three sibling endpoints needed the same treatment. All four call sites now flow through ``coerce_iso``, so: - legacy float metadata heals to ISO on the way out, - ISO metadata passes through unchanged, - ``datetime`` instances (which the new ``coerce_iso`` branch handles explicitly) emit with the ``T`` separator instead of falling through to the space-separated ``str(datetime)`` form. Coverage added for the two endpoints not already pinned by the merge: - ``test_get_thread_state_returns_iso_for_legacy_checkpoint_metadata`` - ``test_get_thread_history_returns_iso_for_legacy_checkpoint_metadata`` Both pre-seed a checkpoint whose metadata carries the literal float from the issue body and assert the wire format is ISO. --- backend/app/gateway/routers/threads.py | 41 +-- .../persistence/thread_meta/memory.py | 16 +- .../harness/deerflow/runtime/runs/manager.py | 7 +- .../packages/harness/deerflow/utils/time.py | 75 +++++ backend/tests/test_threads_router.py | 297 +++++++++++++++++- backend/tests/test_utils_time.py | 90 ++++++ 6 files changed, 494 insertions(+), 32 deletions(-) create mode 100644 backend/packages/harness/deerflow/utils/time.py create mode 100644 backend/tests/test_utils_time.py diff --git a/backend/app/gateway/routers/threads.py b/backend/app/gateway/routers/threads.py index 253717d11..cb048152e 100644 --- a/backend/app/gateway/routers/threads.py +++ b/backend/app/gateway/routers/threads.py @@ -13,7 +13,6 @@ matching the LangGraph Platform wire format expected by the from __future__ import annotations import logging -import time import uuid from typing import Any @@ -27,6 +26,7 @@ from app.gateway.utils import sanitize_log_param from deerflow.config.paths import Paths, get_paths from deerflow.runtime import serialize_channel_values from deerflow.runtime.user_context import get_effective_user_id +from deerflow.utils.time import coerce_iso, now_iso logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/threads", tags=["threads"]) @@ -234,7 +234,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe checkpointer = get_checkpointer(request) thread_store = get_thread_store(request) thread_id = body.thread_id or str(uuid.uuid4()) - now = time.time() + now = now_iso() # ``body.metadata`` is already stripped of server-reserved keys by # ``ThreadCreateRequest._strip_reserved`` — see the model definition. @@ -244,8 +244,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe return ThreadResponse( thread_id=thread_id, status=existing_record.get("status", "idle"), - created_at=str(existing_record.get("created_at", "")), - updated_at=str(existing_record.get("updated_at", "")), + created_at=coerce_iso(existing_record.get("created_at", "")), + updated_at=coerce_iso(existing_record.get("updated_at", "")), metadata=existing_record.get("metadata", {}), ) @@ -280,8 +280,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe return ThreadResponse( thread_id=thread_id, status="idle", - created_at=str(now), - updated_at=str(now), + created_at=now, + updated_at=now, metadata=body.metadata, ) @@ -306,8 +306,11 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th ThreadResponse( thread_id=r["thread_id"], status=r.get("status", "idle"), - created_at=r.get("created_at", ""), - updated_at=r.get("updated_at", ""), + # ``coerce_iso`` heals legacy unix-second values that + # ``MemoryThreadMetaStore`` historically wrote with ``time.time()``; + # SQL-backed rows already arrive as ISO strings and pass through. + created_at=coerce_iso(r.get("created_at", "")), + updated_at=coerce_iso(r.get("updated_at", "")), metadata=r.get("metadata", {}), values={"title": r["display_name"]} if r.get("display_name") else {}, interrupts={}, @@ -339,8 +342,8 @@ async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Reques return ThreadResponse( thread_id=thread_id, status=record.get("status", "idle"), - created_at=str(record.get("created_at", "")), - updated_at=str(record.get("updated_at", "")), + created_at=coerce_iso(record.get("created_at", "")), + updated_at=coerce_iso(record.get("updated_at", "")), metadata=record.get("metadata", {}), ) @@ -380,8 +383,8 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse: record = { "thread_id": thread_id, "status": "idle", - "created_at": ckpt_meta.get("created_at", ""), - "updated_at": ckpt_meta.get("updated_at", ckpt_meta.get("created_at", "")), + "created_at": coerce_iso(ckpt_meta.get("created_at", "")), + "updated_at": coerce_iso(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))), "metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}, } @@ -395,8 +398,8 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse: return ThreadResponse( thread_id=thread_id, status=status, - created_at=str(record.get("created_at", "")), - updated_at=str(record.get("updated_at", "")), + created_at=coerce_iso(record.get("created_at", "")), + updated_at=coerce_iso(record.get("updated_at", "")), metadata=record.get("metadata", {}), values=serialize_channel_values(channel_values), ) @@ -447,10 +450,10 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo values=values, next=next_tasks, metadata=metadata, - checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))}, + checkpoint={"id": checkpoint_id, "ts": coerce_iso(metadata.get("created_at", ""))}, checkpoint_id=checkpoint_id, parent_checkpoint_id=parent_checkpoint_id, - created_at=str(metadata.get("created_at", "")), + created_at=coerce_iso(metadata.get("created_at", "")), tasks=tasks, ) @@ -500,7 +503,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re channel_values.update(body.values) checkpoint["channel_values"] = channel_values - metadata["updated_at"] = time.time() + metadata["updated_at"] = now_iso() if body.as_node: metadata["source"] = "update" @@ -541,7 +544,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re next=[], metadata=metadata, checkpoint_id=new_checkpoint_id, - created_at=str(metadata.get("created_at", "")), + created_at=coerce_iso(metadata.get("created_at", "")), ) @@ -608,7 +611,7 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request parent_checkpoint_id=parent_id, metadata=user_meta, values=values, - created_at=str(metadata.get("created_at", "")), + created_at=coerce_iso(metadata.get("created_at", "")), next=next_tasks, ) ) diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py index ccf59ad42..fbe66fdaf 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py @@ -7,13 +7,13 @@ router for thread records. from __future__ import annotations -import time from typing import Any from langgraph.store.base import BaseStore from deerflow.persistence.thread_meta.base import ThreadMetaStore from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id +from deerflow.utils.time import coerce_iso, now_iso THREADS_NS: tuple[str, ...] = ("threads",) @@ -48,7 +48,7 @@ class MemoryThreadMetaStore(ThreadMetaStore): metadata: dict | None = None, ) -> dict: resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.create") - now = time.time() + now = now_iso() record: dict[str, Any] = { "thread_id": thread_id, "assistant_id": assistant_id, @@ -106,7 +106,7 @@ class MemoryThreadMetaStore(ThreadMetaStore): if record is None: return record["display_name"] = display_name - record["updated_at"] = time.time() + record["updated_at"] = now_iso() await self._store.aput(THREADS_NS, thread_id, record) async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: @@ -114,7 +114,7 @@ class MemoryThreadMetaStore(ThreadMetaStore): if record is None: return record["status"] = status - record["updated_at"] = time.time() + record["updated_at"] = now_iso() await self._store.aput(THREADS_NS, thread_id, record) async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None: @@ -124,7 +124,7 @@ class MemoryThreadMetaStore(ThreadMetaStore): merged = dict(record.get("metadata") or {}) merged.update(metadata) record["metadata"] = merged - record["updated_at"] = time.time() + record["updated_at"] = now_iso() await self._store.aput(THREADS_NS, thread_id, record) async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: @@ -144,6 +144,8 @@ class MemoryThreadMetaStore(ThreadMetaStore): "display_name": val.get("display_name"), "status": val.get("status", "idle"), "metadata": val.get("metadata", {}), - "created_at": str(val.get("created_at", "")), - "updated_at": str(val.get("updated_at", "")), + # ``coerce_iso`` heals legacy unix-second values written by + # earlier Gateway versions that called ``str(time.time())``. + "created_at": coerce_iso(val.get("created_at", "")), + "updated_at": coerce_iso(val.get("updated_at", "")), } diff --git a/backend/packages/harness/deerflow/runtime/runs/manager.py b/backend/packages/harness/deerflow/runtime/runs/manager.py index a54a408b8..533342c87 100644 --- a/backend/packages/harness/deerflow/runtime/runs/manager.py +++ b/backend/packages/harness/deerflow/runtime/runs/manager.py @@ -6,9 +6,10 @@ import asyncio import logging import uuid from dataclasses import dataclass, field -from datetime import UTC, datetime from typing import TYPE_CHECKING +from deerflow.utils.time import now_iso as _now_iso + from .schemas import DisconnectMode, RunStatus if TYPE_CHECKING: @@ -17,10 +18,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -def _now_iso() -> str: - return datetime.now(UTC).isoformat() - - @dataclass class RunRecord: """Mutable record for a single run.""" diff --git a/backend/packages/harness/deerflow/utils/time.py b/backend/packages/harness/deerflow/utils/time.py new file mode 100644 index 000000000..307a4b6b0 --- /dev/null +++ b/backend/packages/harness/deerflow/utils/time.py @@ -0,0 +1,75 @@ +"""ISO 8601 timestamp helpers for the Gateway and embedded runtime. + +DeerFlow stores and serializes thread/run timestamps as ISO 8601 UTC +strings to match the LangGraph Platform schema (see +``langgraph_sdk.schema.Thread``, where ``created_at`` / ``updated_at`` +are ``datetime`` and JSON-encode to ISO 8601). All timestamp generation +should funnel through :func:`now_iso` so the wire format stays +consistent across endpoints, the embedded ``RunManager``, and the +checkpoint metadata written by the Gateway. + +:func:`coerce_iso` provides a forward-compatible read path for legacy +records that historically stored ``str(time.time())`` floats. +""" + +from __future__ import annotations + +import re +from datetime import UTC, datetime + +__all__ = ["coerce_iso", "now_iso"] + +_UNIX_TIMESTAMP_PATTERN = re.compile(r"^\d{10}(?:\.\d+)?$") +"""Matches the unix-timestamp string shape historically written by +``str(time.time())`` (10-digit seconds with optional fractional part). +The 10-digit anchor avoids accidentally rewriting ISO years like +``"2026"`` and stays valid until the year 2286. +""" + + +def now_iso() -> str: + """Return the current UTC time as an ISO 8601 string. + + Example: ``"2026-04-27T03:19:46.511479+00:00"``. + """ + return datetime.now(UTC).isoformat() + + +def coerce_iso(value: object) -> str: + """Best-effort coerce a stored timestamp to an ISO 8601 string. + + Translates legacy unix-timestamp floats / strings written by older + DeerFlow versions into ISO without a one-shot migration. ISO strings + pass through unchanged; ``datetime`` instances are normalised to UTC + (tz-naive values are assumed to be UTC) and emitted via + ``isoformat()`` so the wire format always uses the ``T`` separator; + empty values become ``""``; unrecognised values are stringified as a + last resort. + """ + if value is None or value == "": + return "" + if isinstance(value, bool): + # ``bool`` is a subclass of ``int`` — treat as garbage, not 0/1. + return str(value) + if isinstance(value, datetime): + # ``datetime`` must be handled before the ``int``/``float`` check; + # str(datetime) would produce ``"YYYY-MM-DD HH:MM:SS+00:00"`` + # (space separator), which breaks strict ISO 8601 consumers. + if value.tzinfo is None: + value = value.replace(tzinfo=UTC) + else: + value = value.astimezone(UTC) + return value.isoformat() + if isinstance(value, (int, float)): + try: + return datetime.fromtimestamp(float(value), UTC).isoformat() + except (ValueError, OverflowError, OSError): + return str(value) + if isinstance(value, str): + if _UNIX_TIMESTAMP_PATTERN.match(value): + try: + return datetime.fromtimestamp(float(value), UTC).isoformat() + except (ValueError, OverflowError, OSError): + return value + return value + return str(value) diff --git a/backend/tests/test_threads_router.py b/backend/tests/test_threads_router.py index 4ffa28a8c..daf0c0b13 100644 --- a/backend/tests/test_threads_router.py +++ b/backend/tests/test_threads_router.py @@ -1,12 +1,66 @@ +import re from unittest.mock import patch import pytest from _router_auth_helpers import make_authed_test_app -from fastapi import HTTPException +from fastapi import FastAPI, HTTPException from fastapi.testclient import TestClient +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.store.memory import InMemoryStore from app.gateway.routers import threads from deerflow.config.paths import Paths +from deerflow.persistence.thread_meta.memory import THREADS_NS, MemoryThreadMetaStore + +_ISO_TIMESTAMP_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}") + + +class _PermissiveThreadMetaStore(MemoryThreadMetaStore): + """Memory store that skips user-id filtering for router tests. + + Owner isolation is exercised separately in + ``test_memory_thread_meta_isolation.py``. Router tests need to drive + the FastAPI surface end-to-end with a single fixed app user, but the + stub auth middleware in ``_router_auth_helpers`` stamps a fresh UUID + on every request, so the production filtering would reject every + pre-seeded record. Bypass that filter so the test can focus on the + timestamp wire format. + """ + + async def _get_owned_record(self, thread_id, user_id, method_name): # type: ignore[override] + item = await self._store.aget(THREADS_NS, thread_id) + return dict(item.value) if item is not None else None + + async def check_access(self, thread_id, user_id, *, require_existing=False): # type: ignore[override] + item = await self._store.aget(THREADS_NS, thread_id) + if item is None: + return not require_existing + return True + + async def create(self, thread_id, *, assistant_id=None, user_id=None, display_name=None, metadata=None): # type: ignore[override] + return await super().create(thread_id, assistant_id=assistant_id, user_id=None, display_name=display_name, metadata=metadata) + + async def search(self, *, metadata=None, status=None, limit=100, offset=0, user_id=None): # type: ignore[override] + return await super().search(metadata=metadata, status=status, limit=limit, offset=offset, user_id=None) + + +def _build_thread_app() -> tuple[FastAPI, InMemoryStore, InMemorySaver]: + """Build a stub-authed FastAPI app wired with an in-memory ThreadMetaStore. + + The thread_store on ``app.state`` is a permissive subclass of + ``MemoryThreadMetaStore`` so tests can drive ``/api/threads`` + end-to-end and pre-seed legacy records via the underlying BaseStore. + + Returns ``(app, store, checkpointer)`` for direct seeding/inspection. + """ + app = make_authed_test_app() + store = InMemoryStore() + checkpointer = InMemorySaver() + app.state.store = store + app.state.checkpointer = checkpointer + app.state.thread_store = _PermissiveThreadMetaStore(store) + app.include_router(threads.router) + return app, store, checkpointer def test_delete_thread_data_removes_thread_directory(tmp_path): @@ -136,3 +190,244 @@ def test_strip_reserved_metadata_empty_input(): def test_strip_reserved_metadata_strips_all_reserved_keys(): out = threads._strip_reserved_metadata({"user_id": "x", "keep": "me"}) assert out == {"keep": "me"} + + +# --------------------------------------------------------------------------- +# ISO 8601 timestamp contract (issue #2594) +# --------------------------------------------------------------------------- +# +# Threads endpoints document ``created_at`` / ``updated_at`` as ISO +# timestamps and that is the format LangGraph Platform uses +# (``langgraph_sdk.schema.Thread.created_at: datetime`` JSON-encodes to +# ISO 8601). The tests below pin that contract end-to-end and also +# exercise the ``coerce_iso`` healing path for legacy unix-timestamp +# records written by older Gateway versions. + + +def test_create_thread_returns_iso_timestamps() -> None: + app, _store, _checkpointer = _build_thread_app() + + with TestClient(app) as client: + response = client.post("/api/threads", json={"metadata": {}}) + + assert response.status_code == 200, response.text + body = response.json() + assert _ISO_TIMESTAMP_RE.match(body["created_at"]), body["created_at"] + assert _ISO_TIMESTAMP_RE.match(body["updated_at"]), body["updated_at"] + assert body["created_at"] == body["updated_at"] + + +def test_get_thread_returns_iso_for_legacy_unix_record() -> None: + """A thread record written by older versions stores ``time.time()`` + floats. ``get_thread`` must transparently surface them as ISO so the + frontend's ``new Date(...)`` parser does not break. + """ + app, store, checkpointer = _build_thread_app() + + legacy_thread_id = "legacy-thread" + legacy_ts = "1777252410.411327" + + async def _seed() -> None: + await store.aput( + THREADS_NS, + legacy_thread_id, + { + "thread_id": legacy_thread_id, + "status": "idle", + "created_at": legacy_ts, + "updated_at": legacy_ts, + "metadata": {}, + }, + ) + from langgraph.checkpoint.base import empty_checkpoint + + await checkpointer.aput( + {"configurable": {"thread_id": legacy_thread_id, "checkpoint_ns": ""}}, + empty_checkpoint(), + {"step": -1, "source": "input", "writes": None, "parents": {}}, + {}, + ) + + import asyncio + + asyncio.run(_seed()) + + with TestClient(app) as client: + response = client.get(f"/api/threads/{legacy_thread_id}") + + assert response.status_code == 200, response.text + body = response.json() + assert _ISO_TIMESTAMP_RE.match(body["created_at"]), body["created_at"] + assert _ISO_TIMESTAMP_RE.match(body["updated_at"]), body["updated_at"] + + +def test_patch_thread_returns_iso_and_advances_updated_at() -> None: + app, store, _checkpointer = _build_thread_app() + thread_id = "patch-target" + + legacy_created = "1777000000.000000" + legacy_updated = "1777000000.000000" + + async def _seed() -> None: + await store.aput( + THREADS_NS, + thread_id, + { + "thread_id": thread_id, + "status": "idle", + "created_at": legacy_created, + "updated_at": legacy_updated, + "metadata": {"k": "v0"}, + }, + ) + + import asyncio + + asyncio.run(_seed()) + + with TestClient(app) as client: + response = client.patch(f"/api/threads/{thread_id}", json={"metadata": {"k": "v1"}}) + + assert response.status_code == 200, response.text + body = response.json() + assert _ISO_TIMESTAMP_RE.match(body["created_at"]), body["created_at"] + assert _ISO_TIMESTAMP_RE.match(body["updated_at"]), body["updated_at"] + # Patch issues a fresh ``updated_at`` via ``MemoryThreadMetaStore.update_metadata``, + # so it must be > the migrated legacy ``created_at`` (both ISO strings + # sort lexicographically by time when the format is consistent). + assert body["updated_at"] > body["created_at"] + assert body["metadata"] == {"k": "v1"} + + +def test_search_threads_normalizes_legacy_unix_seconds_to_iso() -> None: + """``MemoryThreadMetaStore`` may hold legacy ``time.time()`` floats + written by older Gateway versions. ``/search`` must surface them as + ISO via ``coerce_iso`` so the frontend's ``new Date(...)`` parser + does not break. + """ + app, store, _checkpointer = _build_thread_app() + + async def _seed() -> None: + # Legacy unix-second float (the literal value from issue #2594). + await store.aput( + THREADS_NS, + "legacy", + { + "thread_id": "legacy", + "status": "idle", + "created_at": 1777000000.0, + "updated_at": 1777000000.0, + "metadata": {}, + }, + ) + # Modern ISO string, slightly later. + await store.aput( + THREADS_NS, + "modern", + { + "thread_id": "modern", + "status": "idle", + "created_at": "2026-04-27T00:00:00+00:00", + "updated_at": "2026-04-27T00:00:00+00:00", + "metadata": {}, + }, + ) + + import asyncio + + asyncio.run(_seed()) + + with TestClient(app) as client: + response = client.post("/api/threads/search", json={"limit": 10}) + + assert response.status_code == 200, response.text + items = response.json() + assert {item["thread_id"] for item in items} == {"legacy", "modern"} + for item in items: + assert _ISO_TIMESTAMP_RE.match(item["created_at"]), item + assert _ISO_TIMESTAMP_RE.match(item["updated_at"]), item + + +def test_memory_thread_meta_store_writes_iso_on_create() -> None: + """``MemoryThreadMetaStore.create`` must emit ISO so newly created + threads serialize correctly without depending on the router's + ``coerce_iso`` heal path. + """ + import asyncio + + store = InMemoryStore() + repo = MemoryThreadMetaStore(store) + + async def _scenario() -> dict: + await repo.create("fresh", user_id=None, metadata={"a": 1}) + record = (await store.aget(THREADS_NS, "fresh")).value + return record + + record = asyncio.run(_scenario()) + assert _ISO_TIMESTAMP_RE.match(record["created_at"]), record + assert _ISO_TIMESTAMP_RE.match(record["updated_at"]), record + + +def test_get_thread_state_returns_iso_for_legacy_checkpoint_metadata() -> None: + """Checkpoints written by older Gateway versions stored + ``created_at`` as a unix-second float in their metadata. The + ``/state`` endpoint must surface that value as ISO so the frontend's + ``new Date(...)`` parser does not break — same root cause as the + thread-record bug fixed in #2594, but on the checkpoint side. + """ + app, _store, checkpointer = _build_thread_app() + thread_id = "legacy-state" + + async def _seed() -> None: + from langgraph.checkpoint.base import empty_checkpoint + + await checkpointer.aput( + {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}, + empty_checkpoint(), + {"step": -1, "source": "input", "writes": None, "parents": {}, "created_at": 1777252410.411327}, + {}, + ) + + import asyncio + + asyncio.run(_seed()) + + with TestClient(app) as client: + response = client.get(f"/api/threads/{thread_id}/state") + + assert response.status_code == 200, response.text + body = response.json() + assert _ISO_TIMESTAMP_RE.match(body["created_at"]), body["created_at"] + assert _ISO_TIMESTAMP_RE.match(body["checkpoint"]["ts"]), body["checkpoint"] + + +def test_get_thread_history_returns_iso_for_legacy_checkpoint_metadata() -> None: + """``/history`` walks ``checkpointer.alist`` and emits one entry per + checkpoint. Each entry's ``created_at`` must come out as ISO even if + older checkpoints stored a unix-second float in their metadata. + """ + app, _store, checkpointer = _build_thread_app() + thread_id = "legacy-history" + + async def _seed() -> None: + from langgraph.checkpoint.base import empty_checkpoint + + await checkpointer.aput( + {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}, + empty_checkpoint(), + {"step": -1, "source": "input", "writes": None, "parents": {}, "created_at": 1777252410.411327}, + {}, + ) + + import asyncio + + asyncio.run(_seed()) + + with TestClient(app) as client: + response = client.post(f"/api/threads/{thread_id}/history", json={"limit": 10}) + + assert response.status_code == 200, response.text + entries = response.json() + assert entries, "expected at least one history entry" + for entry in entries: + assert _ISO_TIMESTAMP_RE.match(entry["created_at"]), entry diff --git a/backend/tests/test_utils_time.py b/backend/tests/test_utils_time.py new file mode 100644 index 000000000..d873876c2 --- /dev/null +++ b/backend/tests/test_utils_time.py @@ -0,0 +1,90 @@ +"""Tests for ``deerflow.utils.time``.""" + +from __future__ import annotations + +import re +from datetime import UTC, datetime, timedelta, timezone + +from deerflow.utils.time import coerce_iso, now_iso + +_ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}") + + +def test_now_iso_is_utc_iso8601() -> None: + value = now_iso() + assert _ISO_RE.match(value), value + parsed = datetime.fromisoformat(value) + assert parsed.tzinfo is not None + assert parsed.tzinfo.utcoffset(parsed) == UTC.utcoffset(parsed) + + +def test_coerce_iso_passes_iso_through() -> None: + iso = "2026-04-27T01:13:30.411334+00:00" + assert coerce_iso(iso) == iso + + +def test_coerce_iso_converts_unix_float_string() -> None: + legacy = "1777252410.411327" + out = coerce_iso(legacy) + assert _ISO_RE.match(out), out + # Round-trip: parsed timestamp matches the original epoch. + parsed = datetime.fromisoformat(out) + assert abs(parsed.timestamp() - 1777252410.411327) < 1e-3 + + +def test_coerce_iso_converts_unix_int_string() -> None: + out = coerce_iso("1700000000") + assert _ISO_RE.match(out), out + + +def test_coerce_iso_converts_numeric_types() -> None: + out_float = coerce_iso(1777252410.411327) + out_int = coerce_iso(1700000000) + assert _ISO_RE.match(out_float) + assert _ISO_RE.match(out_int) + + +def test_coerce_iso_handles_empty_and_none() -> None: + assert coerce_iso(None) == "" + assert coerce_iso("") == "" + + +def test_coerce_iso_does_not_misinterpret_short_numeric() -> None: + # A 4-digit year should never be parsed as a unix timestamp; only + # 10-digit unix-second strings match the legacy pattern. + assert coerce_iso("2026") == "2026" + + +def test_coerce_iso_handles_unparseable_string() -> None: + assert coerce_iso("not-a-timestamp") == "not-a-timestamp" + + +def test_coerce_iso_rejects_bool() -> None: + # ``bool`` is a subclass of ``int`` — must not be treated as epoch 0/1. + assert coerce_iso(True) == "True" + assert coerce_iso(False) == "False" + + +def test_coerce_iso_handles_tz_aware_datetime() -> None: + # str(datetime) would emit a space separator; coerce_iso must use ``T``. + dt = datetime(2026, 4, 27, 1, 13, 30, 411327, tzinfo=UTC) + out = coerce_iso(dt) + assert out == "2026-04-27T01:13:30.411327+00:00" + assert "T" in out and " " not in out + + +def test_coerce_iso_handles_tz_naive_datetime_as_utc() -> None: + dt = datetime(2026, 4, 27, 1, 13, 30, 411327) + out = coerce_iso(dt) + assert out == "2026-04-27T01:13:30.411327+00:00" + parsed = datetime.fromisoformat(out) + assert parsed.tzinfo is not None + assert parsed.utcoffset() == timedelta(0) + + +def test_coerce_iso_normalises_non_utc_datetime_to_utc() -> None: + # +08:00 wall-clock 09:13 == UTC 01:13. + plus_eight = timezone(timedelta(hours=8)) + dt = datetime(2026, 4, 27, 9, 13, 30, 411327, tzinfo=plus_eight) + out = coerce_iso(dt) + assert out == "2026-04-27T01:13:30.411327+00:00" From e543bbf5d6b657be05e90ca4264c98cc2c3add70 Mon Sep 17 00:00:00 2001 From: Hinotobi Date: Sat, 2 May 2026 15:19:28 +0800 Subject: [PATCH 10/11] [security] fix(upload): reject symlinked upload destinations (#2623) * fix: reject symlinked upload destinations * test: harden upload destination checks * fix: address PR feedback for #2623 * test: cover safe upload re-uploads * fix: preserve upload limit checks after rebase * fix(upload): stream safe HTTP upload writes --- backend/app/channels/manager.py | 13 ++- backend/app/gateway/routers/uploads.py | 48 +++++--- .../harness/deerflow/uploads/manager.py | 64 +++++++++++ .../tests/test_channel_file_attachments.py | 106 +++++++++++++++++- backend/tests/test_uploads_manager.py | 54 +++++++++ backend/tests/test_uploads_router.py | 100 +++++++++++++++++ 6 files changed, 369 insertions(+), 16 deletions(-) diff --git a/backend/app/channels/manager.py b/backend/app/channels/manager.py index c09b13173..349fa1bfe 100644 --- a/backend/app/channels/manager.py +++ b/backend/app/channels/manager.py @@ -420,7 +420,13 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic if not msg.files: return [] - from deerflow.uploads.manager import claim_unique_filename, ensure_uploads_dir, normalize_filename + from deerflow.uploads.manager import ( + UnsafeUploadPathError, + claim_unique_filename, + ensure_uploads_dir, + normalize_filename, + write_upload_file_no_symlink, + ) uploads_dir = ensure_uploads_dir(thread_id) seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()} @@ -471,7 +477,10 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic dest = uploads_dir / safe_name try: - dest.write_bytes(data) + dest = write_upload_file_no_symlink(uploads_dir, safe_name, data) + except UnsafeUploadPathError: + logger.warning("[Manager] skipping inbound file with unsafe destination: %s", safe_name) + continue except Exception: logger.exception("[Manager] failed to write inbound file: %s", dest) continue diff --git a/backend/app/gateway/routers/uploads.py b/backend/app/gateway/routers/uploads.py index 604a6e154..a4267f728 100644 --- a/backend/app/gateway/routers/uploads.py +++ b/backend/app/gateway/routers/uploads.py @@ -5,7 +5,7 @@ import os import stat from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile -from pydantic import BaseModel +from pydantic import BaseModel, Field from app.gateway.authz import require_permission from app.gateway.deps import get_config @@ -15,12 +15,14 @@ from deerflow.runtime.user_context import get_effective_user_id from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider from deerflow.uploads.manager import ( PathTraversalError, + UnsafeUploadPathError, delete_file_safe, enrich_file_listing, ensure_uploads_dir, get_uploads_dir, list_files_in_dir, normalize_filename, + open_upload_file_no_symlink, upload_artifact_url, upload_virtual_path, ) @@ -42,6 +44,7 @@ class UploadResponse(BaseModel): success: bool files: list[dict[str, str]] message: str + skipped_files: list[str] = Field(default_factory=list) class UploadLimits(BaseModel): @@ -116,17 +119,18 @@ def _cleanup_uploaded_paths(paths: list[os.PathLike[str] | str]) -> None: logger.warning("Failed to clean up upload path after rejected request: %s", path, exc_info=True) -async def _write_upload_file_streaming( +async def _write_upload_file_with_limits( file: UploadFile, - file_path: os.PathLike[str] | str, *, + uploads_dir: os.PathLike[str] | str, display_filename: str, max_single_file_size: int, max_total_size: int, total_size: int, -) -> tuple[int, int]: +) -> tuple[os.PathLike[str] | str, int, int]: file_size = 0 - with open(file_path, "wb") as output: + file_path, fh = open_upload_file_no_symlink(uploads_dir, display_filename) + try: while chunk := await file.read(UPLOAD_CHUNK_SIZE): file_size += len(chunk) total_size += len(chunk) @@ -134,8 +138,17 @@ async def _write_upload_file_streaming( raise HTTPException(status_code=413, detail=f"File too large: {display_filename}") if total_size > max_total_size: raise HTTPException(status_code=413, detail="Total upload size too large") - output.write(chunk) - return file_size, total_size + fh.write(chunk) + except Exception: + fh.close() + try: + os.unlink(file_path) + except FileNotFoundError: + pass + raise + else: + fh.close() + return file_path, file_size, total_size def _auto_convert_documents_enabled(app_config: AppConfig) -> bool: @@ -177,6 +190,7 @@ async def upload_files( uploaded_files = [] written_paths = [] sandbox_sync_targets = [] + skipped_files = [] total_size = 0 sandbox_provider = get_sandbox_provider() @@ -200,16 +214,15 @@ async def upload_files( continue try: - file_path = uploads_dir / safe_filename - written_paths.append(file_path) - file_size, total_size = await _write_upload_file_streaming( + file_path, file_size, total_size = await _write_upload_file_with_limits( file, - file_path, + uploads_dir=uploads_dir, display_filename=safe_filename, max_single_file_size=limits.max_file_size, max_total_size=limits.max_total_size, total_size=total_size, ) + written_paths.append(file_path) virtual_path = upload_virtual_path(safe_filename) @@ -246,6 +259,10 @@ async def upload_files( except HTTPException as e: _cleanup_uploaded_paths(written_paths) raise e + except UnsafeUploadPathError as e: + logger.warning("Skipping upload with unsafe destination %s: %s", file.filename, e) + skipped_files.append(safe_filename) + continue except Exception as e: logger.error(f"Failed to upload {file.filename}: {e}") _cleanup_uploaded_paths(written_paths) @@ -256,10 +273,15 @@ async def upload_files( _make_file_sandbox_writable(file_path) sandbox.update_file(virtual_path, file_path.read_bytes()) + message = f"Successfully uploaded {len(uploaded_files)} file(s)" + if skipped_files: + message += f"; skipped {len(skipped_files)} unsafe file(s)" + return UploadResponse( - success=True, + success=not skipped_files, files=uploaded_files, - message=f"Successfully uploaded {len(uploaded_files)} file(s)", + message=message, + skipped_files=skipped_files, ) diff --git a/backend/packages/harness/deerflow/uploads/manager.py b/backend/packages/harness/deerflow/uploads/manager.py index c36151b38..1a1b63f09 100644 --- a/backend/packages/harness/deerflow/uploads/manager.py +++ b/backend/packages/harness/deerflow/uploads/manager.py @@ -4,8 +4,10 @@ Pure business logic — no FastAPI/HTTP dependencies. Both Gateway and Client delegate to these functions. """ +import errno import os import re +import stat from pathlib import Path from urllib.parse import quote @@ -17,6 +19,10 @@ class PathTraversalError(ValueError): """Raised when a path escapes its allowed base directory.""" +class UnsafeUploadPathError(ValueError): + """Raised when an upload destination is not a safe regular file path.""" + + # thread_id must be alphanumeric, hyphens, underscores, or dots only. _SAFE_THREAD_ID = re.compile(r"^[a-zA-Z0-9._-]+$") @@ -109,6 +115,64 @@ def validate_path_traversal(path: Path, base: Path) -> None: raise PathTraversalError("Path traversal detected") from None +def open_upload_file_no_symlink(base_dir: Path, filename: str) -> tuple[Path, object]: + """Open an upload destination for safe streaming writes. + + Upload directories may be mounted into local sandboxes. A sandbox process can + therefore leave a symlink at a future upload filename. Normal ``Path.write_bytes`` + follows that link and can overwrite files outside the uploads directory with + gateway privileges. This helper rejects symlink destinations and uses + ``O_NOFOLLOW`` where available so the final path component cannot be raced into + a symlink between validation and open. + """ + safe_name = normalize_filename(filename) + dest = base_dir / safe_name + + try: + st = os.lstat(dest) + except FileNotFoundError: + st = None + + if st is not None and not stat.S_ISREG(st.st_mode): + raise UnsafeUploadPathError(f"Upload destination is not a regular file: {safe_name}") + + validate_path_traversal(dest, base_dir) + + if not hasattr(os, "O_NOFOLLOW"): + raise UnsafeUploadPathError("Upload writes require O_NOFOLLOW support") + + flags = os.O_WRONLY | os.O_CREAT | os.O_NOFOLLOW + if hasattr(os, "O_NONBLOCK"): + flags |= os.O_NONBLOCK + + try: + fd = os.open(dest, flags, 0o600) + except OSError as exc: + if exc.errno in {errno.ELOOP, errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}: + raise UnsafeUploadPathError(f"Unsafe upload destination: {safe_name}") from exc + raise + + try: + opened_stat = os.fstat(fd) + if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink != 1: + raise UnsafeUploadPathError(f"Upload destination is not an exclusive regular file: {safe_name}") + os.ftruncate(fd, 0) + fh = os.fdopen(fd, "wb") + fd = -1 + finally: + if fd >= 0: + os.close(fd) + return dest, fh + + +def write_upload_file_no_symlink(base_dir: Path, filename: str, data: bytes) -> Path: + """Write upload bytes without following a pre-existing destination symlink.""" + dest, fh = open_upload_file_no_symlink(base_dir, filename) + with fh: + fh.write(data) + return dest + + def list_files_in_dir(directory: Path) -> dict: """List files (not directories) in *directory*. diff --git a/backend/tests/test_channel_file_attachments.py b/backend/tests/test_channel_file_attachments.py index 7273b1c82..aa2e9b004 100644 --- a/backend/tests/test_channel_file_attachments.py +++ b/backend/tests/test_channel_file_attachments.py @@ -3,11 +3,12 @@ from __future__ import annotations import asyncio +import os from pathlib import Path from unittest.mock import MagicMock, patch from app.channels.base import Channel -from app.channels.message_bus import MessageBus, OutboundMessage, ResolvedAttachment +from app.channels.message_bus import InboundMessage, MessageBus, OutboundMessage, ResolvedAttachment def _run(coro): @@ -248,6 +249,109 @@ class TestResolveAttachments: assert result[0].filename == "data.csv" +# --------------------------------------------------------------------------- +# Inbound file ingestion tests +# --------------------------------------------------------------------------- + + +class TestInboundFileIngestion: + def test_rejects_preexisting_symlink_destination(self, tmp_path): + from app.channels import manager + + uploads_dir = tmp_path / "uploads" + uploads_dir.mkdir() + outside_file = tmp_path / "outside-created.txt" + (uploads_dir / "victim.txt").symlink_to(outside_file) + + msg = InboundMessage( + channel_name="test-channel", + chat_id="chat-1", + user_id="user-1", + text="see attachment", + files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}], + ) + + async def fake_reader(file_info, client): + return b"attacker data" + + with ( + patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir), + patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False), + ): + result = _run(manager._ingest_inbound_files("thread-1", msg)) + + assert result == [] + assert not outside_file.exists() + assert (uploads_dir / "victim.txt").is_symlink() + + def test_rejects_dangling_symlink_destination(self, tmp_path): + from app.channels import manager + + uploads_dir = tmp_path / "uploads" + uploads_dir.mkdir() + missing_target = tmp_path / "missing-created.txt" + (uploads_dir / "victim.txt").symlink_to(missing_target) + + msg = InboundMessage( + channel_name="test-channel", + chat_id="chat-1", + user_id="user-1", + text="see attachment", + files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}], + ) + + async def fake_reader(file_info, client): + return b"attacker data" + + with ( + patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir), + patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False), + ): + result = _run(manager._ingest_inbound_files("thread-1", msg)) + + assert result == [] + assert not missing_target.exists() + assert (uploads_dir / "victim.txt").is_symlink() + + def test_hardlinked_existing_file_is_not_overwritten(self, tmp_path): + from app.channels import manager + + uploads_dir = tmp_path / "uploads" + uploads_dir.mkdir() + outside_file = tmp_path / "outside-created.txt" + outside_file.write_text("protected", encoding="utf-8") + os.link(outside_file, uploads_dir / "victim.txt") + + msg = InboundMessage( + channel_name="test-channel", + chat_id="chat-1", + user_id="user-1", + text="see attachment", + files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}], + ) + + async def fake_reader(file_info, client): + return b"new attachment data" + + with ( + patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir), + patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False), + ): + result = _run(manager._ingest_inbound_files("thread-1", msg)) + + assert result == [ + { + "filename": "victim_1.txt", + "size": len(b"new attachment data"), + "path": "/mnt/user-data/uploads/victim_1.txt", + "is_image": False, + } + ] + assert outside_file.read_text(encoding="utf-8") == "protected" + assert (uploads_dir / "victim.txt").read_text(encoding="utf-8") == "protected" + assert (uploads_dir / "victim_1.txt").read_bytes() == b"new attachment data" + + # --------------------------------------------------------------------------- # Channel base class _on_outbound with attachments # --------------------------------------------------------------------------- diff --git a/backend/tests/test_uploads_manager.py b/backend/tests/test_uploads_manager.py index 64964c0b0..2cf1ae7fb 100644 --- a/backend/tests/test_uploads_manager.py +++ b/backend/tests/test_uploads_manager.py @@ -1,14 +1,20 @@ """Tests for deerflow.uploads.manager — shared upload management logic.""" +import errno +import os +from unittest.mock import patch + import pytest from deerflow.uploads.manager import ( PathTraversalError, + UnsafeUploadPathError, claim_unique_filename, delete_file_safe, list_files_in_dir, normalize_filename, validate_path_traversal, + write_upload_file_no_symlink, ) # --------------------------------------------------------------------------- @@ -97,6 +103,54 @@ class TestValidatePathTraversal: validate_path_traversal(link, tmp_path) +# --------------------------------------------------------------------------- +# write_upload_file_no_symlink +# --------------------------------------------------------------------------- + + +class TestWriteUploadFileNoSymlink: + def test_writes_new_file(self, tmp_path): + dest = write_upload_file_no_symlink(tmp_path, "notes.txt", b"hello") + + assert dest == tmp_path / "notes.txt" + assert dest.read_bytes() == b"hello" + + def test_overwrites_existing_regular_file_with_single_link(self, tmp_path): + dest = tmp_path / "notes.txt" + dest.write_bytes(b"old contents") + assert os.stat(dest).st_nlink == 1 + + result = write_upload_file_no_symlink(tmp_path, "notes.txt", b"new contents") + + assert result == dest + assert dest.read_bytes() == b"new contents" + assert os.stat(dest).st_nlink == 1 + + def test_fails_closed_without_no_follow_support(self, tmp_path, monkeypatch): + monkeypatch.delattr(os, "O_NOFOLLOW", raising=False) + + with pytest.raises(UnsafeUploadPathError, match="O_NOFOLLOW"): + write_upload_file_no_symlink(tmp_path, "notes.txt", b"hello") + + assert not (tmp_path / "notes.txt").exists() + + def test_open_uses_nonblocking_flag_when_available(self, tmp_path): + with patch("deerflow.uploads.manager.os.open", side_effect=OSError(errno.ENXIO, "no reader")) as open_mock: + with pytest.raises(UnsafeUploadPathError, match="Unsafe upload destination"): + write_upload_file_no_symlink(tmp_path, "pipe.txt", b"hello") + + flags = open_mock.call_args.args[1] + assert flags & os.O_NONBLOCK + + @pytest.mark.parametrize("open_errno", [errno.ENXIO, errno.EAGAIN]) + def test_nonblocking_special_file_open_errors_are_unsafe(self, tmp_path, open_errno): + with patch("deerflow.uploads.manager.os.open", side_effect=OSError(open_errno, "would block")): + with pytest.raises(UnsafeUploadPathError, match="Unsafe upload destination"): + write_upload_file_no_symlink(tmp_path, "pipe.txt", b"hello") + + assert not (tmp_path / "pipe.txt").exists() + + # --------------------------------------------------------------------------- # list_files_in_dir # --------------------------------------------------------------------------- diff --git a/backend/tests/test_uploads_router.py b/backend/tests/test_uploads_router.py index a2538ec40..4a778345f 100644 --- a/backend/tests/test_uploads_router.py +++ b/backend/tests/test_uploads_router.py @@ -1,4 +1,5 @@ import asyncio +import os import stat from io import BytesIO from pathlib import Path @@ -428,6 +429,105 @@ def test_upload_files_rejects_dotdot_and_dot_filenames(tmp_path): assert [f.name for f in thread_uploads_dir.iterdir()] == ["passwd"] +def test_upload_files_rejects_preexisting_symlink_destination(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + outside_file = tmp_path / "outside.txt" + outside_file.write_text("protected", encoding="utf-8") + (thread_uploads_dir / "victim.txt").symlink_to(outside_file) + + provider = MagicMock() + provider.uses_thread_data_mounts = True + + with ( + patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=provider), + ): + file = UploadFile(filename="victim.txt", file=BytesIO(b"attacker upload")) + result = asyncio.run(uploads.upload_files("thread-local", files=[file])) + + assert result.success is False + assert result.files == [] + assert result.skipped_files == ["victim.txt"] + assert "skipped 1 unsafe file" in result.message + assert outside_file.read_text(encoding="utf-8") == "protected" + assert (thread_uploads_dir / "victim.txt").is_symlink() + + +def test_upload_files_rejects_dangling_symlink_destination(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + missing_target = tmp_path / "missing-target.txt" + (thread_uploads_dir / "victim.txt").symlink_to(missing_target) + + provider = MagicMock() + provider.uses_thread_data_mounts = True + + with ( + patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=provider), + ): + file = UploadFile(filename="victim.txt", file=BytesIO(b"attacker upload")) + result = asyncio.run(uploads.upload_files("thread-local", files=[file])) + + assert result.success is False + assert result.files == [] + assert result.skipped_files == ["victim.txt"] + assert not missing_target.exists() + assert (thread_uploads_dir / "victim.txt").is_symlink() + + +def test_upload_files_rejects_hardlinked_destination_without_truncating(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + outside_file = tmp_path / "outside.txt" + outside_file.write_text("protected", encoding="utf-8") + os.link(outside_file, thread_uploads_dir / "victim.txt") + + provider = MagicMock() + provider.uses_thread_data_mounts = True + + with ( + patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=provider), + ): + file = UploadFile(filename="victim.txt", file=BytesIO(b"attacker upload")) + result = asyncio.run(uploads.upload_files("thread-local", files=[file])) + + assert result.success is False + assert result.files == [] + assert result.skipped_files == ["victim.txt"] + assert outside_file.read_text(encoding="utf-8") == "protected" + assert (thread_uploads_dir / "victim.txt").read_text(encoding="utf-8") == "protected" + + +def test_upload_files_overwrites_existing_regular_file(tmp_path): + thread_uploads_dir = tmp_path / "uploads" + thread_uploads_dir.mkdir(parents=True) + existing_file = thread_uploads_dir / "notes.txt" + existing_file.write_bytes(b"old upload") + assert existing_file.stat().st_nlink == 1 + + provider = MagicMock() + provider.uses_thread_data_mounts = True + + with ( + patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), + patch.object(uploads, "get_sandbox_provider", return_value=provider), + ): + file = UploadFile(filename="notes.txt", file=BytesIO(b"new upload")) + result = asyncio.run(uploads.upload_files("thread-local", files=[file])) + + assert result.success is True + assert [file_info["filename"] for file_info in result.files] == ["notes.txt"] + assert existing_file.read_bytes() == b"new upload" + assert existing_file.stat().st_nlink == 1 + + def test_delete_uploaded_file_removes_generated_markdown_companion(tmp_path): thread_uploads_dir = tmp_path / "uploads" thread_uploads_dir.mkdir(parents=True) From 44ab21fc44981eb2de9e4c0bccd2039c5195715c Mon Sep 17 00:00:00 2001 From: wanxsb Date: Sat, 2 May 2026 16:22:35 +0800 Subject: [PATCH 11/11] feat(community): add Serper web search provider (#2630) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(community): add Serper web search provider Add a new community search provider backed by the Serper Google Search API (https://serper.dev). Serper returns real-time Google results via a simple JSON API and requires only an API key — no extra Python package. Changes: - backend/packages/harness/deerflow/community/serper/__init__.py - backend/packages/harness/deerflow/community/serper/tools.py Implements web_search_tool using httpx (already a project dependency). API key is read from config.yaml `api_key` field or SERPER_API_KEY env var. Follows the same interface / output shape as the existing ddg_search provider. Exposes max_results parameter (default 5) with config override logic. - backend/tests/test_serper_tools.py Unit tests covering API key resolution, config overrides, HTTP errors, empty results, and parameter passing. - config.example.yaml: add commented-out Serper example alongside other providers - .env.example: add SERPER_API_KEY placeholder Co-Authored-By: Claude Sonnet 4.6 * Fix the lint error * Fix the lint error --------- Co-authored-by: Claude Sonnet 4.6 Co-authored-by: Willem Jiang --- .env.example | 3 + .../deerflow/community/serper/__init__.py | 3 + .../deerflow/community/serper/tools.py | 95 ++++++ backend/tests/test_serper_tools.py | 308 ++++++++++++++++++ config.example.yaml | 10 + 5 files changed, 419 insertions(+) create mode 100644 backend/packages/harness/deerflow/community/serper/__init__.py create mode 100644 backend/packages/harness/deerflow/community/serper/tools.py create mode 100644 backend/tests/test_serper_tools.py diff --git a/.env.example b/.env.example index f443818b3..41d87a8c7 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,6 @@ +# Serper API Key (Google Search) - https://serper.dev +SERPER_API_KEY=your-serper-api-key + # TAVILY API Key TAVILY_API_KEY=your-tavily-api-key diff --git a/backend/packages/harness/deerflow/community/serper/__init__.py b/backend/packages/harness/deerflow/community/serper/__init__.py new file mode 100644 index 000000000..876167859 --- /dev/null +++ b/backend/packages/harness/deerflow/community/serper/__init__.py @@ -0,0 +1,3 @@ +from .tools import web_search_tool + +__all__ = ["web_search_tool"] diff --git a/backend/packages/harness/deerflow/community/serper/tools.py b/backend/packages/harness/deerflow/community/serper/tools.py new file mode 100644 index 000000000..1cad11fb8 --- /dev/null +++ b/backend/packages/harness/deerflow/community/serper/tools.py @@ -0,0 +1,95 @@ +""" +Web Search Tool - Search the web using Serper (Google Search API). + +Serper provides real-time Google Search results via a JSON API. +An API key is required. Sign up at https://serper.dev to get one. +""" + +import json +import logging +import os + +import httpx +from langchain.tools import tool + +from deerflow.config import get_app_config + +logger = logging.getLogger(__name__) + +_SERPER_ENDPOINT = "https://google.serper.dev/search" +_api_key_warned = False + + +def _get_api_key() -> str | None: + config = get_app_config().get_tool_config("web_search") + if config is not None: + api_key = config.model_extra.get("api_key") + if isinstance(api_key, str) and api_key.strip(): + return api_key + return os.getenv("SERPER_API_KEY") + + +@tool("web_search", parse_docstring=True) +def web_search_tool(query: str, max_results: int = 5) -> str: + """Search the web for information using Google Search via Serper. + + Args: + query: Search keywords describing what you want to find. Be specific for better results. + max_results: Maximum number of search results to return. Default is 5. + """ + global _api_key_warned + + config = get_app_config().get_tool_config("web_search") + if config is not None and "max_results" in config.model_extra: + max_results = config.model_extra.get("max_results", max_results) + + api_key = _get_api_key() + if not api_key: + if not _api_key_warned: + _api_key_warned = True + logger.warning("Serper API key is not set. Set SERPER_API_KEY in your environment or provide api_key in config.yaml. Sign up at https://serper.dev") + return json.dumps( + {"error": "SERPER_API_KEY is not configured", "query": query}, + ensure_ascii=False, + ) + + headers = { + "X-API-KEY": api_key, + "Content-Type": "application/json", + } + payload = {"q": query, "num": max_results} + + try: + with httpx.Client(timeout=30) as client: + response = client.post(_SERPER_ENDPOINT, headers=headers, json=payload) + response.raise_for_status() + data = response.json() + except httpx.HTTPStatusError as e: + logger.error(f"Serper API returned HTTP {e.response.status_code}: {e.response.text}") + return json.dumps( + {"error": f"Serper API error: HTTP {e.response.status_code}", "query": query}, + ensure_ascii=False, + ) + except Exception as e: + logger.error(f"Serper search failed: {type(e).__name__}: {e}") + return json.dumps({"error": str(e), "query": query}, ensure_ascii=False) + + organic = data.get("organic", []) + if not organic: + return json.dumps({"error": "No results found", "query": query}, ensure_ascii=False) + + normalized_results = [ + { + "title": r.get("title", ""), + "url": r.get("link", ""), + "content": r.get("snippet", ""), + } + for r in organic[:max_results] + ] + + output = { + "query": query, + "total_results": len(normalized_results), + "results": normalized_results, + } + return json.dumps(output, indent=2, ensure_ascii=False) diff --git a/backend/tests/test_serper_tools.py b/backend/tests/test_serper_tools.py new file mode 100644 index 000000000..2e53b0351 --- /dev/null +++ b/backend/tests/test_serper_tools.py @@ -0,0 +1,308 @@ +"""Unit tests for the Serper community web search tool.""" + +import json +from unittest.mock import MagicMock, patch + +import httpx +import pytest + + +@pytest.fixture(autouse=True) +def reset_api_key_warned(): + """Reset the module-level warning flag before each test.""" + import deerflow.community.serper.tools as serper_mod + + serper_mod._api_key_warned = False + yield + serper_mod._api_key_warned = False + + +@pytest.fixture +def mock_config_with_key(): + with patch("deerflow.community.serper.tools.get_app_config") as mock: + tool_config = MagicMock() + tool_config.model_extra = {"api_key": "test-serper-key", "max_results": 5} + mock.return_value.get_tool_config.return_value = tool_config + yield mock + + +@pytest.fixture +def mock_config_no_key(): + with patch("deerflow.community.serper.tools.get_app_config") as mock: + tool_config = MagicMock() + tool_config.model_extra = {} + mock.return_value.get_tool_config.return_value = tool_config + yield mock + + +def _make_serper_response(organic: list) -> MagicMock: + mock_resp = MagicMock() + mock_resp.json.return_value = {"organic": organic} + mock_resp.raise_for_status = MagicMock() + return mock_resp + + +class TestGetApiKey: + def test_returns_config_key_when_present(self): + with patch("deerflow.community.serper.tools.get_app_config") as mock: + tool_config = MagicMock() + tool_config.model_extra = {"api_key": "from-config"} + mock.return_value.get_tool_config.return_value = tool_config + + from deerflow.community.serper.tools import _get_api_key + + assert _get_api_key() == "from-config" + + def test_falls_back_to_env_when_config_key_empty(self): + with patch("deerflow.community.serper.tools.get_app_config") as mock: + tool_config = MagicMock() + tool_config.model_extra = {"api_key": ""} + mock.return_value.get_tool_config.return_value = tool_config + with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}): + from deerflow.community.serper.tools import _get_api_key + + assert _get_api_key() == "env-key" + + def test_falls_back_to_env_when_config_key_whitespace(self): + with patch("deerflow.community.serper.tools.get_app_config") as mock: + tool_config = MagicMock() + tool_config.model_extra = {"api_key": " "} + mock.return_value.get_tool_config.return_value = tool_config + with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}): + from deerflow.community.serper.tools import _get_api_key + + assert _get_api_key() == "env-key" + + def test_falls_back_to_env_when_config_key_null(self): + with patch("deerflow.community.serper.tools.get_app_config") as mock: + tool_config = MagicMock() + tool_config.model_extra = {"api_key": None} + mock.return_value.get_tool_config.return_value = tool_config + with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}): + from deerflow.community.serper.tools import _get_api_key + + assert _get_api_key() == "env-key" + + def test_falls_back_to_env_when_no_config(self): + with patch("deerflow.community.serper.tools.get_app_config") as mock: + mock.return_value.get_tool_config.return_value = None + with patch.dict("os.environ", {"SERPER_API_KEY": "env-only"}): + from deerflow.community.serper.tools import _get_api_key + + assert _get_api_key() == "env-only" + + def test_returns_none_when_no_key_anywhere(self): + with patch("deerflow.community.serper.tools.get_app_config") as mock: + mock.return_value.get_tool_config.return_value = None + with patch.dict("os.environ", {}, clear=True): + import os + + os.environ.pop("SERPER_API_KEY", None) + from deerflow.community.serper.tools import _get_api_key + + assert _get_api_key() is None + + +class TestWebSearchTool: + def test_basic_search_returns_normalized_results(self, mock_config_with_key): + organic = [ + {"title": "Result 1", "link": "https://example.com/1", "snippet": "Snippet 1"}, + {"title": "Result 2", "link": "https://example.com/2", "snippet": "Snippet 2"}, + ] + mock_resp = _make_serper_response(organic) + + with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls: + mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp + + from deerflow.community.serper.tools import web_search_tool + + result = web_search_tool.invoke({"query": "python tutorial"}) + parsed = json.loads(result) + + assert parsed["query"] == "python tutorial" + assert parsed["total_results"] == 2 + assert parsed["results"][0]["title"] == "Result 1" + assert parsed["results"][0]["url"] == "https://example.com/1" + assert parsed["results"][0]["content"] == "Snippet 1" + + def test_respects_max_results_from_config(self, mock_config_with_key): + mock_config_with_key.return_value.get_tool_config.return_value.model_extra = { + "api_key": "test-key", + "max_results": 3, + } + organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)] + mock_resp = _make_serper_response(organic) + + with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls: + mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp + + from deerflow.community.serper.tools import web_search_tool + + result = web_search_tool.invoke({"query": "test"}) + parsed = json.loads(result) + + assert parsed["total_results"] == 3 + assert len(parsed["results"]) == 3 + + def test_max_results_parameter_accepted(self, mock_config_no_key): + """Tool accepts max_results as a call parameter when config does not override it.""" + organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)] + mock_resp = _make_serper_response(organic) + + with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}): + with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls: + mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp + + from deerflow.community.serper.tools import web_search_tool + + result = web_search_tool.invoke({"query": "test", "max_results": 2}) + parsed = json.loads(result) + + assert parsed["total_results"] == 2 + + def test_config_max_results_overrides_parameter(self): + """Config max_results overrides the parameter passed at call time, matching ddg_search behaviour.""" + with patch("deerflow.community.serper.tools.get_app_config") as mock: + tool_config = MagicMock() + tool_config.model_extra = {"api_key": "test-key", "max_results": 3} + mock.return_value.get_tool_config.return_value = tool_config + + organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)] + mock_resp = _make_serper_response(organic) + + with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls: + mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp + + from deerflow.community.serper.tools import web_search_tool + + result = web_search_tool.invoke({"query": "test", "max_results": 8}) + parsed = json.loads(result) + + assert parsed["total_results"] == 3 + + def test_empty_organic_returns_error_json(self, mock_config_with_key): + """Empty organic list returns structured error, matching ddg_search convention.""" + mock_resp = _make_serper_response([]) + + with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls: + mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp + + from deerflow.community.serper.tools import web_search_tool + + result = web_search_tool.invoke({"query": "no results"}) + parsed = json.loads(result) + + assert "error" in parsed + assert parsed["error"] == "No results found" + assert parsed["query"] == "no results" + + def test_missing_api_key_returns_error_json(self, mock_config_no_key): + with patch.dict("os.environ", {}, clear=True): + import os + + os.environ.pop("SERPER_API_KEY", None) + + from deerflow.community.serper.tools import web_search_tool + + result = web_search_tool.invoke({"query": "test"}) + parsed = json.loads(result) + + assert "error" in parsed + assert "SERPER_API_KEY" in parsed["error"] + + def test_missing_api_key_logs_warning_once(self, mock_config_no_key, caplog): + import logging + + with patch.dict("os.environ", {}, clear=True): + import os + + os.environ.pop("SERPER_API_KEY", None) + + from deerflow.community.serper.tools import web_search_tool + + with caplog.at_level(logging.WARNING, logger="deerflow.community.serper.tools"): + web_search_tool.invoke({"query": "q1"}) + web_search_tool.invoke({"query": "q2"}) + + warnings = [r for r in caplog.records if r.levelno == logging.WARNING] + assert len(warnings) == 1 + + def test_http_error_returns_structured_error(self, mock_config_with_key): + mock_error_response = MagicMock() + mock_error_response.status_code = 403 + mock_error_response.text = "Forbidden" + + with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls: + mock_client_cls.return_value.__enter__.return_value.post.side_effect = httpx.HTTPStatusError("403", request=MagicMock(), response=mock_error_response) + + from deerflow.community.serper.tools import web_search_tool + + result = web_search_tool.invoke({"query": "test"}) + parsed = json.loads(result) + + assert "error" in parsed + assert "403" in parsed["error"] + + def test_network_exception_returns_error_json(self, mock_config_with_key): + with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls: + mock_client_cls.return_value.__enter__.return_value.post.side_effect = Exception("timeout") + + from deerflow.community.serper.tools import web_search_tool + + result = web_search_tool.invoke({"query": "test"}) + parsed = json.loads(result) + + assert "error" in parsed + + def test_sends_correct_headers_and_payload(self, mock_config_with_key): + organic = [{"title": "T", "link": "https://x.com", "snippet": "S"}] + mock_resp = _make_serper_response(organic) + + with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls: + mock_post = mock_client_cls.return_value.__enter__.return_value.post + mock_post.return_value = mock_resp + + from deerflow.community.serper.tools import web_search_tool + + web_search_tool.invoke({"query": "hello world"}) + + call_kwargs = mock_post.call_args + headers = call_kwargs.kwargs["headers"] + payload = call_kwargs.kwargs["json"] + + assert headers["X-API-KEY"] == "test-serper-key" + assert payload["q"] == "hello world" + assert payload["num"] == 5 + + def test_uses_env_key_when_config_absent(self): + with patch("deerflow.community.serper.tools.get_app_config") as mock: + mock.return_value.get_tool_config.return_value = None + with patch.dict("os.environ", {"SERPER_API_KEY": "env-only-key"}): + organic = [{"title": "T", "link": "https://x.com", "snippet": "S"}] + mock_resp = _make_serper_response(organic) + + with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls: + mock_post = mock_client_cls.return_value.__enter__.return_value.post + mock_post.return_value = mock_resp + + from deerflow.community.serper.tools import web_search_tool + + web_search_tool.invoke({"query": "env key test"}) + headers = mock_post.call_args.kwargs["headers"] + + assert headers["X-API-KEY"] == "env-only-key" + + def test_partial_fields_in_organic_result(self, mock_config_with_key): + """Missing title/link/snippet should default to empty string.""" + organic = [{}] + mock_resp = _make_serper_response(organic) + + with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls: + mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp + + from deerflow.community.serper.tools import web_search_tool + + result = web_search_tool.invoke({"query": "test"}) + parsed = json.loads(result) + + assert parsed["results"][0] == {"title": "", "url": "", "content": ""} diff --git a/config.example.yaml b/config.example.yaml index b16b4a6bb..7e282e46e 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -373,6 +373,16 @@ tools: use: deerflow.community.ddg_search.tools:web_search_tool max_results: 5 + # Web search tool (uses Serper - Google Search API, requires SERPER_API_KEY) + # Serper provides real-time Google Search results. Sign up at https://serper.dev + # Note: set SERPER_API_KEY in your environment before starting the app, or + # uncomment and fill in api_key below (the $VAR syntax is resolved at startup). + # - name: web_search + # group: web + # use: deerflow.community.serper.tools:web_search_tool + # max_results: 5 + # # api_key: $SERPER_API_KEY # Optional if SERPER_API_KEY env var is set + # Web search tool (requires Tavily API key) # - name: web_search # group: web