mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-11 11:13:51 +00:00
Merge branch 'main' into fix-2425
This commit is contained in:
commit
f77dcf9c23
@ -1,3 +1,6 @@
|
||||
# Serper API Key (Google Search) - https://serper.dev
|
||||
SERPER_API_KEY=your-serper-api-key
|
||||
|
||||
# TAVILY API Key
|
||||
TAVILY_API_KEY=your-tavily-api-key
|
||||
|
||||
|
||||
@ -251,7 +251,7 @@ See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
|
||||
|
||||
If you prefer running services locally:
|
||||
|
||||
Prerequisite: complete the "Configuration" steps above first (`make setup`). `make dev` requires a valid `config.yaml` in the project root (can be overridden via `DEER_FLOW_CONFIG_PATH`). Run `make doctor` to verify your setup before starting.
|
||||
Prerequisite: complete the "Configuration" steps above first (`make setup`). `make dev` requires a valid `config.yaml` in the project root. Set `DEER_FLOW_PROJECT_ROOT` to define that root explicitly, or `DEER_FLOW_CONFIG_PATH` to point at a specific config file. Runtime state defaults to `.deer-flow` under the project root and can be moved with `DEER_FLOW_HOME`; skills default to `skills/` under the project root and can be moved with `DEER_FLOW_SKILLS_PATH`. Run `make doctor` to verify your setup before starting.
|
||||
On Windows, run the local development flow from Git Bash. Native `cmd.exe` and PowerShell shells are not supported for the bash-based service scripts, and WSL is not guaranteed because some scripts rely on Git for Windows utilities such as `cygpath`.
|
||||
|
||||
1. **Check prerequisites**:
|
||||
|
||||
@ -194,7 +194,7 @@ make down # 停止并移除容器
|
||||
|
||||
如果你更希望直接在本地启动各个服务:
|
||||
|
||||
前提:先完成上面的“配置”步骤(`make config` 和模型 API key 配置)。`make dev` 需要有效配置文件,默认读取项目根目录下的 `config.yaml`,也可以通过 `DEER_FLOW_CONFIG_PATH` 覆盖。
|
||||
前提:先完成上面的“配置”步骤(`make config` 和模型 API key 配置)。`make dev` 需要有效配置文件,默认读取项目根目录下的 `config.yaml`。可以用 `DEER_FLOW_PROJECT_ROOT` 显式指定项目根目录,也可以用 `DEER_FLOW_CONFIG_PATH` 指向某个具体配置文件。运行期状态默认写到项目根目录下的 `.deer-flow`,可用 `DEER_FLOW_HOME` 覆盖;skills 默认读取项目根目录下的 `skills/`,可用 `DEER_FLOW_SKILLS_PATH` 覆盖。
|
||||
在 Windows 上,请使用 Git Bash 运行本地开发流程。基于 bash 的服务脚本不支持直接在原生 `cmd.exe` 或 PowerShell 中执行,且 WSL 也不保证可用,因为部分脚本依赖 Git for Windows 的 `cygpath` 等工具。
|
||||
|
||||
1. **检查依赖环境**:
|
||||
|
||||
@ -420,7 +420,13 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
|
||||
if not msg.files:
|
||||
return []
|
||||
|
||||
from deerflow.uploads.manager import claim_unique_filename, ensure_uploads_dir, normalize_filename
|
||||
from deerflow.uploads.manager import (
|
||||
UnsafeUploadPathError,
|
||||
claim_unique_filename,
|
||||
ensure_uploads_dir,
|
||||
normalize_filename,
|
||||
write_upload_file_no_symlink,
|
||||
)
|
||||
|
||||
uploads_dir = ensure_uploads_dir(thread_id)
|
||||
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
|
||||
@ -471,7 +477,10 @@ async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dic
|
||||
|
||||
dest = uploads_dir / safe_name
|
||||
try:
|
||||
dest.write_bytes(data)
|
||||
dest = write_upload_file_no_symlink(uploads_dir, safe_name, data)
|
||||
except UnsafeUploadPathError:
|
||||
logger.warning("[Manager] skipping inbound file with unsafe destination: %s", safe_name)
|
||||
continue
|
||||
except Exception:
|
||||
logger.exception("[Manager] failed to write inbound file: %s", dest)
|
||||
continue
|
||||
|
||||
@ -13,11 +13,11 @@ matching the LangGraph Platform wire format expected by the
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
@ -26,6 +26,7 @@ from app.gateway.utils import sanitize_log_param
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.utils.time import coerce_iso, now_iso
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/threads", tags=["threads"])
|
||||
@ -233,7 +234,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
checkpointer = get_checkpointer(request)
|
||||
thread_store = get_thread_store(request)
|
||||
thread_id = body.thread_id or str(uuid.uuid4())
|
||||
now = time.time()
|
||||
now = now_iso()
|
||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
||||
|
||||
@ -243,8 +244,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=existing_record.get("status", "idle"),
|
||||
created_at=str(existing_record.get("created_at", "")),
|
||||
updated_at=str(existing_record.get("updated_at", "")),
|
||||
created_at=coerce_iso(existing_record.get("created_at", "")),
|
||||
updated_at=coerce_iso(existing_record.get("updated_at", "")),
|
||||
metadata=existing_record.get("metadata", {}),
|
||||
)
|
||||
|
||||
@ -262,8 +263,6 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
# Write an empty checkpoint so state endpoints work immediately
|
||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
try:
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
|
||||
ckpt_metadata = {
|
||||
"step": -1,
|
||||
"source": "input",
|
||||
@ -281,8 +280,8 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status="idle",
|
||||
created_at=str(now),
|
||||
updated_at=str(now),
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
metadata=body.metadata,
|
||||
)
|
||||
|
||||
@ -307,8 +306,11 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
||||
ThreadResponse(
|
||||
thread_id=r["thread_id"],
|
||||
status=r.get("status", "idle"),
|
||||
created_at=r.get("created_at", ""),
|
||||
updated_at=r.get("updated_at", ""),
|
||||
# ``coerce_iso`` heals legacy unix-second values that
|
||||
# ``MemoryThreadMetaStore`` historically wrote with ``time.time()``;
|
||||
# SQL-backed rows already arrive as ISO strings and pass through.
|
||||
created_at=coerce_iso(r.get("created_at", "")),
|
||||
updated_at=coerce_iso(r.get("updated_at", "")),
|
||||
metadata=r.get("metadata", {}),
|
||||
values={"title": r["display_name"]} if r.get("display_name") else {},
|
||||
interrupts={},
|
||||
@ -340,8 +342,8 @@ async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Reques
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=record.get("status", "idle"),
|
||||
created_at=str(record.get("created_at", "")),
|
||||
updated_at=str(record.get("updated_at", "")),
|
||||
created_at=coerce_iso(record.get("created_at", "")),
|
||||
updated_at=coerce_iso(record.get("updated_at", "")),
|
||||
metadata=record.get("metadata", {}),
|
||||
)
|
||||
|
||||
@ -381,8 +383,8 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
record = {
|
||||
"thread_id": thread_id,
|
||||
"status": "idle",
|
||||
"created_at": ckpt_meta.get("created_at", ""),
|
||||
"updated_at": ckpt_meta.get("updated_at", ckpt_meta.get("created_at", "")),
|
||||
"created_at": coerce_iso(ckpt_meta.get("created_at", "")),
|
||||
"updated_at": coerce_iso(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))),
|
||||
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
|
||||
}
|
||||
|
||||
@ -396,8 +398,8 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=status,
|
||||
created_at=str(record.get("created_at", "")),
|
||||
updated_at=str(record.get("updated_at", "")),
|
||||
created_at=coerce_iso(record.get("created_at", "")),
|
||||
updated_at=coerce_iso(record.get("updated_at", "")),
|
||||
metadata=record.get("metadata", {}),
|
||||
values=serialize_channel_values(channel_values),
|
||||
)
|
||||
@ -448,10 +450,10 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
|
||||
values=values,
|
||||
next=next_tasks,
|
||||
metadata=metadata,
|
||||
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
|
||||
checkpoint={"id": checkpoint_id, "ts": coerce_iso(metadata.get("created_at", ""))},
|
||||
checkpoint_id=checkpoint_id,
|
||||
parent_checkpoint_id=parent_checkpoint_id,
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
created_at=coerce_iso(metadata.get("created_at", "")),
|
||||
tasks=tasks,
|
||||
)
|
||||
|
||||
@ -501,7 +503,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
channel_values.update(body.values)
|
||||
|
||||
checkpoint["channel_values"] = channel_values
|
||||
metadata["updated_at"] = time.time()
|
||||
metadata["updated_at"] = now_iso()
|
||||
|
||||
if body.as_node:
|
||||
metadata["source"] = "update"
|
||||
@ -542,7 +544,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
next=[],
|
||||
metadata=metadata,
|
||||
checkpoint_id=new_checkpoint_id,
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
created_at=coerce_iso(metadata.get("created_at", "")),
|
||||
)
|
||||
|
||||
|
||||
@ -609,7 +611,7 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
||||
parent_checkpoint_id=parent_id,
|
||||
metadata=user_meta,
|
||||
values=values,
|
||||
created_at=str(metadata.get("created_at", "")),
|
||||
created_at=coerce_iso(metadata.get("created_at", "")),
|
||||
next=next_tasks,
|
||||
)
|
||||
)
|
||||
|
||||
@ -5,7 +5,7 @@ import os
|
||||
import stat
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_config
|
||||
@ -15,12 +15,14 @@ from deerflow.runtime.user_context import get_effective_user_id
|
||||
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
|
||||
from deerflow.uploads.manager import (
|
||||
PathTraversalError,
|
||||
UnsafeUploadPathError,
|
||||
delete_file_safe,
|
||||
enrich_file_listing,
|
||||
ensure_uploads_dir,
|
||||
get_uploads_dir,
|
||||
list_files_in_dir,
|
||||
normalize_filename,
|
||||
open_upload_file_no_symlink,
|
||||
upload_artifact_url,
|
||||
upload_virtual_path,
|
||||
)
|
||||
@ -30,6 +32,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/threads/{thread_id}/uploads", tags=["uploads"])
|
||||
|
||||
UPLOAD_CHUNK_SIZE = 8192
|
||||
DEFAULT_MAX_FILES = 10
|
||||
DEFAULT_MAX_FILE_SIZE = 50 * 1024 * 1024
|
||||
DEFAULT_MAX_TOTAL_SIZE = 100 * 1024 * 1024
|
||||
|
||||
|
||||
class UploadResponse(BaseModel):
|
||||
"""Response model for file upload."""
|
||||
@ -37,6 +44,15 @@ class UploadResponse(BaseModel):
|
||||
success: bool
|
||||
files: list[dict[str, str]]
|
||||
message: str
|
||||
skipped_files: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class UploadLimits(BaseModel):
|
||||
"""Application-level upload limits exposed to clients."""
|
||||
|
||||
max_files: int
|
||||
max_file_size: int
|
||||
max_total_size: int
|
||||
|
||||
|
||||
def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
|
||||
@ -69,6 +85,72 @@ def _get_uploads_config_value(app_config: AppConfig, key: str, default: object)
|
||||
return getattr(uploads_cfg, key, default)
|
||||
|
||||
|
||||
def _get_upload_limit(app_config: AppConfig, key: str, default: int, *, legacy_key: str | None = None) -> int:
|
||||
try:
|
||||
value = _get_uploads_config_value(app_config, key, None)
|
||||
if value is None and legacy_key is not None:
|
||||
value = _get_uploads_config_value(app_config, legacy_key, None)
|
||||
if value is None:
|
||||
value = default
|
||||
limit = int(value)
|
||||
if limit <= 0:
|
||||
raise ValueError
|
||||
return limit
|
||||
except Exception:
|
||||
logger.warning("Invalid uploads.%s value; falling back to %d", key, default)
|
||||
return default
|
||||
|
||||
|
||||
def _get_upload_limits(app_config: AppConfig) -> UploadLimits:
|
||||
return UploadLimits(
|
||||
max_files=_get_upload_limit(app_config, "max_files", DEFAULT_MAX_FILES, legacy_key="max_file_count"),
|
||||
max_file_size=_get_upload_limit(app_config, "max_file_size", DEFAULT_MAX_FILE_SIZE, legacy_key="max_single_file_size"),
|
||||
max_total_size=_get_upload_limit(app_config, "max_total_size", DEFAULT_MAX_TOTAL_SIZE),
|
||||
)
|
||||
|
||||
|
||||
def _cleanup_uploaded_paths(paths: list[os.PathLike[str] | str]) -> None:
|
||||
for path in reversed(paths):
|
||||
try:
|
||||
os.unlink(path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.warning("Failed to clean up upload path after rejected request: %s", path, exc_info=True)
|
||||
|
||||
|
||||
async def _write_upload_file_with_limits(
|
||||
file: UploadFile,
|
||||
*,
|
||||
uploads_dir: os.PathLike[str] | str,
|
||||
display_filename: str,
|
||||
max_single_file_size: int,
|
||||
max_total_size: int,
|
||||
total_size: int,
|
||||
) -> tuple[os.PathLike[str] | str, int, int]:
|
||||
file_size = 0
|
||||
file_path, fh = open_upload_file_no_symlink(uploads_dir, display_filename)
|
||||
try:
|
||||
while chunk := await file.read(UPLOAD_CHUNK_SIZE):
|
||||
file_size += len(chunk)
|
||||
total_size += len(chunk)
|
||||
if file_size > max_single_file_size:
|
||||
raise HTTPException(status_code=413, detail=f"File too large: {display_filename}")
|
||||
if total_size > max_total_size:
|
||||
raise HTTPException(status_code=413, detail="Total upload size too large")
|
||||
fh.write(chunk)
|
||||
except Exception:
|
||||
fh.close()
|
||||
try:
|
||||
os.unlink(file_path)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
raise
|
||||
else:
|
||||
fh.close()
|
||||
return file_path, file_size, total_size
|
||||
|
||||
|
||||
def _auto_convert_documents_enabled(app_config: AppConfig) -> bool:
|
||||
"""Return whether automatic host-side document conversion is enabled.
|
||||
|
||||
@ -96,12 +178,20 @@ async def upload_files(
|
||||
if not files:
|
||||
raise HTTPException(status_code=400, detail="No files provided")
|
||||
|
||||
limits = _get_upload_limits(config)
|
||||
if len(files) > limits.max_files:
|
||||
raise HTTPException(status_code=413, detail=f"Too many files: maximum is {limits.max_files}")
|
||||
|
||||
try:
|
||||
uploads_dir = ensure_uploads_dir(thread_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
|
||||
uploaded_files = []
|
||||
written_paths = []
|
||||
sandbox_sync_targets = []
|
||||
skipped_files = []
|
||||
total_size = 0
|
||||
|
||||
sandbox_provider = get_sandbox_provider()
|
||||
sync_to_sandbox = not _uses_thread_data_mounts(sandbox_provider)
|
||||
@ -109,6 +199,8 @@ async def upload_files(
|
||||
if sync_to_sandbox:
|
||||
sandbox_id = sandbox_provider.acquire(thread_id)
|
||||
sandbox = sandbox_provider.get(sandbox_id)
|
||||
if sandbox is None:
|
||||
raise HTTPException(status_code=500, detail="Failed to acquire sandbox")
|
||||
auto_convert_documents = _auto_convert_documents_enabled(config)
|
||||
|
||||
for file in files:
|
||||
@ -122,35 +214,40 @@ async def upload_files(
|
||||
continue
|
||||
|
||||
try:
|
||||
content = await file.read()
|
||||
file_path = uploads_dir / safe_filename
|
||||
file_path.write_bytes(content)
|
||||
file_path, file_size, total_size = await _write_upload_file_with_limits(
|
||||
file,
|
||||
uploads_dir=uploads_dir,
|
||||
display_filename=safe_filename,
|
||||
max_single_file_size=limits.max_file_size,
|
||||
max_total_size=limits.max_total_size,
|
||||
total_size=total_size,
|
||||
)
|
||||
written_paths.append(file_path)
|
||||
|
||||
virtual_path = upload_virtual_path(safe_filename)
|
||||
|
||||
if sync_to_sandbox and sandbox is not None:
|
||||
_make_file_sandbox_writable(file_path)
|
||||
sandbox.update_file(virtual_path, content)
|
||||
if sync_to_sandbox:
|
||||
sandbox_sync_targets.append((file_path, virtual_path))
|
||||
|
||||
file_info = {
|
||||
"filename": safe_filename,
|
||||
"size": str(len(content)),
|
||||
"size": str(file_size),
|
||||
"path": str(sandbox_uploads / safe_filename),
|
||||
"virtual_path": virtual_path,
|
||||
"artifact_url": upload_artifact_url(thread_id, safe_filename),
|
||||
}
|
||||
|
||||
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
|
||||
logger.info(f"Saved file: {safe_filename} ({file_size} bytes) to {file_info['path']}")
|
||||
|
||||
file_ext = file_path.suffix.lower()
|
||||
if auto_convert_documents and file_ext in CONVERTIBLE_EXTENSIONS:
|
||||
md_path = await convert_file_to_markdown(file_path)
|
||||
if md_path:
|
||||
written_paths.append(md_path)
|
||||
md_virtual_path = upload_virtual_path(md_path.name)
|
||||
|
||||
if sync_to_sandbox and sandbox is not None:
|
||||
_make_file_sandbox_writable(md_path)
|
||||
sandbox.update_file(md_virtual_path, md_path.read_bytes())
|
||||
if sync_to_sandbox:
|
||||
sandbox_sync_targets.append((md_path, md_virtual_path))
|
||||
|
||||
file_info["markdown_file"] = md_path.name
|
||||
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
|
||||
@ -159,17 +256,46 @@ async def upload_files(
|
||||
|
||||
uploaded_files.append(file_info)
|
||||
|
||||
except HTTPException as e:
|
||||
_cleanup_uploaded_paths(written_paths)
|
||||
raise e
|
||||
except UnsafeUploadPathError as e:
|
||||
logger.warning("Skipping upload with unsafe destination %s: %s", file.filename, e)
|
||||
skipped_files.append(safe_filename)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload {file.filename}: {e}")
|
||||
_cleanup_uploaded_paths(written_paths)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
|
||||
|
||||
if sync_to_sandbox:
|
||||
for file_path, virtual_path in sandbox_sync_targets:
|
||||
_make_file_sandbox_writable(file_path)
|
||||
sandbox.update_file(virtual_path, file_path.read_bytes())
|
||||
|
||||
message = f"Successfully uploaded {len(uploaded_files)} file(s)"
|
||||
if skipped_files:
|
||||
message += f"; skipped {len(skipped_files)} unsafe file(s)"
|
||||
|
||||
return UploadResponse(
|
||||
success=True,
|
||||
success=not skipped_files,
|
||||
files=uploaded_files,
|
||||
message=f"Successfully uploaded {len(uploaded_files)} file(s)",
|
||||
message=message,
|
||||
skipped_files=skipped_files,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/limits", response_model=UploadLimits)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_upload_limits(
|
||||
thread_id: str,
|
||||
request: Request,
|
||||
config: AppConfig = Depends(get_config),
|
||||
) -> UploadLimits:
|
||||
"""Return upload limits used by the gateway for this thread."""
|
||||
return _get_upload_limits(config)
|
||||
|
||||
|
||||
@router.get("/list", response_model=dict)
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
|
||||
|
||||
@ -321,12 +321,16 @@ models:
|
||||
- `DEEPSEEK_API_KEY` - DeepSeek API key
|
||||
- `NOVITA_API_KEY` - Novita API key (OpenAI-compatible endpoint)
|
||||
- `TAVILY_API_KEY` - Tavily search API key
|
||||
- `DEER_FLOW_PROJECT_ROOT` - Project root for relative runtime paths
|
||||
- `DEER_FLOW_CONFIG_PATH` - Custom config file path
|
||||
- `DEER_FLOW_EXTENSIONS_CONFIG_PATH` - Custom extensions config file path
|
||||
- `DEER_FLOW_HOME` - Runtime state directory (defaults to `.deer-flow` under the project root)
|
||||
- `DEER_FLOW_SKILLS_PATH` - Skills directory when `skills.path` is omitted
|
||||
- `GATEWAY_ENABLE_DOCS` - Set to `false` to disable Swagger UI (`/docs`), ReDoc (`/redoc`), and OpenAPI schema (`/openapi.json`) endpoints (default: `true`)
|
||||
|
||||
## Configuration Location
|
||||
|
||||
The configuration file should be placed in the **project root directory** (`deer-flow/config.yaml`), not in the backend directory.
|
||||
The configuration file should be placed in the **project root directory** (`deer-flow/config.yaml`). Set `DEER_FLOW_PROJECT_ROOT` when the process may start from another working directory, or set `DEER_FLOW_CONFIG_PATH` to point at a specific file.
|
||||
|
||||
## Configuration Priority
|
||||
|
||||
@ -334,12 +338,12 @@ DeerFlow searches for configuration in this order:
|
||||
|
||||
1. Path specified in code via `config_path` argument
|
||||
2. Path from `DEER_FLOW_CONFIG_PATH` environment variable
|
||||
3. `config.yaml` in current working directory (typically `backend/` when running)
|
||||
4. `config.yaml` in parent directory (project root: `deer-flow/`)
|
||||
3. `config.yaml` under `DEER_FLOW_PROJECT_ROOT`, or under the current working directory when `DEER_FLOW_PROJECT_ROOT` is unset
|
||||
4. Legacy backend/repository-root locations for monorepo compatibility
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Place `config.yaml` in project root** - Not in `backend/` directory
|
||||
1. **Place `config.yaml` in project root** - Set `DEER_FLOW_PROJECT_ROOT` if the runtime starts elsewhere
|
||||
2. **Never commit `config.yaml`** - It's already in `.gitignore`
|
||||
3. **Use environment variables for secrets** - Don't hardcode API keys
|
||||
4. **Keep `config.example.yaml` updated** - Document all new options
|
||||
@ -350,7 +354,7 @@ DeerFlow searches for configuration in this order:
|
||||
|
||||
### "Config file not found"
|
||||
- Ensure `config.yaml` exists in the **project root** directory (`deer-flow/config.yaml`)
|
||||
- The backend searches parent directory by default, so root location is preferred
|
||||
- If the runtime starts outside the project root, set `DEER_FLOW_PROJECT_ROOT`
|
||||
- Alternatively, set `DEER_FLOW_CONFIG_PATH` environment variable to custom location
|
||||
|
||||
### "Invalid API key"
|
||||
@ -360,7 +364,7 @@ DeerFlow searches for configuration in this order:
|
||||
### "Skills not loading"
|
||||
- Check that `deer-flow/skills/` directory exists
|
||||
- Verify skills have valid `SKILL.md` files
|
||||
- Check `skills.path` configuration if using custom path
|
||||
- Check `skills.path` or `DEER_FLOW_SKILLS_PATH` if using a custom path
|
||||
|
||||
### "Docker sandbox fails to start"
|
||||
- Ensure Docker is running
|
||||
|
||||
@ -22,6 +22,8 @@ POST /api/threads/{thread_id}/uploads
|
||||
**请求体:** `multipart/form-data`
|
||||
- `files`: 一个或多个文件
|
||||
|
||||
网关会在应用层限制上传规模,默认最多 10 个文件、单文件 50 MiB、单次请求总计 100 MiB。可通过 `config.yaml` 的 `uploads.max_files`、`uploads.max_file_size`、`uploads.max_total_size` 调整;前端会读取同一组限制并在选择文件时提示,超过限制时后端返回 `413 Payload Too Large`。
|
||||
|
||||
**响应:**
|
||||
```json
|
||||
{
|
||||
@ -48,7 +50,23 @@ POST /api/threads/{thread_id}/uploads
|
||||
- `virtual_path`: Agent 在沙箱中使用的虚拟路径
|
||||
- `artifact_url`: 前端通过 HTTP 访问文件的 URL
|
||||
|
||||
### 2. 列出已上传文件
|
||||
### 2. 查询上传限制
|
||||
```
|
||||
GET /api/threads/{thread_id}/uploads/limits
|
||||
```
|
||||
|
||||
返回网关当前生效的上传限制,供前端在用户选择文件前提示和拦截。
|
||||
|
||||
**响应:**
|
||||
```json
|
||||
{
|
||||
"max_files": 10,
|
||||
"max_file_size": 52428800,
|
||||
"max_total_size": 104857600
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 列出已上传文件
|
||||
```
|
||||
GET /api/threads/{thread_id}/uploads/list
|
||||
```
|
||||
@ -71,7 +89,7 @@ GET /api/threads/{thread_id}/uploads/list
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 删除文件
|
||||
### 4. 删除文件
|
||||
```
|
||||
DELETE /api/threads/{thread_id}/uploads/{filename}
|
||||
```
|
||||
|
||||
@ -23,6 +23,9 @@ DeerFlow uses a YAML configuration file that should be placed in the **project r
|
||||
# Option A: Set environment variables (recommended)
|
||||
export OPENAI_API_KEY="your-key-here"
|
||||
|
||||
# Optional: pin the project root when running from another directory
|
||||
export DEER_FLOW_PROJECT_ROOT="/path/to/deer-flow"
|
||||
|
||||
# Option B: Edit config.yaml directly
|
||||
vim config.yaml # or your preferred editor
|
||||
```
|
||||
@ -35,17 +38,20 @@ DeerFlow uses a YAML configuration file that should be placed in the **project r
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Location**: `config.yaml` should be in `deer-flow/` (project root), not `deer-flow/backend/`
|
||||
- **Location**: `config.yaml` should be in `deer-flow/` (project root)
|
||||
- **Git**: `config.yaml` is automatically ignored by git (contains secrets)
|
||||
- **Priority**: If both `backend/config.yaml` and `../config.yaml` exist, backend version takes precedence
|
||||
- **Runtime root**: Set `DEER_FLOW_PROJECT_ROOT` if DeerFlow may start from outside the project root
|
||||
- **Runtime data**: State defaults to `.deer-flow` under the project root; set `DEER_FLOW_HOME` to move it
|
||||
- **Skills**: Skills default to `skills/` under the project root; set `DEER_FLOW_SKILLS_PATH` or `skills.path` to move them
|
||||
|
||||
## Configuration File Locations
|
||||
|
||||
The backend searches for `config.yaml` in this order:
|
||||
|
||||
1. `DEER_FLOW_CONFIG_PATH` environment variable (if set)
|
||||
2. `backend/config.yaml` (current directory when running from backend/)
|
||||
3. `deer-flow/config.yaml` (parent directory - **recommended location**)
|
||||
1. Explicit `config_path` argument from code
|
||||
2. `DEER_FLOW_CONFIG_PATH` environment variable (if set)
|
||||
3. `config.yaml` under `DEER_FLOW_PROJECT_ROOT`, or the current working directory when `DEER_FLOW_PROJECT_ROOT` is unset
|
||||
4. Legacy backend/repository-root locations for monorepo compatibility
|
||||
|
||||
**Recommended**: Place `config.yaml` in project root (`deer-flow/config.yaml`).
|
||||
|
||||
@ -77,8 +83,8 @@ python -c "from deerflow.config.app_config import AppConfig; print(AppConfig.res
|
||||
|
||||
If it can't find the config:
|
||||
1. Ensure you've copied `config.example.yaml` to `config.yaml`
|
||||
2. Verify you're in the correct directory
|
||||
3. Check the file exists: `ls -la ../config.yaml`
|
||||
2. Verify you're in the project root, or set `DEER_FLOW_PROJECT_ROOT`
|
||||
3. Check the file exists: `ls -la config.yaml`
|
||||
|
||||
### Permission denied
|
||||
|
||||
@ -89,4 +95,4 @@ chmod 600 ../config.yaml # Protect sensitive configuration
|
||||
## See Also
|
||||
|
||||
- [Configuration Guide](CONFIGURATION.md) - Detailed configuration options
|
||||
- [Architecture Overview](../CLAUDE.md) - System architecture
|
||||
- [Architecture Overview](../CLAUDE.md) - System architecture
|
||||
|
||||
@ -19,8 +19,6 @@ from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddlewar
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config.agents_config import load_agent_config, validate_agent_name
|
||||
from deerflow.config.app_config import AppConfig, get_app_config
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.config.summarization_config import get_summarization_config
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -52,7 +50,8 @@ def _resolve_model_name(requested_model_name: str | None = None, *, app_config:
|
||||
|
||||
def _create_summarization_middleware(*, app_config: AppConfig | None = None) -> DeerFlowSummarizationMiddleware | None:
|
||||
"""Create and configure the summarization middleware from config."""
|
||||
config = get_summarization_config()
|
||||
resolved_app_config = app_config or get_app_config()
|
||||
config = resolved_app_config.summarization
|
||||
|
||||
if not config.enabled:
|
||||
return None
|
||||
@ -73,9 +72,9 @@ def _create_summarization_middleware(*, app_config: AppConfig | None = None) ->
|
||||
# as middleware rather than lead_agent (SummarizationMiddleware is a
|
||||
# LangChain built-in, so we tag the model at creation time).
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=app_config)
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=resolved_app_config)
|
||||
else:
|
||||
model = create_chat_model(thinking_enabled=False, app_config=app_config)
|
||||
model = create_chat_model(thinking_enabled=False, app_config=resolved_app_config)
|
||||
model = model.with_config(tags=["middleware:summarize"])
|
||||
|
||||
# Prepare kwargs
|
||||
@ -92,18 +91,13 @@ def _create_summarization_middleware(*, app_config: AppConfig | None = None) ->
|
||||
kwargs["summary_prompt"] = config.summary_prompt
|
||||
|
||||
hooks: list[BeforeSummarizationHook] = []
|
||||
if get_memory_config().enabled:
|
||||
if resolved_app_config.memory.enabled:
|
||||
hooks.append(memory_flush_hook)
|
||||
|
||||
# The logic below relies on two assumptions holding true: this factory is
|
||||
# the sole entry point for DeerFlowSummarizationMiddleware, and the runtime
|
||||
# config is not expected to change after startup.
|
||||
try:
|
||||
resolved_app_config = app_config or get_app_config()
|
||||
skills_container_path = resolved_app_config.skills.container_path or "/mnt/skills"
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve skills container path; falling back to default")
|
||||
skills_container_path = "/mnt/skills"
|
||||
skills_container_path = resolved_app_config.skills.container_path or "/mnt/skills"
|
||||
|
||||
return DeerFlowSummarizationMiddleware(
|
||||
**kwargs,
|
||||
@ -279,10 +273,10 @@ def _build_middlewares(
|
||||
middlewares.append(TokenUsageMiddleware())
|
||||
|
||||
# Add TitleMiddleware
|
||||
middlewares.append(TitleMiddleware())
|
||||
middlewares.append(TitleMiddleware(app_config=resolved_app_config))
|
||||
|
||||
# Add MemoryMiddleware (after TitleMiddleware)
|
||||
middlewares.append(MemoryMiddleware(agent_name=agent_name))
|
||||
middlewares.append(MemoryMiddleware(agent_name=agent_name, memory_config=resolved_app_config.memory))
|
||||
|
||||
# Add ViewImageMiddleware only if the current model supports vision.
|
||||
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
|
||||
@ -316,7 +310,9 @@ def _build_middlewares(
|
||||
|
||||
def make_lead_agent(config: RunnableConfig):
|
||||
"""LangGraph graph factory; keep the signature compatible with LangGraph Server."""
|
||||
return _make_lead_agent(config, app_config=get_app_config())
|
||||
runtime_config = _get_runtime_config(config)
|
||||
runtime_app_config = runtime_config.get("app_config")
|
||||
return _make_lead_agent(config, app_config=runtime_app_config or get_app_config())
|
||||
|
||||
|
||||
def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
|
||||
|
||||
@ -158,7 +158,7 @@ Skip simple one-off tasks.
|
||||
"""
|
||||
|
||||
|
||||
def _build_available_subagents_description(available_names: list[str], bash_available: bool) -> str:
|
||||
def _build_available_subagents_description(available_names: list[str], bash_available: bool, *, app_config: AppConfig | None = None) -> str:
|
||||
"""Dynamically build subagent type descriptions from registry.
|
||||
|
||||
Mirrors Codex's pattern where agent_type_description is dynamically generated
|
||||
@ -180,7 +180,7 @@ def _build_available_subagents_description(available_names: list[str], bash_avai
|
||||
if name in builtin_descriptions:
|
||||
lines.append(f"- **{name}**: {builtin_descriptions[name]}")
|
||||
else:
|
||||
config = get_subagent_config(name)
|
||||
config = get_subagent_config(name, app_config=app_config)
|
||||
if config is not None:
|
||||
desc = config.description.split("\n")[0].strip() # First line only for brevity
|
||||
lines.append(f"- **{name}**: {desc}")
|
||||
@ -188,7 +188,7 @@ def _build_available_subagents_description(available_names: list[str], bash_avai
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _build_subagent_section(max_concurrent: int) -> str:
|
||||
def _build_subagent_section(max_concurrent: int, *, app_config: AppConfig | None = None) -> str:
|
||||
"""Build the subagent system prompt section with dynamic concurrency limit.
|
||||
|
||||
Args:
|
||||
@ -198,12 +198,12 @@ def _build_subagent_section(max_concurrent: int) -> str:
|
||||
Formatted subagent section string.
|
||||
"""
|
||||
n = max_concurrent
|
||||
available_names = get_available_subagent_names()
|
||||
available_names = get_available_subagent_names(app_config=app_config) if app_config is not None else get_available_subagent_names()
|
||||
bash_available = "bash" in available_names
|
||||
|
||||
# Dynamically build subagent type descriptions from registry (aligned with Codex's
|
||||
# agent_type_description pattern where all registered roles are listed in the tool spec).
|
||||
available_subagents = _build_available_subagents_description(available_names, bash_available)
|
||||
available_subagents = _build_available_subagents_description(available_names, bash_available, app_config=app_config)
|
||||
direct_tool_examples = "bash, ls, read_file, web_search, etc." if bash_available else "ls, read_file, web_search, etc."
|
||||
direct_execution_example = (
|
||||
'# User asks: "Run the tests"\n# Thinking: Cannot decompose into parallel sub-tasks\n# → Execute directly\n\nbash("npm test") # Direct execution, not task()'
|
||||
@ -530,21 +530,28 @@ combined with a FastAPI gateway for REST API access [citation:FastAPI](https://f
|
||||
"""
|
||||
|
||||
|
||||
def _get_memory_context(agent_name: str | None = None) -> str:
|
||||
def _get_memory_context(agent_name: str | None = None, *, app_config: AppConfig | None = None) -> str:
|
||||
"""Get memory context for injection into system prompt.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, loads per-agent memory. If None, loads global memory.
|
||||
app_config: Explicit application config. When provided, memory options
|
||||
are read from this value instead of the global config singleton.
|
||||
|
||||
Returns:
|
||||
Formatted memory context string wrapped in XML tags, or empty string if disabled.
|
||||
"""
|
||||
try:
|
||||
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
config = get_memory_config()
|
||||
if app_config is None:
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
|
||||
config = get_memory_config()
|
||||
else:
|
||||
config = app_config.memory
|
||||
|
||||
if not config.enabled or not config.injection_enabled:
|
||||
return ""
|
||||
|
||||
@ -558,8 +565,8 @@ def _get_memory_context(agent_name: str | None = None) -> str:
|
||||
{memory_content}
|
||||
</memory>
|
||||
"""
|
||||
except Exception as e:
|
||||
logger.error("Failed to load memory context: %s", e)
|
||||
except Exception:
|
||||
logger.exception("Failed to load memory context")
|
||||
return ""
|
||||
|
||||
|
||||
@ -599,15 +606,20 @@ def get_skills_prompt_section(available_skills: set[str] | None = None, *, app_c
|
||||
"""Generate the skills prompt section with available skills list."""
|
||||
skills = _get_enabled_skills_for_config(app_config)
|
||||
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
if app_config is None:
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
config = app_config or get_app_config()
|
||||
config = get_app_config()
|
||||
container_base_path = config.skills.container_path
|
||||
skill_evolution_enabled = config.skill_evolution.enabled
|
||||
except Exception:
|
||||
container_base_path = "/mnt/skills"
|
||||
skill_evolution_enabled = False
|
||||
else:
|
||||
config = app_config
|
||||
container_base_path = config.skills.container_path
|
||||
skill_evolution_enabled = config.skill_evolution.enabled
|
||||
except Exception:
|
||||
container_base_path = "/mnt/skills"
|
||||
skill_evolution_enabled = False
|
||||
|
||||
if not skills and not skill_evolution_enabled:
|
||||
return ""
|
||||
@ -640,13 +652,17 @@ def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) ->
|
||||
"""
|
||||
from deerflow.tools.builtins.tool_search import get_deferred_registry
|
||||
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
if app_config is None:
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
config = app_config or get_app_config()
|
||||
if not config.tool_search.enabled:
|
||||
config = get_app_config()
|
||||
except Exception:
|
||||
return ""
|
||||
except Exception:
|
||||
else:
|
||||
config = app_config
|
||||
|
||||
if not config.tool_search.enabled:
|
||||
return ""
|
||||
|
||||
registry = get_deferred_registry()
|
||||
@ -657,15 +673,19 @@ def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) ->
|
||||
return f"<available-deferred-tools>\n{names}\n</available-deferred-tools>"
|
||||
|
||||
|
||||
def _build_acp_section() -> str:
|
||||
def _build_acp_section(*, app_config: AppConfig | None = None) -> str:
|
||||
"""Build the ACP agent prompt section, only if ACP agents are configured."""
|
||||
try:
|
||||
from deerflow.config.acp_config import get_acp_agents
|
||||
if app_config is None:
|
||||
try:
|
||||
from deerflow.config.acp_config import get_acp_agents
|
||||
|
||||
agents = get_acp_agents()
|
||||
if not agents:
|
||||
agents = get_acp_agents()
|
||||
except Exception:
|
||||
return ""
|
||||
except Exception:
|
||||
else:
|
||||
agents = getattr(app_config, "acp_agents", {}) or {}
|
||||
|
||||
if not agents:
|
||||
return ""
|
||||
|
||||
return (
|
||||
@ -679,14 +699,18 @@ def _build_acp_section() -> str:
|
||||
|
||||
def _build_custom_mounts_section(*, app_config: AppConfig | None = None) -> str:
|
||||
"""Build a prompt section for explicitly configured sandbox mounts."""
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
if app_config is None:
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
config = app_config or get_app_config()
|
||||
mounts = config.sandbox.mounts or []
|
||||
except Exception:
|
||||
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
|
||||
return ""
|
||||
config = get_app_config()
|
||||
except Exception:
|
||||
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
|
||||
return ""
|
||||
else:
|
||||
config = app_config
|
||||
|
||||
mounts = config.sandbox.mounts or []
|
||||
|
||||
if not mounts:
|
||||
return ""
|
||||
@ -709,11 +733,11 @@ def apply_prompt_template(
|
||||
app_config: AppConfig | None = None,
|
||||
) -> str:
|
||||
# Get memory context
|
||||
memory_context = _get_memory_context(agent_name)
|
||||
memory_context = _get_memory_context(agent_name, app_config=app_config)
|
||||
|
||||
# Include subagent section only if enabled (from runtime parameter)
|
||||
n = max_concurrent_subagents
|
||||
subagent_section = _build_subagent_section(n) if subagent_enabled else ""
|
||||
subagent_section = _build_subagent_section(n, app_config=app_config) if subagent_enabled else ""
|
||||
|
||||
# Add subagent reminder to critical_reminders if enabled
|
||||
subagent_reminder = (
|
||||
@ -740,7 +764,7 @@ def apply_prompt_template(
|
||||
deferred_tools_section = get_deferred_tools_prompt_section(app_config=app_config)
|
||||
|
||||
# Build ACP agent section only if ACP agents are configured
|
||||
acp_section = _build_acp_section()
|
||||
acp_section = _build_acp_section(app_config=app_config)
|
||||
custom_mounts_section = _build_custom_mounts_section(app_config=app_config)
|
||||
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
"""Middleware for memory mechanism."""
|
||||
|
||||
import logging
|
||||
from typing import override
|
||||
from typing import TYPE_CHECKING, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
@ -13,6 +13,9 @@ from deerflow.agents.memory.queue import get_memory_queue
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -34,14 +37,17 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
|
||||
state_schema = MemoryMiddlewareState
|
||||
|
||||
def __init__(self, agent_name: str | None = None):
|
||||
def __init__(self, agent_name: str | None = None, *, memory_config: "MemoryConfig | None" = None):
|
||||
"""Initialize the MemoryMiddleware.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
memory_config: Explicit memory config. When omitted, legacy global
|
||||
config fallback is used.
|
||||
"""
|
||||
super().__init__()
|
||||
self._agent_name = agent_name
|
||||
self._memory_config = memory_config
|
||||
|
||||
@override
|
||||
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
@ -54,7 +60,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
Returns:
|
||||
None (no state changes needed from this middleware).
|
||||
"""
|
||||
config = get_memory_config()
|
||||
config = self._memory_config or get_memory_config()
|
||||
if not config.enabled:
|
||||
return None
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, NotRequired, override
|
||||
from typing import TYPE_CHECKING, Any, NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
@ -12,6 +12,10 @@ from langgraph.runtime import Runtime
|
||||
from deerflow.config.title_config import get_title_config
|
||||
from deerflow.models import create_chat_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.title_config import TitleConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -26,6 +30,18 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
|
||||
state_schema = TitleMiddlewareState
|
||||
|
||||
def __init__(self, *, app_config: "AppConfig | None" = None, title_config: "TitleConfig | None" = None):
|
||||
super().__init__()
|
||||
self._app_config = app_config
|
||||
self._title_config = title_config
|
||||
|
||||
def _get_title_config(self):
|
||||
if self._title_config is not None:
|
||||
return self._title_config
|
||||
if self._app_config is not None:
|
||||
return self._app_config.title
|
||||
return get_title_config()
|
||||
|
||||
def _normalize_content(self, content: object) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
@ -47,7 +63,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
|
||||
def _should_generate_title(self, state: TitleMiddlewareState) -> bool:
|
||||
"""Check if we should generate a title for this thread."""
|
||||
config = get_title_config()
|
||||
config = self._get_title_config()
|
||||
if not config.enabled:
|
||||
return False
|
||||
|
||||
@ -72,7 +88,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
|
||||
Returns (prompt_string, user_msg) so callers can use user_msg as fallback.
|
||||
"""
|
||||
config = get_title_config()
|
||||
config = self._get_title_config()
|
||||
messages = state.get("messages", [])
|
||||
|
||||
user_msg_content = next((m.content for m in messages if m.type == "human"), "")
|
||||
@ -94,14 +110,14 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
|
||||
def _parse_title(self, content: object) -> str:
|
||||
"""Normalize model output into a clean title string."""
|
||||
config = get_title_config()
|
||||
config = self._get_title_config()
|
||||
title_content = self._normalize_content(content)
|
||||
title_content = self._strip_think_tags(title_content)
|
||||
title = title_content.strip().strip('"').strip("'")
|
||||
return title[: config.max_chars] if len(title) > config.max_chars else title
|
||||
|
||||
def _fallback_title(self, user_msg: str) -> str:
|
||||
config = get_title_config()
|
||||
config = self._get_title_config()
|
||||
fallback_chars = min(config.max_chars, 50)
|
||||
if len(user_msg) > fallback_chars:
|
||||
return user_msg[:fallback_chars].rstrip() + "..."
|
||||
@ -135,14 +151,17 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
if not self._should_generate_title(state):
|
||||
return None
|
||||
|
||||
config = get_title_config()
|
||||
config = self._get_title_config()
|
||||
prompt, user_msg = self._build_title_prompt(state)
|
||||
|
||||
try:
|
||||
model_kwargs = {"thinking_enabled": False}
|
||||
if self._app_config is not None:
|
||||
model_kwargs["app_config"] = self._app_config
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
model = create_chat_model(name=config.model_name, **model_kwargs)
|
||||
else:
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
model = create_chat_model(**model_kwargs)
|
||||
response = await model.ainvoke(prompt, config=self._get_runnable_config())
|
||||
title = self._parse_title(response.content)
|
||||
if title:
|
||||
|
||||
@ -136,11 +136,32 @@ def build_lead_runtime_middlewares(*, app_config: AppConfig, lazy_init: bool = T
|
||||
)
|
||||
|
||||
|
||||
def build_subagent_runtime_middlewares(*, app_config: AppConfig, lazy_init: bool = True) -> list[AgentMiddleware]:
|
||||
def build_subagent_runtime_middlewares(
|
||||
*,
|
||||
app_config: AppConfig | None = None,
|
||||
model_name: str | None = None,
|
||||
lazy_init: bool = True,
|
||||
) -> list[AgentMiddleware]:
|
||||
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
||||
return _build_runtime_middlewares(
|
||||
if app_config is None:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
app_config = get_app_config()
|
||||
|
||||
middlewares = _build_runtime_middlewares(
|
||||
app_config=app_config,
|
||||
include_uploads=False,
|
||||
include_dangling_tool_call_patch=True,
|
||||
lazy_init=lazy_init,
|
||||
)
|
||||
|
||||
if model_name is None and app_config.models:
|
||||
model_name = app_config.models[0].name
|
||||
|
||||
model_config = app_config.get_model_config(model_name) if model_name else None
|
||||
if model_config is not None and model_config.supports_vision:
|
||||
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||
|
||||
middlewares.append(ViewImageMiddleware())
|
||||
|
||||
return middlewares
|
||||
|
||||
@ -48,6 +48,12 @@ class AioSandbox(Sandbox):
|
||||
self._home_dir = context.home_dir
|
||||
return self._home_dir
|
||||
|
||||
# Default no_change_timeout for exec_command (seconds). Matches the
|
||||
# client-level timeout so that long-running commands which produce no
|
||||
# output are not prematurely terminated by the sandbox's built-in 120 s
|
||||
# default.
|
||||
_DEFAULT_NO_CHANGE_TIMEOUT = 600
|
||||
|
||||
def execute_command(self, command: str) -> str:
|
||||
"""Execute a shell command in the sandbox.
|
||||
|
||||
@ -66,13 +72,13 @@ class AioSandbox(Sandbox):
|
||||
"""
|
||||
with self._lock:
|
||||
try:
|
||||
result = self._client.shell.exec_command(command=command)
|
||||
result = self._client.shell.exec_command(command=command, no_change_timeout=self._DEFAULT_NO_CHANGE_TIMEOUT)
|
||||
output = result.data.output if result.data else ""
|
||||
|
||||
if output and _ERROR_OBSERVATION_SIGNATURE in output:
|
||||
logger.warning("ErrorObservation detected in sandbox output, retrying with a fresh session")
|
||||
fresh_id = str(uuid.uuid4())
|
||||
result = self._client.shell.exec_command(command=command, id=fresh_id)
|
||||
result = self._client.shell.exec_command(command=command, id=fresh_id, no_change_timeout=self._DEFAULT_NO_CHANGE_TIMEOUT)
|
||||
output = result.data.output if result.data else ""
|
||||
|
||||
return output if output else "(no output)"
|
||||
@ -108,7 +114,7 @@ class AioSandbox(Sandbox):
|
||||
"""
|
||||
with self._lock:
|
||||
try:
|
||||
result = self._client.shell.exec_command(command=f"find {shlex.quote(path)} -maxdepth {max_depth} -type f -o -type d 2>/dev/null | head -500")
|
||||
result = self._client.shell.exec_command(command=f"find {shlex.quote(path)} -maxdepth {max_depth} -type f -o -type d 2>/dev/null | head -500", no_change_timeout=self._DEFAULT_NO_CHANGE_TIMEOUT)
|
||||
output = result.data.output if result.data else ""
|
||||
if output:
|
||||
return [line.strip() for line in output.strip().split("\n") if line.strip()]
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
from .tools import web_search_tool
|
||||
|
||||
__all__ = ["web_search_tool"]
|
||||
95
backend/packages/harness/deerflow/community/serper/tools.py
Normal file
95
backend/packages/harness/deerflow/community/serper/tools.py
Normal file
@ -0,0 +1,95 @@
|
||||
"""
|
||||
Web Search Tool - Search the web using Serper (Google Search API).
|
||||
|
||||
Serper provides real-time Google Search results via a JSON API.
|
||||
An API key is required. Sign up at https://serper.dev to get one.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import httpx
|
||||
from langchain.tools import tool
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SERPER_ENDPOINT = "https://google.serper.dev/search"
|
||||
_api_key_warned = False
|
||||
|
||||
|
||||
def _get_api_key() -> str | None:
|
||||
config = get_app_config().get_tool_config("web_search")
|
||||
if config is not None:
|
||||
api_key = config.model_extra.get("api_key")
|
||||
if isinstance(api_key, str) and api_key.strip():
|
||||
return api_key
|
||||
return os.getenv("SERPER_API_KEY")
|
||||
|
||||
|
||||
@tool("web_search", parse_docstring=True)
|
||||
def web_search_tool(query: str, max_results: int = 5) -> str:
|
||||
"""Search the web for information using Google Search via Serper.
|
||||
|
||||
Args:
|
||||
query: Search keywords describing what you want to find. Be specific for better results.
|
||||
max_results: Maximum number of search results to return. Default is 5.
|
||||
"""
|
||||
global _api_key_warned
|
||||
|
||||
config = get_app_config().get_tool_config("web_search")
|
||||
if config is not None and "max_results" in config.model_extra:
|
||||
max_results = config.model_extra.get("max_results", max_results)
|
||||
|
||||
api_key = _get_api_key()
|
||||
if not api_key:
|
||||
if not _api_key_warned:
|
||||
_api_key_warned = True
|
||||
logger.warning("Serper API key is not set. Set SERPER_API_KEY in your environment or provide api_key in config.yaml. Sign up at https://serper.dev")
|
||||
return json.dumps(
|
||||
{"error": "SERPER_API_KEY is not configured", "query": query},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"X-API-KEY": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {"q": query, "num": max_results}
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=30) as client:
|
||||
response = client.post(_SERPER_ENDPOINT, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Serper API returned HTTP {e.response.status_code}: {e.response.text}")
|
||||
return json.dumps(
|
||||
{"error": f"Serper API error: HTTP {e.response.status_code}", "query": query},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Serper search failed: {type(e).__name__}: {e}")
|
||||
return json.dumps({"error": str(e), "query": query}, ensure_ascii=False)
|
||||
|
||||
organic = data.get("organic", [])
|
||||
if not organic:
|
||||
return json.dumps({"error": "No results found", "query": query}, ensure_ascii=False)
|
||||
|
||||
normalized_results = [
|
||||
{
|
||||
"title": r.get("title", ""),
|
||||
"url": r.get("link", ""),
|
||||
"content": r.get("snippet", ""),
|
||||
}
|
||||
for r in organic[:max_results]
|
||||
]
|
||||
|
||||
output = {
|
||||
"query": query,
|
||||
"total_results": len(normalized_results),
|
||||
"results": normalized_results,
|
||||
}
|
||||
return json.dumps(output, indent=2, ensure_ascii=False)
|
||||
@ -8,7 +8,7 @@ import yaml
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from deerflow.config.acp_config import load_acp_config_from_dict
|
||||
from deerflow.config.acp_config import ACPAgentConfig, load_acp_config_from_dict
|
||||
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
|
||||
from deerflow.config.database_config import DatabaseConfig
|
||||
@ -17,6 +17,7 @@ from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_
|
||||
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.run_events_config import RunEventsConfig
|
||||
from deerflow.config.runtime_paths import existing_project_file
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.skill_evolution_config import SkillEvolutionConfig
|
||||
from deerflow.config.skills_config import SkillsConfig
|
||||
@ -46,8 +47,8 @@ class CircuitBreakerConfig(BaseModel):
|
||||
recovery_timeout_sec: int = Field(default=60, description="Time in seconds before attempting to recover the circuit")
|
||||
|
||||
|
||||
def _default_config_candidates() -> tuple[Path, ...]:
|
||||
"""Return deterministic config.yaml locations without relying on cwd."""
|
||||
def _legacy_config_candidates() -> tuple[Path, ...]:
|
||||
"""Return source-tree config.yaml locations for monorepo compatibility."""
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
repo_root = backend_dir.parent
|
||||
return (backend_dir / "config.yaml", repo_root / "config.yaml")
|
||||
@ -94,6 +95,7 @@ class AppConfig(BaseModel):
|
||||
summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration")
|
||||
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
|
||||
agents_api: AgentsApiConfig = Field(default_factory=AgentsApiConfig, description="Custom-agent management API configuration")
|
||||
acp_agents: dict[str, ACPAgentConfig] = Field(default_factory=dict, description="ACP-compatible agent configuration")
|
||||
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
|
||||
@ -110,7 +112,8 @@ class AppConfig(BaseModel):
|
||||
Priority:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`.
|
||||
3. Otherwise, search the caller project root.
|
||||
4. Finally, search legacy backend/repository-root defaults for monorepo compatibility.
|
||||
"""
|
||||
if config_path:
|
||||
path = Path(config_path)
|
||||
@ -123,10 +126,14 @@ class AppConfig(BaseModel):
|
||||
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
|
||||
return path
|
||||
else:
|
||||
for path in _default_config_candidates():
|
||||
project_config = existing_project_file(("config.yaml",))
|
||||
if project_config is not None:
|
||||
return project_config
|
||||
|
||||
for path in _legacy_config_candidates():
|
||||
if path.exists():
|
||||
return path
|
||||
raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations")
|
||||
raise FileNotFoundError("`config.yaml` file not found in the project root or legacy backend/repository root locations")
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str | None = None) -> Self:
|
||||
|
||||
@ -7,6 +7,8 @@ from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from deerflow.config.runtime_paths import existing_project_file
|
||||
|
||||
|
||||
class McpOAuthConfig(BaseModel):
|
||||
"""OAuth configuration for an MCP server (HTTP/SSE transports)."""
|
||||
@ -73,8 +75,8 @@ class ExtensionsConfig(BaseModel):
|
||||
Priority:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_EXTENSIONS_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, check for `extensions_config.json` in the current directory, then in the parent directory.
|
||||
4. For backward compatibility, also check for `mcp_config.json` if `extensions_config.json` is not found.
|
||||
3. Otherwise, search the caller project root for `extensions_config.json`, then `mcp_config.json`.
|
||||
4. For backward compatibility, also search legacy backend/repository-root defaults.
|
||||
5. If not found, return None (extensions are optional).
|
||||
|
||||
Args:
|
||||
@ -83,8 +85,9 @@ class ExtensionsConfig(BaseModel):
|
||||
Resolution order:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_EXTENSIONS_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, search backend/repository-root defaults for
|
||||
3. Otherwise, search the caller project root for
|
||||
`extensions_config.json`, then legacy `mcp_config.json`.
|
||||
4. Finally, search backend/repository-root defaults for monorepo compatibility.
|
||||
|
||||
Returns:
|
||||
Path to the extensions config file if found, otherwise None.
|
||||
@ -100,6 +103,10 @@ class ExtensionsConfig(BaseModel):
|
||||
raise FileNotFoundError(f"Extensions config file specified by environment variable `DEER_FLOW_EXTENSIONS_CONFIG_PATH` not found at {path}")
|
||||
return path
|
||||
else:
|
||||
project_config = existing_project_file(("extensions_config.json", "mcp_config.json"))
|
||||
if project_config is not None:
|
||||
return project_config
|
||||
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
repo_root = backend_dir.parent
|
||||
for path in (
|
||||
|
||||
@ -3,6 +3,8 @@ import re
|
||||
import shutil
|
||||
from pathlib import Path, PureWindowsPath
|
||||
|
||||
from deerflow.config.runtime_paths import runtime_home
|
||||
|
||||
# Virtual path prefix seen by agents inside the sandbox
|
||||
VIRTUAL_PATH_PREFIX = "/mnt/user-data"
|
||||
|
||||
@ -11,9 +13,8 @@ _SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
|
||||
|
||||
def _default_local_base_dir() -> Path:
|
||||
"""Return the repo-local DeerFlow state directory without relying on cwd."""
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
return backend_dir / ".deer-flow"
|
||||
"""Return the caller project's writable DeerFlow state directory."""
|
||||
return runtime_home()
|
||||
|
||||
|
||||
def _validate_thread_id(thread_id: str) -> str:
|
||||
@ -81,7 +82,7 @@ class Paths:
|
||||
BaseDir resolution (in priority order):
|
||||
1. Constructor argument `base_dir`
|
||||
2. DEER_FLOW_HOME environment variable
|
||||
3. Repo-local fallback derived from this module path: `{backend_dir}/.deer-flow`
|
||||
3. Caller project fallback: `{project_root}/.deer-flow`
|
||||
"""
|
||||
|
||||
def __init__(self, base_dir: str | Path | None = None) -> None:
|
||||
|
||||
41
backend/packages/harness/deerflow/config/runtime_paths.py
Normal file
41
backend/packages/harness/deerflow/config/runtime_paths.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""Runtime path resolution for standalone harness usage."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def project_root() -> Path:
|
||||
"""Return the caller project root for runtime-owned files."""
|
||||
if env_root := os.getenv("DEER_FLOW_PROJECT_ROOT"):
|
||||
root = Path(env_root).resolve()
|
||||
if not root.exists():
|
||||
raise ValueError(f"DEER_FLOW_PROJECT_ROOT is set to '{env_root}', but the resolved path '{root}' does not exist.")
|
||||
if not root.is_dir():
|
||||
raise ValueError(f"DEER_FLOW_PROJECT_ROOT is set to '{env_root}', but the resolved path '{root}' is not a directory.")
|
||||
return root
|
||||
return Path.cwd().resolve()
|
||||
|
||||
|
||||
def runtime_home() -> Path:
|
||||
"""Return the writable DeerFlow state directory."""
|
||||
if env_home := os.getenv("DEER_FLOW_HOME"):
|
||||
return Path(env_home).resolve()
|
||||
return project_root() / ".deer-flow"
|
||||
|
||||
|
||||
def resolve_path(value: str | os.PathLike[str], *, base: Path | None = None) -> Path:
|
||||
"""Resolve absolute paths as-is and relative paths against the project root."""
|
||||
path = Path(value)
|
||||
if not path.is_absolute():
|
||||
path = (base or project_root()) / path
|
||||
return path.resolve()
|
||||
|
||||
|
||||
def existing_project_file(names: tuple[str, ...]) -> Path | None:
|
||||
"""Return the first existing named file under the project root."""
|
||||
root = project_root()
|
||||
for name in names:
|
||||
candidate = root / name
|
||||
if candidate.is_file():
|
||||
return candidate
|
||||
return None
|
||||
@ -1,11 +1,9 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
def _default_repo_root() -> Path:
|
||||
"""Resolve the repo root without relying on the current working directory."""
|
||||
return Path(__file__).resolve().parents[5]
|
||||
from deerflow.config.runtime_paths import project_root, resolve_path
|
||||
|
||||
|
||||
class SkillsConfig(BaseModel):
|
||||
@ -17,7 +15,7 @@ class SkillsConfig(BaseModel):
|
||||
)
|
||||
path: str | None = Field(
|
||||
default=None,
|
||||
description="Path to skills directory. If not specified, defaults to ../skills relative to backend directory",
|
||||
description="Path to skills directory. If not specified, defaults to skills under the caller project root.",
|
||||
)
|
||||
container_path: str = Field(
|
||||
default="/mnt/skills",
|
||||
@ -32,15 +30,11 @@ class SkillsConfig(BaseModel):
|
||||
Path to the skills directory
|
||||
"""
|
||||
if self.path:
|
||||
# Use configured path (can be absolute or relative)
|
||||
path = Path(self.path)
|
||||
if not path.is_absolute():
|
||||
# If relative, resolve from the repo root for deterministic behavior.
|
||||
path = _default_repo_root() / path
|
||||
return path.resolve()
|
||||
else:
|
||||
# Default: <repo_root>/skills
|
||||
return _default_repo_root() / "skills"
|
||||
# Use configured path (can be absolute or relative to project root)
|
||||
return resolve_path(self.path)
|
||||
if env_path := os.getenv("DEER_FLOW_SKILLS_PATH"):
|
||||
return resolve_path(env_path)
|
||||
return project_root() / "skills"
|
||||
|
||||
def get_skill_container_path(self, skill_name: str, category: str = "public") -> str:
|
||||
"""
|
||||
|
||||
@ -27,6 +27,34 @@ from deerflow.models.credential_loader import CodexCliCredential, load_codex_cli
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
|
||||
|
||||
def _build_usage_metadata(oai_usage: dict) -> dict:
|
||||
"""Convert Codex/Responses API usage dict to LangChain usage_metadata format.
|
||||
|
||||
Maps OpenAI Responses API token usage fields to the dict structure that
|
||||
LangChain AIMessage.usage_metadata expects. This avoids depending on
|
||||
langchain_openai private helpers like ``_create_usage_metadata_responses``.
|
||||
"""
|
||||
input_tokens = oai_usage.get("input_tokens", 0)
|
||||
output_tokens = oai_usage.get("output_tokens", 0)
|
||||
total_tokens = oai_usage.get("total_tokens", input_tokens + output_tokens)
|
||||
metadata: dict = {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
input_details = oai_usage.get("input_tokens_details") or {}
|
||||
output_details = oai_usage.get("output_tokens_details") or {}
|
||||
cache_read = input_details.get("cached_tokens")
|
||||
if cache_read is not None:
|
||||
metadata["input_token_details"] = {"cache_read": cache_read}
|
||||
reasoning = output_details.get("reasoning_tokens")
|
||||
if reasoning is not None:
|
||||
metadata["output_token_details"] = {"reasoning": reasoning}
|
||||
return metadata
|
||||
|
||||
|
||||
MAX_RETRIES = 3
|
||||
|
||||
|
||||
@ -346,6 +374,7 @@ class CodexChatModel(BaseChatModel):
|
||||
)
|
||||
|
||||
usage = response.get("usage", {})
|
||||
usage_metadata = _build_usage_metadata(usage) if usage else None
|
||||
additional_kwargs = {}
|
||||
if reasoning_content:
|
||||
additional_kwargs["reasoning_content"] = reasoning_content
|
||||
@ -355,6 +384,7 @@ class CodexChatModel(BaseChatModel):
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
invalid_tool_calls=invalid_tool_calls,
|
||||
additional_kwargs=additional_kwargs,
|
||||
usage_metadata=usage_metadata,
|
||||
response_metadata={
|
||||
"model": response.get("model", self.model),
|
||||
"usage": usage,
|
||||
|
||||
@ -7,13 +7,13 @@ router for thread records.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
from deerflow.utils.time import coerce_iso, now_iso
|
||||
|
||||
THREADS_NS: tuple[str, ...] = ("threads",)
|
||||
|
||||
@ -48,7 +48,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.create")
|
||||
now = time.time()
|
||||
now = now_iso()
|
||||
record: dict[str, Any] = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
@ -106,7 +106,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
if record is None:
|
||||
return
|
||||
record["display_name"] = display_name
|
||||
record["updated_at"] = time.time()
|
||||
record["updated_at"] = now_iso()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
@ -114,7 +114,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
if record is None:
|
||||
return
|
||||
record["status"] = status
|
||||
record["updated_at"] = time.time()
|
||||
record["updated_at"] = now_iso()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
@ -124,7 +124,7 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
merged = dict(record.get("metadata") or {})
|
||||
merged.update(metadata)
|
||||
record["metadata"] = merged
|
||||
record["updated_at"] = time.time()
|
||||
record["updated_at"] = now_iso()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
@ -144,6 +144,8 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
"display_name": val.get("display_name"),
|
||||
"status": val.get("status", "idle"),
|
||||
"metadata": val.get("metadata", {}),
|
||||
"created_at": str(val.get("created_at", "")),
|
||||
"updated_at": str(val.get("updated_at", "")),
|
||||
# ``coerce_iso`` heals legacy unix-second values written by
|
||||
# earlier Gateway versions that called ``str(time.time())``.
|
||||
"created_at": coerce_iso(val.get("created_at", "")),
|
||||
"updated_at": coerce_iso(val.get("updated_at", "")),
|
||||
}
|
||||
|
||||
@ -6,9 +6,10 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deerflow.utils.time import now_iso as _now_iso
|
||||
|
||||
from .schemas import DisconnectMode, RunStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -17,10 +18,6 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now_iso() -> str:
|
||||
return datetime.now(UTC).isoformat()
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunRecord:
|
||||
"""Mutable record for a single run."""
|
||||
|
||||
@ -21,7 +21,9 @@ import inspect
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import HumanMessage
|
||||
@ -39,12 +41,19 @@ logger = logging.getLogger(__name__)
|
||||
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
|
||||
|
||||
|
||||
def _build_runtime_context(thread_id: str, run_id: str, caller_context: Any | None) -> dict[str, Any]:
|
||||
def _build_runtime_context(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
caller_context: Any | None,
|
||||
app_config: AppConfig | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build the dict that becomes ``ToolRuntime.context`` for the run.
|
||||
|
||||
Always includes ``thread_id`` and ``run_id``. Additional keys from the caller's
|
||||
``config['context']`` (e.g. ``agent_name`` for the bootstrap flow — issue #2677)
|
||||
are merged in but never override ``thread_id``/``run_id``.
|
||||
are merged in but never override ``thread_id``/``run_id``. The resolved
|
||||
``AppConfig`` is added by the worker so tools can consume it without ambient
|
||||
global lookups.
|
||||
|
||||
langgraph 1.1+ surfaces this as ``runtime.context`` via the parent runtime stored
|
||||
under ``config['configurable']['__pregel_runtime']`` — see
|
||||
@ -54,6 +63,8 @@ def _build_runtime_context(thread_id: str, run_id: str, caller_context: Any | No
|
||||
if isinstance(caller_context, dict):
|
||||
for key, value in caller_context.items():
|
||||
runtime_ctx.setdefault(key, value)
|
||||
if app_config is not None:
|
||||
runtime_ctx["app_config"] = app_config
|
||||
return runtime_ctx
|
||||
|
||||
|
||||
@ -74,6 +85,18 @@ class RunContext:
|
||||
app_config: AppConfig | None = field(default=None)
|
||||
|
||||
|
||||
def _install_runtime_context(config: dict, runtime_context: dict[str, Any]) -> None:
|
||||
existing_context = config.get("context")
|
||||
if isinstance(existing_context, dict):
|
||||
existing_context.setdefault("thread_id", runtime_context["thread_id"])
|
||||
existing_context.setdefault("run_id", runtime_context["run_id"])
|
||||
if "app_config" in runtime_context:
|
||||
existing_context["app_config"] = runtime_context["app_config"]
|
||||
return
|
||||
|
||||
config["context"] = dict(runtime_context)
|
||||
|
||||
|
||||
def _compute_agent_factory_supports_app_config(agent_factory: Any) -> bool:
|
||||
try:
|
||||
return "app_config" in inspect.signature(agent_factory).parameters
|
||||
@ -191,11 +214,9 @@ async def run_agent(
|
||||
# access thread-level data. langgraph-cli does this automatically; we must do it
|
||||
# manually here because we drive the graph through ``agent.astream(config=...)``
|
||||
# without passing the official ``context=`` parameter.
|
||||
runtime_ctx = _build_runtime_context(thread_id, run_id, config.get("context"))
|
||||
if "context" in config and isinstance(config["context"], dict):
|
||||
config["context"].setdefault("thread_id", thread_id)
|
||||
config["context"].setdefault("run_id", run_id)
|
||||
runtime = Runtime(context=runtime_ctx, store=store)
|
||||
runtime_ctx = _build_runtime_context(thread_id, run_id, config.get("context"), ctx.app_config)
|
||||
_install_runtime_context(config, runtime_ctx)
|
||||
runtime = Runtime(context=cast(Any, runtime_ctx), store=store)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
# Inject RunJournal as a LangChain callback handler.
|
||||
@ -423,6 +444,12 @@ async def _rollback_to_pre_run_checkpoint(
|
||||
if checkpoint_to_restore.get("id") is None:
|
||||
logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id)
|
||||
return
|
||||
restore_marker = _new_checkpoint_marker()
|
||||
checkpoint_to_restore = {
|
||||
**checkpoint_to_restore,
|
||||
"id": restore_marker["id"],
|
||||
"ts": restore_marker["ts"],
|
||||
}
|
||||
metadata = pre_run_snapshot.get("metadata", {})
|
||||
metadata_to_restore = metadata if isinstance(metadata, dict) else {}
|
||||
raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns")
|
||||
@ -474,6 +501,11 @@ async def _rollback_to_pre_run_checkpoint(
|
||||
)
|
||||
|
||||
|
||||
def _new_checkpoint_marker() -> dict[str, str]:
|
||||
marker = empty_checkpoint()
|
||||
return {"id": marker["id"], "ts": marker["ts"]}
|
||||
|
||||
|
||||
def _lg_mode_to_sse_event(mode: str) -> str:
|
||||
"""Map LangGraph internal stream_mode name to SSE event name.
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ from collections.abc import Iterable
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config.skills_config import _default_repo_root
|
||||
from deerflow.config.runtime_paths import resolve_path
|
||||
from deerflow.skills.storage.skill_storage import SKILL_MD_FILE, SkillStorage
|
||||
from deerflow.skills.types import SkillCategory
|
||||
|
||||
@ -44,10 +44,7 @@ class LocalSkillStorage(SkillStorage):
|
||||
config = app_config or get_app_config()
|
||||
self._host_root: Path = config.skills.get_skills_path()
|
||||
else:
|
||||
path = Path(host_path)
|
||||
if not path.is_absolute():
|
||||
path = _default_repo_root() / path
|
||||
self._host_root = path.resolve()
|
||||
self._host_root = resolve_path(host_path)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Abstract operation implementations
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
"""Subagent configuration definitions."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -29,3 +33,24 @@ class SubagentConfig:
|
||||
model: str = "inherit"
|
||||
max_turns: int = 50
|
||||
timeout_seconds: int = 900
|
||||
|
||||
|
||||
def _default_model_name(app_config: "AppConfig") -> str:
|
||||
if not app_config.models:
|
||||
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
|
||||
return app_config.models[0].name
|
||||
|
||||
|
||||
def resolve_subagent_model_name(config: SubagentConfig, parent_model: str | None, *, app_config: "AppConfig | None" = None) -> str:
|
||||
"""Resolve the effective model name a subagent should use."""
|
||||
if config.model != "inherit":
|
||||
return config.model
|
||||
|
||||
if parent_model is not None:
|
||||
return parent_model
|
||||
|
||||
if app_config is None:
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
app_config = get_app_config()
|
||||
return _default_model_name(app_config)
|
||||
|
||||
@ -20,9 +20,10 @@ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from deerflow.agents.thread_state import SandboxState, ThreadDataState, ThreadState
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.models import create_chat_model
|
||||
from deerflow.subagents.config import SubagentConfig
|
||||
from deerflow.subagents.config import SubagentConfig, resolve_subagent_model_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -167,6 +168,8 @@ def _get_isolated_subagent_loop() -> asyncio.AbstractEventLoop:
|
||||
_isolated_subagent_loop_thread = thread
|
||||
_isolated_subagent_loop_started = started_event
|
||||
|
||||
if _isolated_subagent_loop is None:
|
||||
raise RuntimeError("Isolated subagent event loop is not initialized")
|
||||
return _isolated_subagent_loop
|
||||
|
||||
|
||||
@ -213,21 +216,6 @@ def _filter_tools(
|
||||
return filtered
|
||||
|
||||
|
||||
def _get_model_name(config: SubagentConfig, parent_model: str | None) -> str | None:
|
||||
"""Resolve the model name for a subagent.
|
||||
|
||||
Args:
|
||||
config: Subagent configuration.
|
||||
parent_model: The parent agent's model name.
|
||||
|
||||
Returns:
|
||||
Model name to use, or None to use default.
|
||||
"""
|
||||
if config.model == "inherit":
|
||||
return parent_model
|
||||
return config.model
|
||||
|
||||
|
||||
class SubagentExecutor:
|
||||
"""Executor for running subagents."""
|
||||
|
||||
@ -247,9 +235,9 @@ class SubagentExecutor:
|
||||
Args:
|
||||
config: Subagent configuration.
|
||||
tools: List of all available tools (will be filtered).
|
||||
app_config: Resolved AppConfig; threaded into middleware factories
|
||||
at agent-build time. When None, ``_create_agent`` falls back to
|
||||
``get_app_config()`` (matches the lead-agent factory's pattern).
|
||||
app_config: Resolved AppConfig. When None, ``_create_agent`` falls
|
||||
back to ``get_app_config()`` (matches the lead-agent factory's
|
||||
pattern).
|
||||
parent_model: The parent agent's model name for inheritance.
|
||||
sandbox_state: Sandbox state from parent agent.
|
||||
thread_data: Thread data from parent agent.
|
||||
@ -259,6 +247,13 @@ class SubagentExecutor:
|
||||
self.config = config
|
||||
self.app_config = app_config
|
||||
self.parent_model = parent_model
|
||||
# Resolve eagerly only when it does not require loading config.yaml; otherwise defer
|
||||
# to _create_agent (which already loads app_config) so unit tests can construct
|
||||
# executors without a config file present.
|
||||
if config.model != "inherit" or parent_model is not None or app_config is not None:
|
||||
self.model_name: str | None = resolve_subagent_model_name(config, parent_model, app_config=app_config)
|
||||
else:
|
||||
self.model_name = None
|
||||
self.sandbox_state = sandbox_state
|
||||
self.thread_data = thread_data
|
||||
self.thread_id = thread_id
|
||||
@ -276,17 +271,15 @@ class SubagentExecutor:
|
||||
|
||||
def _create_agent(self):
|
||||
"""Create the agent instance."""
|
||||
# Mirror lead-agent factory pattern: prefer explicit app_config,
|
||||
# fall back to ambient lookup at agent-build time.
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
resolved_app_config = self.app_config or get_app_config()
|
||||
model_name = _get_model_name(self.config, self.parent_model)
|
||||
model = create_chat_model(name=model_name, thinking_enabled=False, app_config=resolved_app_config)
|
||||
app_config = self.app_config or get_app_config()
|
||||
if self.model_name is None:
|
||||
self.model_name = resolve_subagent_model_name(self.config, self.parent_model, app_config=app_config)
|
||||
model = create_chat_model(name=self.model_name, thinking_enabled=False, app_config=app_config)
|
||||
|
||||
from deerflow.agents.middlewares.tool_error_handling_middleware import build_subagent_runtime_middlewares
|
||||
|
||||
middlewares = build_subagent_runtime_middlewares(app_config=resolved_app_config, lazy_init=True)
|
||||
# Reuse shared middleware composition with lead agent.
|
||||
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name=self.model_name, lazy_init=True)
|
||||
|
||||
return create_agent(
|
||||
model=model,
|
||||
@ -317,8 +310,10 @@ class SubagentExecutor:
|
||||
try:
|
||||
from deerflow.skills.storage import get_or_new_skill_storage
|
||||
|
||||
storage_kwargs = {"app_config": self.app_config} if self.app_config is not None else {}
|
||||
storage = await asyncio.to_thread(get_or_new_skill_storage, **storage_kwargs)
|
||||
# Use asyncio.to_thread to avoid blocking the event loop (LangGraph ASGI requirement)
|
||||
all_skills = await asyncio.to_thread(get_or_new_skill_storage().load_skills, enabled_only=True)
|
||||
all_skills = await asyncio.to_thread(storage.load_skills, enabled_only=True)
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} loaded {len(all_skills)} enabled skills from disk")
|
||||
except Exception:
|
||||
logger.warning(f"[trace={self.trace_id}] Failed to load skills for subagent {self.config.name}", exc_info=True)
|
||||
@ -404,6 +399,10 @@ class SubagentExecutor:
|
||||
status=SubagentStatus.RUNNING,
|
||||
started_at=datetime.now(),
|
||||
)
|
||||
ai_messages = result.ai_messages
|
||||
if ai_messages is None:
|
||||
ai_messages = []
|
||||
result.ai_messages = ai_messages
|
||||
|
||||
try:
|
||||
agent = self._create_agent()
|
||||
@ -413,10 +412,12 @@ class SubagentExecutor:
|
||||
run_config: RunnableConfig = {
|
||||
"recursion_limit": self.config.max_turns,
|
||||
}
|
||||
context = {}
|
||||
context: dict[str, Any] = {}
|
||||
if self.thread_id:
|
||||
run_config["configurable"] = {"thread_id": self.thread_id}
|
||||
context["thread_id"] = self.thread_id
|
||||
if self.app_config is not None:
|
||||
context["app_config"] = self.app_config
|
||||
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} starting async execution with max_turns={self.config.max_turns}")
|
||||
|
||||
@ -463,13 +464,13 @@ class SubagentExecutor:
|
||||
message_id = message_dict.get("id")
|
||||
is_duplicate = False
|
||||
if message_id:
|
||||
is_duplicate = any(msg.get("id") == message_id for msg in result.ai_messages)
|
||||
is_duplicate = any(msg.get("id") == message_id for msg in ai_messages)
|
||||
else:
|
||||
is_duplicate = message_dict in result.ai_messages
|
||||
is_duplicate = message_dict in ai_messages
|
||||
|
||||
if not is_duplicate:
|
||||
result.ai_messages.append(message_dict)
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(result.ai_messages)}")
|
||||
ai_messages.append(message_dict)
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(ai_messages)}")
|
||||
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed async execution")
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import logging
|
||||
from dataclasses import replace
|
||||
from typing import Any
|
||||
|
||||
from deerflow.sandbox.security import is_host_bash_allowed
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
@ -10,19 +11,26 @@ from deerflow.subagents.config import SubagentConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_custom_subagent_config(name: str) -> SubagentConfig | None:
|
||||
def _resolve_subagents_app_config(app_config: Any | None = None):
|
||||
if app_config is None:
|
||||
from deerflow.config.subagents_config import get_subagents_app_config
|
||||
|
||||
return get_subagents_app_config()
|
||||
return getattr(app_config, "subagents", app_config)
|
||||
|
||||
|
||||
def _build_custom_subagent_config(name: str, *, app_config: Any | None = None) -> SubagentConfig | None:
|
||||
"""Build a SubagentConfig from config.yaml custom_agents section.
|
||||
|
||||
Args:
|
||||
name: The name of the custom subagent.
|
||||
app_config: Optional AppConfig or SubagentsAppConfig to resolve from.
|
||||
|
||||
Returns:
|
||||
SubagentConfig if found in custom_agents, None otherwise.
|
||||
"""
|
||||
from deerflow.config.subagents_config import get_subagents_app_config
|
||||
|
||||
app_config = get_subagents_app_config()
|
||||
custom = app_config.custom_agents.get(name)
|
||||
subagents_config = _resolve_subagents_app_config(app_config)
|
||||
custom = subagents_config.custom_agents.get(name)
|
||||
if custom is None:
|
||||
return None
|
||||
|
||||
@ -39,7 +47,7 @@ def _build_custom_subagent_config(name: str) -> SubagentConfig | None:
|
||||
)
|
||||
|
||||
|
||||
def get_subagent_config(name: str) -> SubagentConfig | None:
|
||||
def get_subagent_config(name: str, *, app_config: Any | None = None) -> SubagentConfig | None:
|
||||
"""Get a subagent configuration by name, with config.yaml overrides applied.
|
||||
|
||||
Resolution order (mirrors Codex's config layering):
|
||||
@ -49,6 +57,7 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
|
||||
|
||||
Args:
|
||||
name: The name of the subagent.
|
||||
app_config: Optional AppConfig or SubagentsAppConfig to resolve overrides from.
|
||||
|
||||
Returns:
|
||||
SubagentConfig if found (with any config.yaml overrides applied), None otherwise.
|
||||
@ -56,7 +65,7 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
|
||||
# Step 1: Look up built-in, then fall back to custom_agents
|
||||
config = BUILTIN_SUBAGENTS.get(name)
|
||||
if config is None:
|
||||
config = _build_custom_subagent_config(name)
|
||||
config = _build_custom_subagent_config(name, app_config=app_config)
|
||||
if config is None:
|
||||
return None
|
||||
|
||||
@ -65,12 +74,9 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
|
||||
# (timeout_seconds, max_turns at the top level) apply to built-in agents
|
||||
# but must NOT override custom agents' own values — custom agents define
|
||||
# their own defaults in the custom_agents section.
|
||||
# Lazy import to avoid circular deps.
|
||||
from deerflow.config.subagents_config import get_subagents_app_config
|
||||
|
||||
app_config = get_subagents_app_config()
|
||||
subagents_config = _resolve_subagents_app_config(app_config)
|
||||
is_builtin = name in BUILTIN_SUBAGENTS
|
||||
agent_override = app_config.agents.get(name)
|
||||
agent_override = subagents_config.agents.get(name)
|
||||
|
||||
overrides = {}
|
||||
|
||||
@ -79,27 +85,27 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
|
||||
if agent_override.timeout_seconds != config.timeout_seconds:
|
||||
logger.debug("Subagent '%s': timeout overridden (%ss -> %ss)", name, config.timeout_seconds, agent_override.timeout_seconds)
|
||||
overrides["timeout_seconds"] = agent_override.timeout_seconds
|
||||
elif is_builtin and app_config.timeout_seconds != config.timeout_seconds:
|
||||
logger.debug("Subagent '%s': timeout from global default (%ss -> %ss)", name, config.timeout_seconds, app_config.timeout_seconds)
|
||||
overrides["timeout_seconds"] = app_config.timeout_seconds
|
||||
elif is_builtin and subagents_config.timeout_seconds != config.timeout_seconds:
|
||||
logger.debug("Subagent '%s': timeout from global default (%ss -> %ss)", name, config.timeout_seconds, subagents_config.timeout_seconds)
|
||||
overrides["timeout_seconds"] = subagents_config.timeout_seconds
|
||||
|
||||
# Max turns: per-agent override > global default (builtins only) > config's own value
|
||||
if agent_override is not None and agent_override.max_turns is not None:
|
||||
if agent_override.max_turns != config.max_turns:
|
||||
logger.debug("Subagent '%s': max_turns overridden (%s -> %s)", name, config.max_turns, agent_override.max_turns)
|
||||
overrides["max_turns"] = agent_override.max_turns
|
||||
elif is_builtin and app_config.max_turns is not None and app_config.max_turns != config.max_turns:
|
||||
logger.debug("Subagent '%s': max_turns from global default (%s -> %s)", name, config.max_turns, app_config.max_turns)
|
||||
overrides["max_turns"] = app_config.max_turns
|
||||
elif is_builtin and subagents_config.max_turns is not None and subagents_config.max_turns != config.max_turns:
|
||||
logger.debug("Subagent '%s': max_turns from global default (%s -> %s)", name, config.max_turns, subagents_config.max_turns)
|
||||
overrides["max_turns"] = subagents_config.max_turns
|
||||
|
||||
# Model: per-agent override only (no global default for model)
|
||||
effective_model = app_config.get_model_for(name)
|
||||
effective_model = subagents_config.get_model_for(name)
|
||||
if effective_model is not None and effective_model != config.model:
|
||||
logger.debug("Subagent '%s': model overridden (%s -> %s)", name, config.model, effective_model)
|
||||
overrides["model"] = effective_model
|
||||
|
||||
# Skills: per-agent override only (no global default for skills)
|
||||
effective_skills = app_config.get_skills_for(name)
|
||||
effective_skills = subagents_config.get_skills_for(name)
|
||||
if effective_skills is not None and effective_skills != config.skills:
|
||||
logger.debug("Subagent '%s': skills overridden (%s -> %s)", name, config.skills, effective_skills)
|
||||
overrides["skills"] = effective_skills
|
||||
@ -110,21 +116,21 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
|
||||
return config
|
||||
|
||||
|
||||
def list_subagents() -> list[SubagentConfig]:
|
||||
def list_subagents(*, app_config: Any | None = None) -> list[SubagentConfig]:
|
||||
"""List all available subagent configurations (with config.yaml overrides applied).
|
||||
|
||||
Returns:
|
||||
List of all registered SubagentConfig instances (built-in + custom).
|
||||
"""
|
||||
configs = []
|
||||
for name in get_subagent_names():
|
||||
config = get_subagent_config(name)
|
||||
for name in get_subagent_names(app_config=app_config):
|
||||
config = get_subagent_config(name, app_config=app_config)
|
||||
if config is not None:
|
||||
configs.append(config)
|
||||
return configs
|
||||
|
||||
|
||||
def get_subagent_names() -> list[str]:
|
||||
def get_subagent_names(*, app_config: Any | None = None) -> list[str]:
|
||||
"""Get all available subagent names (built-in + custom).
|
||||
|
||||
Returns:
|
||||
@ -133,25 +139,23 @@ def get_subagent_names() -> list[str]:
|
||||
names = list(BUILTIN_SUBAGENTS.keys())
|
||||
|
||||
# Merge custom_agents from config.yaml
|
||||
from deerflow.config.subagents_config import get_subagents_app_config
|
||||
|
||||
app_config = get_subagents_app_config()
|
||||
for custom_name in app_config.custom_agents:
|
||||
subagents_config = _resolve_subagents_app_config(app_config)
|
||||
for custom_name in subagents_config.custom_agents:
|
||||
if custom_name not in names:
|
||||
names.append(custom_name)
|
||||
|
||||
return names
|
||||
|
||||
|
||||
def get_available_subagent_names() -> list[str]:
|
||||
def get_available_subagent_names(*, app_config: Any | None = None) -> list[str]:
|
||||
"""Get subagent names that should be exposed to the active runtime.
|
||||
|
||||
Returns:
|
||||
List of subagent names visible to the current sandbox configuration.
|
||||
"""
|
||||
names = get_subagent_names()
|
||||
names = get_subagent_names(app_config=app_config)
|
||||
try:
|
||||
host_bash_allowed = is_host_bash_allowed()
|
||||
host_bash_allowed = is_host_bash_allowed(app_config) if hasattr(app_config, "sandbox") else is_host_bash_allowed()
|
||||
except Exception:
|
||||
logger.debug("Could not determine host bash availability; exposing all subagents")
|
||||
return names
|
||||
|
||||
@ -4,20 +4,39 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import replace
|
||||
from typing import Annotated
|
||||
from typing import TYPE_CHECKING, Annotated, Any, cast
|
||||
|
||||
from langchain.tools import InjectedToolCallId, ToolRuntime, tool
|
||||
from langgraph.config import get_stream_writer
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadState
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed
|
||||
from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config
|
||||
from deerflow.subagents.executor import SubagentStatus, cleanup_background_task, get_background_task_result, request_cancel_background_task
|
||||
from deerflow.subagents.config import resolve_subagent_model_name
|
||||
from deerflow.subagents.executor import (
|
||||
SubagentStatus,
|
||||
cleanup_background_task,
|
||||
get_background_task_result,
|
||||
request_cancel_background_task,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.config.app_config import AppConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_runtime_app_config(runtime: Any) -> "AppConfig | None":
|
||||
context = getattr(runtime, "context", None)
|
||||
if isinstance(context, dict):
|
||||
app_config = context.get("app_config")
|
||||
if app_config is not None:
|
||||
return cast("AppConfig", app_config)
|
||||
return None
|
||||
|
||||
|
||||
def _merge_skill_allowlists(parent: list[str] | None, child: list[str] | None) -> list[str] | None:
|
||||
"""Return the effective subagent skill allowlist under the parent policy."""
|
||||
if parent is None:
|
||||
@ -74,15 +93,18 @@ async def task_tool(
|
||||
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
|
||||
max_turns: Optional maximum number of agent turns. Defaults to subagent's configured max.
|
||||
"""
|
||||
available_subagent_names = get_available_subagent_names()
|
||||
runtime_app_config = _get_runtime_app_config(runtime)
|
||||
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
|
||||
|
||||
# Get subagent configuration
|
||||
config = get_subagent_config(subagent_type)
|
||||
config = get_subagent_config(subagent_type, app_config=runtime_app_config) if runtime_app_config is not None else get_subagent_config(subagent_type)
|
||||
if config is None:
|
||||
available = ", ".join(available_subagent_names)
|
||||
return f"Error: Unknown subagent type '{subagent_type}'. Available: {available}"
|
||||
if subagent_type == "bash" and not is_host_bash_allowed():
|
||||
return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
|
||||
if subagent_type == "bash":
|
||||
host_bash_allowed = is_host_bash_allowed(runtime_app_config) if runtime_app_config is not None else is_host_bash_allowed()
|
||||
if not host_bash_allowed:
|
||||
return f"Error: {LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE}"
|
||||
|
||||
# Build config overrides
|
||||
overrides: dict = {}
|
||||
@ -129,20 +151,34 @@ async def task_tool(
|
||||
|
||||
# Inherit parent agent's tool_groups so subagents respect the same restrictions
|
||||
parent_tool_groups = metadata.get("tool_groups")
|
||||
resolved_app_config = runtime_app_config
|
||||
if config.model == "inherit" and parent_model is None and resolved_app_config is None:
|
||||
resolved_app_config = get_app_config()
|
||||
effective_model = resolve_subagent_model_name(config, parent_model, app_config=resolved_app_config)
|
||||
|
||||
# Subagents should not have subagent tools enabled (prevent recursive nesting)
|
||||
tools = get_available_tools(model_name=parent_model, groups=parent_tool_groups, subagent_enabled=False)
|
||||
available_tools_kwargs = {
|
||||
"model_name": effective_model,
|
||||
"groups": parent_tool_groups,
|
||||
"subagent_enabled": False,
|
||||
}
|
||||
if resolved_app_config is not None:
|
||||
available_tools_kwargs["app_config"] = resolved_app_config
|
||||
tools = get_available_tools(**available_tools_kwargs)
|
||||
|
||||
# Create executor
|
||||
executor = SubagentExecutor(
|
||||
config=config,
|
||||
tools=tools,
|
||||
parent_model=parent_model,
|
||||
sandbox_state=sandbox_state,
|
||||
thread_data=thread_data,
|
||||
thread_id=thread_id,
|
||||
trace_id=trace_id,
|
||||
)
|
||||
executor_kwargs = {
|
||||
"config": config,
|
||||
"tools": tools,
|
||||
"parent_model": parent_model,
|
||||
"sandbox_state": sandbox_state,
|
||||
"thread_data": thread_data,
|
||||
"thread_id": thread_id,
|
||||
"trace_id": trace_id,
|
||||
}
|
||||
if resolved_app_config is not None:
|
||||
executor_kwargs["app_config"] = resolved_app_config
|
||||
executor = SubagentExecutor(**executor_kwargs)
|
||||
|
||||
# Start background execution (always async to prevent blocking)
|
||||
# Use tool_call_id as task_id for better traceability
|
||||
@ -177,11 +213,12 @@ async def task_tool(
|
||||
last_status = result.status
|
||||
|
||||
# Check for new AI messages and send task_running events
|
||||
current_message_count = len(result.ai_messages)
|
||||
ai_messages = result.ai_messages or []
|
||||
current_message_count = len(ai_messages)
|
||||
if current_message_count > last_message_count:
|
||||
# Send task_running event for each new message
|
||||
for i in range(last_message_count, current_message_count):
|
||||
message = result.ai_messages[i]
|
||||
message = ai_messages[i]
|
||||
writer(
|
||||
{
|
||||
"type": "task_running",
|
||||
|
||||
@ -141,10 +141,14 @@ def get_available_tools(
|
||||
# Add invoke_acp_agent tool if any ACP agents are configured
|
||||
acp_tools: list[BaseTool] = []
|
||||
try:
|
||||
from deerflow.config.acp_config import get_acp_agents
|
||||
from deerflow.tools.builtins.invoke_acp_agent_tool import build_invoke_acp_agent_tool
|
||||
|
||||
acp_agents = get_acp_agents()
|
||||
if app_config is None:
|
||||
from deerflow.config.acp_config import get_acp_agents
|
||||
|
||||
acp_agents = get_acp_agents()
|
||||
else:
|
||||
acp_agents = getattr(config, "acp_agents", {}) or {}
|
||||
if acp_agents:
|
||||
acp_tools.append(build_invoke_acp_agent_tool(acp_agents))
|
||||
logger.info(f"Including invoke_acp_agent tool ({len(acp_agents)} agent(s): {list(acp_agents.keys())})")
|
||||
|
||||
@ -4,8 +4,10 @@ Pure business logic — no FastAPI/HTTP dependencies.
|
||||
Both Gateway and Client delegate to these functions.
|
||||
"""
|
||||
|
||||
import errno
|
||||
import os
|
||||
import re
|
||||
import stat
|
||||
from pathlib import Path
|
||||
from urllib.parse import quote
|
||||
|
||||
@ -17,6 +19,10 @@ class PathTraversalError(ValueError):
|
||||
"""Raised when a path escapes its allowed base directory."""
|
||||
|
||||
|
||||
class UnsafeUploadPathError(ValueError):
|
||||
"""Raised when an upload destination is not a safe regular file path."""
|
||||
|
||||
|
||||
# thread_id must be alphanumeric, hyphens, underscores, or dots only.
|
||||
_SAFE_THREAD_ID = re.compile(r"^[a-zA-Z0-9._-]+$")
|
||||
|
||||
@ -109,6 +115,64 @@ def validate_path_traversal(path: Path, base: Path) -> None:
|
||||
raise PathTraversalError("Path traversal detected") from None
|
||||
|
||||
|
||||
def open_upload_file_no_symlink(base_dir: Path, filename: str) -> tuple[Path, object]:
|
||||
"""Open an upload destination for safe streaming writes.
|
||||
|
||||
Upload directories may be mounted into local sandboxes. A sandbox process can
|
||||
therefore leave a symlink at a future upload filename. Normal ``Path.write_bytes``
|
||||
follows that link and can overwrite files outside the uploads directory with
|
||||
gateway privileges. This helper rejects symlink destinations and uses
|
||||
``O_NOFOLLOW`` where available so the final path component cannot be raced into
|
||||
a symlink between validation and open.
|
||||
"""
|
||||
safe_name = normalize_filename(filename)
|
||||
dest = base_dir / safe_name
|
||||
|
||||
try:
|
||||
st = os.lstat(dest)
|
||||
except FileNotFoundError:
|
||||
st = None
|
||||
|
||||
if st is not None and not stat.S_ISREG(st.st_mode):
|
||||
raise UnsafeUploadPathError(f"Upload destination is not a regular file: {safe_name}")
|
||||
|
||||
validate_path_traversal(dest, base_dir)
|
||||
|
||||
if not hasattr(os, "O_NOFOLLOW"):
|
||||
raise UnsafeUploadPathError("Upload writes require O_NOFOLLOW support")
|
||||
|
||||
flags = os.O_WRONLY | os.O_CREAT | os.O_NOFOLLOW
|
||||
if hasattr(os, "O_NONBLOCK"):
|
||||
flags |= os.O_NONBLOCK
|
||||
|
||||
try:
|
||||
fd = os.open(dest, flags, 0o600)
|
||||
except OSError as exc:
|
||||
if exc.errno in {errno.ELOOP, errno.EISDIR, errno.ENOTDIR, errno.ENXIO, errno.EAGAIN}:
|
||||
raise UnsafeUploadPathError(f"Unsafe upload destination: {safe_name}") from exc
|
||||
raise
|
||||
|
||||
try:
|
||||
opened_stat = os.fstat(fd)
|
||||
if not stat.S_ISREG(opened_stat.st_mode) or opened_stat.st_nlink != 1:
|
||||
raise UnsafeUploadPathError(f"Upload destination is not an exclusive regular file: {safe_name}")
|
||||
os.ftruncate(fd, 0)
|
||||
fh = os.fdopen(fd, "wb")
|
||||
fd = -1
|
||||
finally:
|
||||
if fd >= 0:
|
||||
os.close(fd)
|
||||
return dest, fh
|
||||
|
||||
|
||||
def write_upload_file_no_symlink(base_dir: Path, filename: str, data: bytes) -> Path:
|
||||
"""Write upload bytes without following a pre-existing destination symlink."""
|
||||
dest, fh = open_upload_file_no_symlink(base_dir, filename)
|
||||
with fh:
|
||||
fh.write(data)
|
||||
return dest
|
||||
|
||||
|
||||
def list_files_in_dir(directory: Path) -> dict:
|
||||
"""List files (not directories) in *directory*.
|
||||
|
||||
|
||||
75
backend/packages/harness/deerflow/utils/time.py
Normal file
75
backend/packages/harness/deerflow/utils/time.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""ISO 8601 timestamp helpers for the Gateway and embedded runtime.
|
||||
|
||||
DeerFlow stores and serializes thread/run timestamps as ISO 8601 UTC
|
||||
strings to match the LangGraph Platform schema (see
|
||||
``langgraph_sdk.schema.Thread``, where ``created_at`` / ``updated_at``
|
||||
are ``datetime`` and JSON-encode to ISO 8601). All timestamp generation
|
||||
should funnel through :func:`now_iso` so the wire format stays
|
||||
consistent across endpoints, the embedded ``RunManager``, and the
|
||||
checkpoint metadata written by the Gateway.
|
||||
|
||||
:func:`coerce_iso` provides a forward-compatible read path for legacy
|
||||
records that historically stored ``str(time.time())`` floats.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
|
||||
__all__ = ["coerce_iso", "now_iso"]
|
||||
|
||||
_UNIX_TIMESTAMP_PATTERN = re.compile(r"^\d{10}(?:\.\d+)?$")
|
||||
"""Matches the unix-timestamp string shape historically written by
|
||||
``str(time.time())`` (10-digit seconds with optional fractional part).
|
||||
The 10-digit anchor avoids accidentally rewriting ISO years like
|
||||
``"2026"`` and stays valid until the year 2286.
|
||||
"""
|
||||
|
||||
|
||||
def now_iso() -> str:
|
||||
"""Return the current UTC time as an ISO 8601 string.
|
||||
|
||||
Example: ``"2026-04-27T03:19:46.511479+00:00"``.
|
||||
"""
|
||||
return datetime.now(UTC).isoformat()
|
||||
|
||||
|
||||
def coerce_iso(value: object) -> str:
|
||||
"""Best-effort coerce a stored timestamp to an ISO 8601 string.
|
||||
|
||||
Translates legacy unix-timestamp floats / strings written by older
|
||||
DeerFlow versions into ISO without a one-shot migration. ISO strings
|
||||
pass through unchanged; ``datetime`` instances are normalised to UTC
|
||||
(tz-naive values are assumed to be UTC) and emitted via
|
||||
``isoformat()`` so the wire format always uses the ``T`` separator;
|
||||
empty values become ``""``; unrecognised values are stringified as a
|
||||
last resort.
|
||||
"""
|
||||
if value is None or value == "":
|
||||
return ""
|
||||
if isinstance(value, bool):
|
||||
# ``bool`` is a subclass of ``int`` — treat as garbage, not 0/1.
|
||||
return str(value)
|
||||
if isinstance(value, datetime):
|
||||
# ``datetime`` must be handled before the ``int``/``float`` check;
|
||||
# str(datetime) would produce ``"YYYY-MM-DD HH:MM:SS+00:00"``
|
||||
# (space separator), which breaks strict ISO 8601 consumers.
|
||||
if value.tzinfo is None:
|
||||
value = value.replace(tzinfo=UTC)
|
||||
else:
|
||||
value = value.astimezone(UTC)
|
||||
return value.isoformat()
|
||||
if isinstance(value, (int, float)):
|
||||
try:
|
||||
return datetime.fromtimestamp(float(value), UTC).isoformat()
|
||||
except (ValueError, OverflowError, OSError):
|
||||
return str(value)
|
||||
if isinstance(value, str):
|
||||
if _UNIX_TIMESTAMP_PATTERN.match(value):
|
||||
try:
|
||||
return datetime.fromtimestamp(float(value), UTC).isoformat()
|
||||
except (ValueError, OverflowError, OSError):
|
||||
return value
|
||||
return value
|
||||
return str(value)
|
||||
@ -47,4 +47,3 @@ members = ["packages/harness"]
|
||||
|
||||
[tool.uv.sources]
|
||||
deerflow-harness = { workspace = true }
|
||||
|
||||
|
||||
@ -133,6 +133,58 @@ class TestListDirSerialization:
|
||||
assert lock_was_held == [True], "list_dir must hold the lock during exec_command"
|
||||
|
||||
|
||||
class TestNoChangeTimeout:
|
||||
"""Verify that no_change_timeout is forwarded to every exec_command call."""
|
||||
|
||||
def test_execute_command_passes_no_change_timeout(self, sandbox):
|
||||
"""execute_command should pass no_change_timeout to exec_command."""
|
||||
calls = []
|
||||
|
||||
def mock_exec(command, **kwargs):
|
||||
calls.append(kwargs)
|
||||
return SimpleNamespace(data=SimpleNamespace(output="ok"))
|
||||
|
||||
sandbox._client.shell.exec_command = mock_exec
|
||||
|
||||
sandbox.execute_command("echo hello")
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
|
||||
|
||||
def test_retry_passes_no_change_timeout(self, sandbox):
|
||||
"""The ErrorObservation retry path should also pass no_change_timeout."""
|
||||
calls = []
|
||||
|
||||
def mock_exec(command, **kwargs):
|
||||
calls.append(kwargs)
|
||||
if len(calls) == 1:
|
||||
return SimpleNamespace(data=SimpleNamespace(output="'ErrorObservation' object has no attribute 'exit_code'"))
|
||||
return SimpleNamespace(data=SimpleNamespace(output="ok"))
|
||||
|
||||
sandbox._client.shell.exec_command = mock_exec
|
||||
|
||||
sandbox.execute_command("echo hello")
|
||||
|
||||
assert len(calls) == 2
|
||||
assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
|
||||
assert calls[1].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
|
||||
|
||||
def test_list_dir_passes_no_change_timeout(self, sandbox):
|
||||
"""list_dir should pass no_change_timeout to exec_command."""
|
||||
calls = []
|
||||
|
||||
def mock_exec(command, **kwargs):
|
||||
calls.append(kwargs)
|
||||
return SimpleNamespace(data=SimpleNamespace(output="/a\n/b"))
|
||||
|
||||
sandbox._client.shell.exec_command = mock_exec
|
||||
|
||||
sandbox.list_dir("/test")
|
||||
|
||||
assert len(calls) == 1
|
||||
assert calls[0].get("no_change_timeout") == sandbox._DEFAULT_NO_CHANGE_TIMEOUT
|
||||
|
||||
|
||||
class TestConcurrentFileWrites:
|
||||
"""Verify file write paths do not lose concurrent updates."""
|
||||
|
||||
|
||||
@ -3,11 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.message_bus import MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from app.channels.message_bus import InboundMessage, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
|
||||
def _run(coro):
|
||||
@ -248,6 +249,109 @@ class TestResolveAttachments:
|
||||
assert result[0].filename == "data.csv"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inbound file ingestion tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInboundFileIngestion:
|
||||
def test_rejects_preexisting_symlink_destination(self, tmp_path):
|
||||
from app.channels import manager
|
||||
|
||||
uploads_dir = tmp_path / "uploads"
|
||||
uploads_dir.mkdir()
|
||||
outside_file = tmp_path / "outside-created.txt"
|
||||
(uploads_dir / "victim.txt").symlink_to(outside_file)
|
||||
|
||||
msg = InboundMessage(
|
||||
channel_name="test-channel",
|
||||
chat_id="chat-1",
|
||||
user_id="user-1",
|
||||
text="see attachment",
|
||||
files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}],
|
||||
)
|
||||
|
||||
async def fake_reader(file_info, client):
|
||||
return b"attacker data"
|
||||
|
||||
with (
|
||||
patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir),
|
||||
patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False),
|
||||
):
|
||||
result = _run(manager._ingest_inbound_files("thread-1", msg))
|
||||
|
||||
assert result == []
|
||||
assert not outside_file.exists()
|
||||
assert (uploads_dir / "victim.txt").is_symlink()
|
||||
|
||||
def test_rejects_dangling_symlink_destination(self, tmp_path):
|
||||
from app.channels import manager
|
||||
|
||||
uploads_dir = tmp_path / "uploads"
|
||||
uploads_dir.mkdir()
|
||||
missing_target = tmp_path / "missing-created.txt"
|
||||
(uploads_dir / "victim.txt").symlink_to(missing_target)
|
||||
|
||||
msg = InboundMessage(
|
||||
channel_name="test-channel",
|
||||
chat_id="chat-1",
|
||||
user_id="user-1",
|
||||
text="see attachment",
|
||||
files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}],
|
||||
)
|
||||
|
||||
async def fake_reader(file_info, client):
|
||||
return b"attacker data"
|
||||
|
||||
with (
|
||||
patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir),
|
||||
patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False),
|
||||
):
|
||||
result = _run(manager._ingest_inbound_files("thread-1", msg))
|
||||
|
||||
assert result == []
|
||||
assert not missing_target.exists()
|
||||
assert (uploads_dir / "victim.txt").is_symlink()
|
||||
|
||||
def test_hardlinked_existing_file_is_not_overwritten(self, tmp_path):
|
||||
from app.channels import manager
|
||||
|
||||
uploads_dir = tmp_path / "uploads"
|
||||
uploads_dir.mkdir()
|
||||
outside_file = tmp_path / "outside-created.txt"
|
||||
outside_file.write_text("protected", encoding="utf-8")
|
||||
os.link(outside_file, uploads_dir / "victim.txt")
|
||||
|
||||
msg = InboundMessage(
|
||||
channel_name="test-channel",
|
||||
chat_id="chat-1",
|
||||
user_id="user-1",
|
||||
text="see attachment",
|
||||
files=[{"filename": "victim.txt", "url": "https://example.invalid/victim.txt"}],
|
||||
)
|
||||
|
||||
async def fake_reader(file_info, client):
|
||||
return b"new attachment data"
|
||||
|
||||
with (
|
||||
patch("deerflow.uploads.manager.ensure_uploads_dir", return_value=uploads_dir),
|
||||
patch.dict(manager.INBOUND_FILE_READERS, {"test-channel": fake_reader}, clear=False),
|
||||
):
|
||||
result = _run(manager._ingest_inbound_files("thread-1", msg))
|
||||
|
||||
assert result == [
|
||||
{
|
||||
"filename": "victim_1.txt",
|
||||
"size": len(b"new attachment data"),
|
||||
"path": "/mnt/user-data/uploads/victim_1.txt",
|
||||
"is_image": False,
|
||||
}
|
||||
]
|
||||
assert outside_file.read_text(encoding="utf-8") == "protected"
|
||||
assert (uploads_dir / "victim.txt").read_text(encoding="utf-8") == "protected"
|
||||
assert (uploads_dir / "victim_1.txt").read_bytes() == b"new attachment data"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Channel base class _on_outbound with attachments
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -17,6 +17,7 @@ import json
|
||||
import os
|
||||
import uuid
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
@ -94,12 +95,18 @@ def e2e_env(tmp_path, monkeypatch):
|
||||
"""Isolated filesystem environment for E2E tests.
|
||||
|
||||
- DEER_FLOW_HOME → tmp_path (all thread data lands in a temp dir)
|
||||
- DEER_FLOW_PROJECT_ROOT → repository root (shared skills/config assets
|
||||
still resolve correctly when tests run from backend/)
|
||||
- Singletons reset so they pick up the new env
|
||||
- Title/memory/summarization disabled to avoid extra LLM calls
|
||||
- AppConfig built programmatically (avoids config.yaml param-name issues)
|
||||
"""
|
||||
# 1. Filesystem isolation
|
||||
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
|
||||
monkeypatch.setenv(
|
||||
"DEER_FLOW_PROJECT_ROOT",
|
||||
str(Path(__file__).resolve().parents[2]),
|
||||
)
|
||||
monkeypatch.setattr("deerflow.config.paths._paths", None)
|
||||
monkeypatch.setattr("deerflow.sandbox.sandbox_provider._default_sandbox_provider", None)
|
||||
|
||||
|
||||
@ -82,6 +82,36 @@ def test_parse_response_text_content():
|
||||
assert result.generations[0].message.content == "Hello world"
|
||||
|
||||
|
||||
def test_parse_response_populates_usage_metadata():
|
||||
model = _make_model()
|
||||
response = {
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"content": [{"type": "output_text", "text": "Hello world"}],
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
"input_tokens_details": {"cached_tokens": 3},
|
||||
"output_tokens_details": {"reasoning_tokens": 2},
|
||||
},
|
||||
"model": "gpt-5.4",
|
||||
}
|
||||
|
||||
result = model._parse_response(response)
|
||||
|
||||
meta = result.generations[0].message.usage_metadata
|
||||
assert meta is not None
|
||||
assert meta["input_tokens"] == 10
|
||||
assert meta["output_tokens"] == 5
|
||||
assert meta["total_tokens"] == 15
|
||||
assert meta["input_token_details"]["cache_read"] == 3
|
||||
assert meta["output_token_details"]["reasoning"] == 2
|
||||
|
||||
|
||||
def test_parse_response_reasoning_content():
|
||||
model = _make_model()
|
||||
response = {
|
||||
|
||||
@ -697,3 +697,33 @@ def test_get_available_tools_includes_invoke_acp_agent_when_agents_configured(mo
|
||||
assert "invoke_acp_agent" in [tool.name for tool in tools]
|
||||
|
||||
load_acp_config_from_dict({})
|
||||
|
||||
|
||||
def test_get_available_tools_uses_explicit_app_config_for_acp_agents(monkeypatch):
|
||||
explicit_agents = {"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")}
|
||||
explicit_config = SimpleNamespace(
|
||||
tools=[],
|
||||
models=[],
|
||||
tool_search=SimpleNamespace(enabled=False),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
get_model_config=lambda name: None,
|
||||
acp_agents=explicit_agents,
|
||||
)
|
||||
sentinel_tool = SimpleNamespace(name="invoke_acp_agent")
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def fail_get_acp_agents():
|
||||
raise AssertionError("ambient get_acp_agents() must not be used when app_config is explicit")
|
||||
|
||||
def fake_build_invoke_acp_agent_tool(agents):
|
||||
captured["agents"] = agents
|
||||
return sentinel_tool
|
||||
|
||||
monkeypatch.setattr("deerflow.tools.tools.is_host_bash_allowed", lambda config=None: True)
|
||||
monkeypatch.setattr("deerflow.config.acp_config.get_acp_agents", fail_get_acp_agents)
|
||||
monkeypatch.setattr("deerflow.tools.builtins.invoke_acp_agent_tool.build_invoke_acp_agent_tool", fake_build_invoke_acp_agent_tool)
|
||||
|
||||
tools = get_available_tools(include_mcp=False, subagent_enabled=False, app_config=explicit_config)
|
||||
|
||||
assert captured["agents"] is explicit_agents
|
||||
assert "invoke_acp_agent" in [tool.name for tool in tools]
|
||||
|
||||
@ -72,6 +72,44 @@ def test_internal_make_lead_agent_uses_explicit_app_config(monkeypatch):
|
||||
assert result["model"] is not None
|
||||
|
||||
|
||||
def test_make_lead_agent_uses_runtime_app_config_from_context_without_global_read(monkeypatch):
|
||||
app_config = _make_app_config([_make_model("context-model", supports_thinking=False)])
|
||||
|
||||
import deerflow.tools as tools_module
|
||||
|
||||
def _raise_get_app_config():
|
||||
raise AssertionError("ambient get_app_config() must not be used when runtime context already carries app_config")
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", _raise_get_app_config)
|
||||
monkeypatch.setattr(tools_module, "get_available_tools", lambda **kwargs: [])
|
||||
monkeypatch.setattr(lead_agent_module, "_build_middlewares", lambda config, model_name, agent_name=None, **kwargs: [])
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _fake_create_chat_model(*, name, thinking_enabled, reasoning_effort=None, app_config=None):
|
||||
captured["name"] = name
|
||||
captured["app_config"] = app_config
|
||||
return object()
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
|
||||
monkeypatch.setattr(lead_agent_module, "create_agent", lambda **kwargs: kwargs)
|
||||
|
||||
result = lead_agent_module.make_lead_agent(
|
||||
{
|
||||
"context": {
|
||||
"model_name": "context-model",
|
||||
"app_config": app_config,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
assert captured == {
|
||||
"name": "context-model",
|
||||
"app_config": app_config,
|
||||
}
|
||||
assert result["model"] is not None
|
||||
|
||||
|
||||
def test_resolve_model_name_falls_back_to_default(monkeypatch, caplog):
|
||||
app_config = _make_app_config(
|
||||
[
|
||||
@ -276,6 +314,16 @@ def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypa
|
||||
)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_summarization_middleware", lambda **kwargs: None)
|
||||
monkeypatch.setattr(lead_agent_module, "_create_todo_list_middleware", lambda is_plan_mode: None)
|
||||
monkeypatch.setattr(
|
||||
lead_agent_module,
|
||||
"TitleMiddleware",
|
||||
lambda *, app_config: captured.setdefault("title_app_config", app_config) or "title-middleware",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
lead_agent_module,
|
||||
"MemoryMiddleware",
|
||||
lambda agent_name=None, *, memory_config: captured.setdefault("memory_config", memory_config) or "memory-middleware",
|
||||
)
|
||||
|
||||
middlewares = lead_agent_module._build_middlewares(
|
||||
{"configurable": {"is_plan_mode": False, "subagent_enabled": False}},
|
||||
@ -286,17 +334,16 @@ def test_build_middlewares_passes_explicit_app_config_to_shared_factory(monkeypa
|
||||
assert captured == {
|
||||
"app_config": app_config,
|
||||
"lazy_init": True,
|
||||
"title_app_config": app_config,
|
||||
"memory_config": app_config.memory,
|
||||
}
|
||||
assert middlewares[0] == "base-middleware"
|
||||
|
||||
|
||||
def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
lead_agent_module,
|
||||
"get_summarization_config",
|
||||
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
|
||||
)
|
||||
monkeypatch.setattr(lead_agent_module, "get_memory_config", lambda: MemoryConfig(enabled=False))
|
||||
app_config = _make_app_config([_make_model("model-masswork", supports_thinking=False)])
|
||||
app_config.summarization = SummarizationConfig(enabled=True, model_name="model-masswork")
|
||||
app_config.memory = MemoryConfig(enabled=False)
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@ -311,13 +358,55 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
||||
captured["app_config"] = app_config
|
||||
return fake_model
|
||||
|
||||
def _raise_get_app_config():
|
||||
raise AssertionError("ambient get_app_config() must not be used when app_config is explicit")
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", _raise_get_app_config)
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
|
||||
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs)
|
||||
|
||||
middleware = lead_agent_module._create_summarization_middleware(app_config=_make_app_config([_make_model("model-masswork", supports_thinking=False)]))
|
||||
middleware = lead_agent_module._create_summarization_middleware(app_config=app_config)
|
||||
|
||||
assert captured["name"] == "model-masswork"
|
||||
assert captured["thinking_enabled"] is False
|
||||
assert captured["app_config"] is not None
|
||||
assert captured["app_config"] is app_config
|
||||
assert middleware["model"] is fake_model
|
||||
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])
|
||||
|
||||
|
||||
def test_create_summarization_middleware_threads_resolved_app_config_to_model(monkeypatch):
|
||||
fallback_app_config = _make_app_config([_make_model("fallback-model", supports_thinking=False)])
|
||||
fallback_app_config.summarization = SummarizationConfig(enabled=True, model_name="fallback-model")
|
||||
fallback_app_config.memory = MemoryConfig(enabled=False)
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
fake_model = MagicMock()
|
||||
fake_model.with_config.return_value = fake_model
|
||||
|
||||
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None, app_config=None):
|
||||
captured["app_config"] = app_config
|
||||
return fake_model
|
||||
|
||||
monkeypatch.setattr(lead_agent_module, "get_app_config", lambda: fallback_app_config)
|
||||
monkeypatch.setattr(lead_agent_module, "create_chat_model", _fake_create_chat_model)
|
||||
monkeypatch.setattr(lead_agent_module, "DeerFlowSummarizationMiddleware", lambda **kwargs: kwargs)
|
||||
|
||||
lead_agent_module._create_summarization_middleware()
|
||||
|
||||
assert captured["app_config"] is fallback_app_config
|
||||
|
||||
|
||||
def test_memory_middleware_uses_explicit_memory_config_without_global_read(monkeypatch):
|
||||
from deerflow.agents.middlewares import memory_middleware as memory_middleware_module
|
||||
from deerflow.agents.middlewares.memory_middleware import MemoryMiddleware
|
||||
|
||||
def _raise_get_memory_config():
|
||||
raise AssertionError("ambient get_memory_config() must not be used when memory_config is explicit")
|
||||
|
||||
monkeypatch.setattr(memory_middleware_module, "get_memory_config", _raise_get_memory_config)
|
||||
|
||||
middleware = MemoryMiddleware(memory_config=MemoryConfig(enabled=False))
|
||||
|
||||
assert middleware.after_agent({"messages": []}, runtime=MagicMock(context={"thread_id": "thread-1"})) is None
|
||||
|
||||
@ -4,6 +4,7 @@ from types import SimpleNamespace
|
||||
import anyio
|
||||
|
||||
from deerflow.agents.lead_agent import prompt as prompt_module
|
||||
from deerflow.config.subagents_config import CustomSubagentConfig, SubagentsAppConfig
|
||||
from deerflow.skills.types import Skill
|
||||
|
||||
|
||||
@ -40,6 +41,21 @@ def test_build_custom_mounts_section_lists_configured_mounts(monkeypatch):
|
||||
assert "read-only" in section
|
||||
|
||||
|
||||
def test_build_custom_mounts_section_uses_explicit_app_config_without_global_read(monkeypatch):
|
||||
mounts = [SimpleNamespace(container_path="/home/user/shared", read_only=False)]
|
||||
config = SimpleNamespace(sandbox=SimpleNamespace(mounts=mounts))
|
||||
|
||||
def fail_get_app_config():
|
||||
raise AssertionError("ambient get_app_config() must not be used when app_config is explicit")
|
||||
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", fail_get_app_config)
|
||||
|
||||
section = prompt_module._build_custom_mounts_section(app_config=config)
|
||||
|
||||
assert "`/home/user/shared`" in section
|
||||
assert "read-write" in section
|
||||
|
||||
|
||||
def test_apply_prompt_template_includes_custom_mounts(monkeypatch):
|
||||
mounts = [SimpleNamespace(container_path="/home/user/shared", read_only=False)]
|
||||
config = SimpleNamespace(
|
||||
@ -49,8 +65,8 @@ def test_apply_prompt_template_includes_custom_mounts(monkeypatch):
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
|
||||
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda **kwargs: "")
|
||||
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
|
||||
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda **kwargs: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None, **kwargs: "")
|
||||
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
|
||||
|
||||
prompt = prompt_module.apply_prompt_template()
|
||||
@ -67,8 +83,8 @@ def test_apply_prompt_template_includes_relative_path_guidance(monkeypatch):
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", lambda: config)
|
||||
monkeypatch.setattr(prompt_module, "_get_enabled_skills", lambda: [])
|
||||
monkeypatch.setattr(prompt_module, "get_deferred_tools_prompt_section", lambda **kwargs: "")
|
||||
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None: "")
|
||||
monkeypatch.setattr(prompt_module, "_build_acp_section", lambda **kwargs: "")
|
||||
monkeypatch.setattr(prompt_module, "_get_memory_context", lambda agent_name=None, **kwargs: "")
|
||||
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
|
||||
|
||||
prompt = prompt_module.apply_prompt_template()
|
||||
@ -77,6 +93,123 @@ def test_apply_prompt_template_includes_relative_path_guidance(monkeypatch):
|
||||
assert "`hello.txt`, `../uploads/data.csv`, and `../outputs/report.md`" in prompt
|
||||
|
||||
|
||||
def test_apply_prompt_template_threads_explicit_app_config_without_global_config(monkeypatch):
|
||||
mounts = [SimpleNamespace(container_path="/home/user/shared", read_only=False)]
|
||||
explicit_config = SimpleNamespace(
|
||||
sandbox=SimpleNamespace(mounts=mounts),
|
||||
skills=SimpleNamespace(container_path="/mnt/explicit-skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
tool_search=SimpleNamespace(enabled=False),
|
||||
memory=SimpleNamespace(enabled=False, injection_enabled=True, max_injection_tokens=2000),
|
||||
acp_agents={},
|
||||
)
|
||||
|
||||
def fail_get_app_config():
|
||||
raise AssertionError("ambient get_app_config() must not be used when app_config is explicit")
|
||||
|
||||
def fail_get_memory_config():
|
||||
raise AssertionError("ambient get_memory_config() must not be used when app_config is explicit")
|
||||
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", fail_get_app_config)
|
||||
monkeypatch.setattr("deerflow.config.memory_config.get_memory_config", fail_get_memory_config)
|
||||
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", lambda app_config=None: SimpleNamespace(load_skills=lambda enabled_only=True: []))
|
||||
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
|
||||
|
||||
prompt = prompt_module.apply_prompt_template(app_config=explicit_config)
|
||||
|
||||
assert "`/home/user/shared`" in prompt
|
||||
assert "Custom Mounted Directories" in prompt
|
||||
|
||||
|
||||
def test_apply_prompt_template_threads_explicit_app_config_to_subagents_without_global_config(monkeypatch):
|
||||
explicit_config = SimpleNamespace(
|
||||
sandbox=SimpleNamespace(
|
||||
use="deerflow.sandbox.local:LocalSandboxProvider",
|
||||
allow_host_bash=False,
|
||||
mounts=[],
|
||||
),
|
||||
subagents=SubagentsAppConfig(
|
||||
custom_agents={
|
||||
"researcher": CustomSubagentConfig(
|
||||
description="Research agent\nwith details",
|
||||
system_prompt="You research.",
|
||||
)
|
||||
}
|
||||
),
|
||||
skills=SimpleNamespace(container_path="/mnt/skills"),
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
tool_search=SimpleNamespace(enabled=False),
|
||||
memory=SimpleNamespace(enabled=False, injection_enabled=True, max_injection_tokens=2000),
|
||||
acp_agents={},
|
||||
)
|
||||
|
||||
def fail_get_app_config():
|
||||
raise AssertionError("ambient get_app_config() must not be used when app_config is explicit")
|
||||
|
||||
def fail_get_subagents_app_config():
|
||||
raise AssertionError("ambient get_subagents_app_config() must not be used when app_config is explicit")
|
||||
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", fail_get_app_config)
|
||||
monkeypatch.setattr("deerflow.config.subagents_config.get_subagents_app_config", fail_get_subagents_app_config)
|
||||
monkeypatch.setattr(prompt_module, "get_or_new_skill_storage", lambda app_config=None: SimpleNamespace(load_skills=lambda enabled_only=True: []))
|
||||
monkeypatch.setattr(prompt_module, "get_agent_soul", lambda agent_name=None: "")
|
||||
|
||||
prompt = prompt_module.apply_prompt_template(subagent_enabled=True, app_config=explicit_config)
|
||||
|
||||
assert "**researcher**: Research agent" in prompt
|
||||
assert "**bash**" not in prompt
|
||||
|
||||
|
||||
def test_build_acp_section_uses_explicit_app_config_without_global_config(monkeypatch):
|
||||
explicit_config = SimpleNamespace(acp_agents={"codex": object()})
|
||||
|
||||
def fail_get_acp_agents():
|
||||
raise AssertionError("ambient get_acp_agents() must not be used when app_config is explicit")
|
||||
|
||||
monkeypatch.setattr("deerflow.config.acp_config.get_acp_agents", fail_get_acp_agents)
|
||||
|
||||
section = prompt_module._build_acp_section(app_config=explicit_config)
|
||||
|
||||
assert "ACP Agent Tasks" in section
|
||||
assert "/mnt/acp-workspace/" in section
|
||||
|
||||
|
||||
def test_get_memory_context_uses_explicit_app_config_without_global_config(monkeypatch):
|
||||
explicit_config = SimpleNamespace(
|
||||
memory=SimpleNamespace(enabled=True, injection_enabled=True, max_injection_tokens=1234),
|
||||
)
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def fail_get_memory_config():
|
||||
raise AssertionError("ambient get_memory_config() must not be used when app_config is explicit")
|
||||
|
||||
def fake_get_memory_data(agent_name=None, *, user_id=None):
|
||||
captured["agent_name"] = agent_name
|
||||
captured["user_id"] = user_id
|
||||
return {"facts": []}
|
||||
|
||||
def fake_format_memory_for_injection(memory_data, *, max_tokens):
|
||||
captured["memory_data"] = memory_data
|
||||
captured["max_tokens"] = max_tokens
|
||||
return "remember this"
|
||||
|
||||
monkeypatch.setattr("deerflow.config.memory_config.get_memory_config", fail_get_memory_config)
|
||||
monkeypatch.setattr("deerflow.runtime.user_context.get_effective_user_id", lambda: "user-1")
|
||||
monkeypatch.setattr("deerflow.agents.memory.get_memory_data", fake_get_memory_data)
|
||||
monkeypatch.setattr("deerflow.agents.memory.format_memory_for_injection", fake_format_memory_for_injection)
|
||||
|
||||
context = prompt_module._get_memory_context("agent-a", app_config=explicit_config)
|
||||
|
||||
assert "<memory>" in context
|
||||
assert "remember this" in context
|
||||
assert captured == {
|
||||
"agent_name": "agent-a",
|
||||
"user_id": "user-1",
|
||||
"memory_data": {"facts": []},
|
||||
"max_tokens": 1234,
|
||||
}
|
||||
|
||||
|
||||
def test_refresh_skills_system_prompt_cache_async_reloads_immediately(monkeypatch, tmp_path):
|
||||
def make_skill(name: str) -> Skill:
|
||||
skill_dir = tmp_path / name
|
||||
|
||||
@ -106,7 +106,11 @@ def test_get_skills_prompt_section_uses_explicit_config_for_enabled_skills(monke
|
||||
skill_evolution=SimpleNamespace(enabled=False),
|
||||
)
|
||||
|
||||
def fail_get_app_config():
|
||||
raise AssertionError("ambient get_app_config() must not be used when app_config is explicit")
|
||||
|
||||
monkeypatch.setattr("deerflow.agents.lead_agent.prompt._get_enabled_skills", lambda: [_make_skill("global-skill")])
|
||||
monkeypatch.setattr("deerflow.config.get_app_config", fail_get_app_config)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.agents.lead_agent.prompt.get_or_new_skill_storage",
|
||||
lambda app_config=None, **kwargs: __import__("types").SimpleNamespace(load_skills=lambda *, enabled_only: [_make_skill("explicit-skill")] if app_config is explicit_config else []),
|
||||
|
||||
@ -1,8 +1,14 @@
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, call
|
||||
|
||||
import pytest
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from deerflow.runtime.runs.worker import _agent_factory_supports_app_config, _build_runtime_context, _rollback_to_pre_run_checkpoint
|
||||
from deerflow.runtime.runs.manager import RunManager
|
||||
from deerflow.runtime.runs.schemas import RunStatus
|
||||
from deerflow.runtime.runs.worker import RunContext, _agent_factory_supports_app_config, _build_runtime_context, _install_runtime_context, _rollback_to_pre_run_checkpoint, run_agent
|
||||
|
||||
|
||||
class FakeCheckpointer:
|
||||
@ -12,6 +18,81 @@ class FakeCheckpointer:
|
||||
self.aput_writes = AsyncMock()
|
||||
|
||||
|
||||
def _make_checkpoint(checkpoint_id: str, messages: list[str], version: int):
|
||||
checkpoint = empty_checkpoint()
|
||||
checkpoint["id"] = checkpoint_id
|
||||
checkpoint["channel_values"] = {"messages": messages}
|
||||
checkpoint["channel_versions"] = {"messages": version}
|
||||
return checkpoint
|
||||
|
||||
|
||||
def test_build_runtime_context_includes_app_config_when_present():
|
||||
app_config = object()
|
||||
|
||||
context = _build_runtime_context("thread-1", "run-1", None, app_config)
|
||||
|
||||
assert context["thread_id"] == "thread-1"
|
||||
assert context["run_id"] == "run-1"
|
||||
assert context["app_config"] is app_config
|
||||
|
||||
|
||||
def test_install_runtime_context_preserves_existing_thread_id_and_threads_app_config():
|
||||
app_config = object()
|
||||
config = {"context": {"thread_id": "caller-thread"}}
|
||||
|
||||
_install_runtime_context(
|
||||
config,
|
||||
{
|
||||
"thread_id": "record-thread",
|
||||
"run_id": "run-1",
|
||||
"app_config": app_config,
|
||||
},
|
||||
)
|
||||
|
||||
assert config["context"]["thread_id"] == "caller-thread"
|
||||
assert config["context"]["run_id"] == "run-1"
|
||||
assert config["context"]["app_config"] is app_config
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_run_agent_threads_explicit_app_config_into_config_only_factory():
|
||||
run_manager = RunManager()
|
||||
record = await run_manager.create("thread-1")
|
||||
bridge = SimpleNamespace(
|
||||
publish=AsyncMock(),
|
||||
publish_end=AsyncMock(),
|
||||
cleanup=AsyncMock(),
|
||||
)
|
||||
app_config = object()
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class DummyAgent:
|
||||
async def astream(self, graph_input, config=None, stream_mode=None, subgraphs=False):
|
||||
captured["astream_context"] = config["context"]
|
||||
yield {"messages": []}
|
||||
|
||||
def factory(*, config):
|
||||
captured["factory_context"] = config["context"]
|
||||
return DummyAgent()
|
||||
|
||||
await run_agent(
|
||||
bridge,
|
||||
run_manager,
|
||||
record,
|
||||
ctx=RunContext(checkpointer=None, app_config=app_config),
|
||||
agent_factory=factory,
|
||||
graph_input={},
|
||||
config={},
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert captured["factory_context"]["app_config"] is app_config
|
||||
assert captured["astream_context"]["app_config"] is app_config
|
||||
assert run_manager.get(record.run_id).status == RunStatus.success
|
||||
bridge.publish_end.assert_awaited_once_with(record.run_id)
|
||||
bridge.cleanup.assert_awaited_once_with(record.run_id, delay=60)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_restores_snapshot_without_deleting_thread():
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
||||
@ -39,16 +120,16 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
|
||||
)
|
||||
|
||||
checkpointer.adelete_thread.assert_not_awaited()
|
||||
checkpointer.aput.assert_awaited_once_with(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
||||
{
|
||||
"id": "ckpt-1",
|
||||
"channel_versions": {"messages": 3},
|
||||
"channel_values": {"messages": ["before"]},
|
||||
},
|
||||
{"source": "input"},
|
||||
{"messages": 3},
|
||||
)
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
|
||||
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||
assert restored_checkpoint["id"] != "ckpt-1"
|
||||
assert "channel_versions" in restored_checkpoint
|
||||
assert "channel_values" in restored_checkpoint
|
||||
assert restored_checkpoint["channel_versions"] == {"messages": 3}
|
||||
assert restored_checkpoint["channel_values"] == {"messages": ["before"]}
|
||||
assert restored_metadata == {"source": "input"}
|
||||
assert new_versions == {"messages": 3}
|
||||
assert checkpointer.aput_writes.await_args_list == [
|
||||
call(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
|
||||
@ -63,6 +144,40 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_restored_checkpoint_becomes_latest_with_real_checkpointer():
|
||||
checkpointer = InMemorySaver()
|
||||
thread_config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||
before_checkpoint = _make_checkpoint("0001", ["before"], 1)
|
||||
before_config = checkpointer.put(thread_config, before_checkpoint, {"step": 1}, {"messages": 1})
|
||||
after_checkpoint = _make_checkpoint("0002", ["after"], 2)
|
||||
after_config = checkpointer.put(before_config, after_checkpoint, {"step": 2}, {"messages": 2})
|
||||
checkpointer.put_writes(after_config, [("messages", "pending-after")], task_id="task-after")
|
||||
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="0001",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": before_checkpoint,
|
||||
"metadata": {"step": 1},
|
||||
"pending_writes": [("task-before", "messages", "pending-before")],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
latest = checkpointer.get_tuple(thread_config)
|
||||
|
||||
assert latest is not None
|
||||
assert latest.config["configurable"]["checkpoint_id"] != "0001"
|
||||
assert latest.config["configurable"]["checkpoint_id"] != "0002"
|
||||
assert latest.checkpoint["channel_values"] == {"messages": ["before"]}
|
||||
assert latest.pending_writes == [("task-before", "messages", "pending-before")]
|
||||
assert ("task-after", "messages", "pending-after") not in latest.pending_writes
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_deletes_thread_when_no_snapshot_exists():
|
||||
checkpointer = FakeCheckpointer(put_result=None)
|
||||
@ -123,12 +238,13 @@ async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace():
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.aput.assert_awaited_once_with(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
||||
{"id": "ckpt-1", "channel_versions": {}},
|
||||
{},
|
||||
{},
|
||||
)
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
|
||||
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||
assert restored_checkpoint["id"] != "ckpt-1"
|
||||
assert restored_checkpoint["channel_versions"] == {}
|
||||
assert restored_metadata == {}
|
||||
assert new_versions == {}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
|
||||
145
backend/tests/test_runtime_paths.py
Normal file
145
backend/tests/test_runtime_paths.py
Normal file
@ -0,0 +1,145 @@
|
||||
"""Runtime path policy tests for standalone harness usage."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from deerflow.config import app_config as app_config_module
|
||||
from deerflow.config import extensions_config as extensions_config_module
|
||||
from deerflow.config.app_config import AppConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.config.paths import Paths
|
||||
from deerflow.config.runtime_paths import project_root
|
||||
from deerflow.config.skills_config import SkillsConfig
|
||||
from deerflow.skills.storage import get_or_new_skill_storage
|
||||
|
||||
|
||||
def _clear_path_env(monkeypatch):
|
||||
for name in (
|
||||
"DEER_FLOW_CONFIG_PATH",
|
||||
"DEER_FLOW_EXTENSIONS_CONFIG_PATH",
|
||||
"DEER_FLOW_HOME",
|
||||
"DEER_FLOW_PROJECT_ROOT",
|
||||
"DEER_FLOW_SKILLS_PATH",
|
||||
):
|
||||
monkeypatch.delenv(name, raising=False)
|
||||
|
||||
|
||||
def test_default_runtime_paths_resolve_from_current_project(tmp_path: Path, monkeypatch):
|
||||
_clear_path_env(monkeypatch)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
(tmp_path / "config.yaml").write_text(
|
||||
yaml.safe_dump({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
(tmp_path / "extensions_config.json").write_text('{"mcpServers": {}, "skills": {}}', encoding="utf-8")
|
||||
|
||||
assert AppConfig.resolve_config_path() == tmp_path / "config.yaml"
|
||||
assert ExtensionsConfig.resolve_config_path() == tmp_path / "extensions_config.json"
|
||||
assert Paths().base_dir == tmp_path / ".deer-flow"
|
||||
assert SkillsConfig().get_skills_path() == tmp_path / "skills"
|
||||
assert get_or_new_skill_storage(skills_path=SkillsConfig().get_skills_path()).get_skills_root_path() == tmp_path / "skills"
|
||||
|
||||
|
||||
def test_deer_flow_project_root_overrides_current_directory(tmp_path: Path, monkeypatch):
|
||||
_clear_path_env(monkeypatch)
|
||||
project_root = tmp_path / "project"
|
||||
other_cwd = tmp_path / "other"
|
||||
project_root.mkdir()
|
||||
other_cwd.mkdir()
|
||||
monkeypatch.chdir(other_cwd)
|
||||
monkeypatch.setenv("DEER_FLOW_PROJECT_ROOT", str(project_root))
|
||||
|
||||
(project_root / "config.yaml").write_text(
|
||||
yaml.safe_dump({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
(project_root / "mcp_config.json").write_text('{"mcpServers": {}, "skills": {}}', encoding="utf-8")
|
||||
|
||||
assert AppConfig.resolve_config_path() == project_root / "config.yaml"
|
||||
assert ExtensionsConfig.resolve_config_path() == project_root / "mcp_config.json"
|
||||
assert Paths().base_dir == project_root / ".deer-flow"
|
||||
assert SkillsConfig(path="custom-skills").get_skills_path() == project_root / "custom-skills"
|
||||
|
||||
|
||||
def test_deer_flow_skills_path_overrides_project_default(tmp_path: Path, monkeypatch):
|
||||
_clear_path_env(monkeypatch)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
monkeypatch.setenv("DEER_FLOW_SKILLS_PATH", "team-skills")
|
||||
|
||||
assert SkillsConfig().get_skills_path() == tmp_path / "team-skills"
|
||||
assert get_or_new_skill_storage(skills_path=SkillsConfig().get_skills_path()).get_skills_root_path() == tmp_path / "team-skills"
|
||||
|
||||
|
||||
def test_deer_flow_project_root_must_exist(tmp_path: Path, monkeypatch):
|
||||
_clear_path_env(monkeypatch)
|
||||
missing_root = tmp_path / "missing"
|
||||
monkeypatch.setenv("DEER_FLOW_PROJECT_ROOT", str(missing_root))
|
||||
|
||||
with pytest.raises(ValueError, match="does not exist"):
|
||||
project_root()
|
||||
|
||||
|
||||
def test_deer_flow_project_root_must_be_directory(tmp_path: Path, monkeypatch):
|
||||
_clear_path_env(monkeypatch)
|
||||
project_root_file = tmp_path / "project-root"
|
||||
project_root_file.write_text("", encoding="utf-8")
|
||||
monkeypatch.setenv("DEER_FLOW_PROJECT_ROOT", str(project_root_file))
|
||||
|
||||
with pytest.raises(ValueError, match="not a directory"):
|
||||
project_root()
|
||||
|
||||
|
||||
def test_app_config_falls_back_to_legacy_when_project_root_lacks_config(tmp_path: Path, monkeypatch):
|
||||
"""When DEER_FLOW_PROJECT_ROOT is unset and cwd has no config.yaml, the
|
||||
legacy backend/repo-root candidates must be used for monorepo compatibility."""
|
||||
_clear_path_env(monkeypatch)
|
||||
cwd = tmp_path / "cwd"
|
||||
cwd.mkdir()
|
||||
monkeypatch.chdir(cwd)
|
||||
|
||||
legacy_backend = tmp_path / "legacy-backend"
|
||||
legacy_repo = tmp_path / "legacy-repo"
|
||||
legacy_backend.mkdir()
|
||||
legacy_repo.mkdir()
|
||||
legacy_backend_config = legacy_backend / "config.yaml"
|
||||
legacy_backend_config.write_text(
|
||||
yaml.safe_dump({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
repo_root_config = legacy_repo / "config.yaml"
|
||||
repo_root_config.write_text("", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(
|
||||
app_config_module,
|
||||
"_legacy_config_candidates",
|
||||
lambda: (legacy_backend_config, repo_root_config),
|
||||
)
|
||||
|
||||
assert AppConfig.resolve_config_path() == legacy_backend_config
|
||||
|
||||
|
||||
def test_extensions_config_falls_back_to_legacy_when_project_root_lacks_file(tmp_path: Path, monkeypatch):
|
||||
"""ExtensionsConfig should hit the legacy backend/repo-root locations when
|
||||
the caller project root has no extensions_config.json/mcp_config.json."""
|
||||
_clear_path_env(monkeypatch)
|
||||
cwd = tmp_path / "cwd"
|
||||
cwd.mkdir()
|
||||
monkeypatch.chdir(cwd)
|
||||
|
||||
fake_backend = tmp_path / "fake-backend"
|
||||
fake_repo = tmp_path / "fake-repo"
|
||||
fake_backend.mkdir()
|
||||
fake_repo.mkdir()
|
||||
legacy_extensions = fake_backend / "extensions_config.json"
|
||||
legacy_extensions.write_text('{"mcpServers": {}, "skills": {}}', encoding="utf-8")
|
||||
|
||||
fake_paths_module_file = fake_backend / "packages" / "harness" / "deerflow" / "config" / "extensions_config.py"
|
||||
fake_paths_module_file.parent.mkdir(parents=True)
|
||||
fake_paths_module_file.write_text("", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(extensions_config_module, "__file__", str(fake_paths_module_file))
|
||||
|
||||
assert ExtensionsConfig.resolve_config_path() == legacy_extensions
|
||||
308
backend/tests/test_serper_tools.py
Normal file
308
backend/tests/test_serper_tools.py
Normal file
@ -0,0 +1,308 @@
|
||||
"""Unit tests for the Serper community web search tool."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_api_key_warned():
|
||||
"""Reset the module-level warning flag before each test."""
|
||||
import deerflow.community.serper.tools as serper_mod
|
||||
|
||||
serper_mod._api_key_warned = False
|
||||
yield
|
||||
serper_mod._api_key_warned = False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_with_key():
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "test-serper-key", "max_results": 5}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_no_key():
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
yield mock
|
||||
|
||||
|
||||
def _make_serper_response(organic: list) -> MagicMock:
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.json.return_value = {"organic": organic}
|
||||
mock_resp.raise_for_status = MagicMock()
|
||||
return mock_resp
|
||||
|
||||
|
||||
class TestGetApiKey:
|
||||
def test_returns_config_key_when_present(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "from-config"}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
|
||||
from deerflow.community.serper.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "from-config"
|
||||
|
||||
def test_falls_back_to_env_when_config_key_empty(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": ""}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
|
||||
from deerflow.community.serper.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "env-key"
|
||||
|
||||
def test_falls_back_to_env_when_config_key_whitespace(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": " "}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
|
||||
from deerflow.community.serper.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "env-key"
|
||||
|
||||
def test_falls_back_to_env_when_config_key_null(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": None}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
|
||||
from deerflow.community.serper.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "env-key"
|
||||
|
||||
def test_falls_back_to_env_when_no_config(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
mock.return_value.get_tool_config.return_value = None
|
||||
with patch.dict("os.environ", {"SERPER_API_KEY": "env-only"}):
|
||||
from deerflow.community.serper.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() == "env-only"
|
||||
|
||||
def test_returns_none_when_no_key_anywhere(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
mock.return_value.get_tool_config.return_value = None
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
import os
|
||||
|
||||
os.environ.pop("SERPER_API_KEY", None)
|
||||
from deerflow.community.serper.tools import _get_api_key
|
||||
|
||||
assert _get_api_key() is None
|
||||
|
||||
|
||||
class TestWebSearchTool:
|
||||
def test_basic_search_returns_normalized_results(self, mock_config_with_key):
|
||||
organic = [
|
||||
{"title": "Result 1", "link": "https://example.com/1", "snippet": "Snippet 1"},
|
||||
{"title": "Result 2", "link": "https://example.com/2", "snippet": "Snippet 2"},
|
||||
]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "python tutorial"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["query"] == "python tutorial"
|
||||
assert parsed["total_results"] == 2
|
||||
assert parsed["results"][0]["title"] == "Result 1"
|
||||
assert parsed["results"][0]["url"] == "https://example.com/1"
|
||||
assert parsed["results"][0]["content"] == "Snippet 1"
|
||||
|
||||
def test_respects_max_results_from_config(self, mock_config_with_key):
|
||||
mock_config_with_key.return_value.get_tool_config.return_value.model_extra = {
|
||||
"api_key": "test-key",
|
||||
"max_results": 3,
|
||||
}
|
||||
organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["total_results"] == 3
|
||||
assert len(parsed["results"]) == 3
|
||||
|
||||
def test_max_results_parameter_accepted(self, mock_config_no_key):
|
||||
"""Tool accepts max_results as a call parameter when config does not override it."""
|
||||
organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch.dict("os.environ", {"SERPER_API_KEY": "env-key"}):
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test", "max_results": 2})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["total_results"] == 2
|
||||
|
||||
def test_config_max_results_overrides_parameter(self):
|
||||
"""Config max_results overrides the parameter passed at call time, matching ddg_search behaviour."""
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
tool_config = MagicMock()
|
||||
tool_config.model_extra = {"api_key": "test-key", "max_results": 3}
|
||||
mock.return_value.get_tool_config.return_value = tool_config
|
||||
|
||||
organic = [{"title": f"R{i}", "link": f"https://x.com/{i}", "snippet": f"S{i}"} for i in range(10)]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test", "max_results": 8})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["total_results"] == 3
|
||||
|
||||
def test_empty_organic_returns_error_json(self, mock_config_with_key):
|
||||
"""Empty organic list returns structured error, matching ddg_search convention."""
|
||||
mock_resp = _make_serper_response([])
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "no results"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert "error" in parsed
|
||||
assert parsed["error"] == "No results found"
|
||||
assert parsed["query"] == "no results"
|
||||
|
||||
def test_missing_api_key_returns_error_json(self, mock_config_no_key):
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
import os
|
||||
|
||||
os.environ.pop("SERPER_API_KEY", None)
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert "error" in parsed
|
||||
assert "SERPER_API_KEY" in parsed["error"]
|
||||
|
||||
def test_missing_api_key_logs_warning_once(self, mock_config_no_key, caplog):
|
||||
import logging
|
||||
|
||||
with patch.dict("os.environ", {}, clear=True):
|
||||
import os
|
||||
|
||||
os.environ.pop("SERPER_API_KEY", None)
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="deerflow.community.serper.tools"):
|
||||
web_search_tool.invoke({"query": "q1"})
|
||||
web_search_tool.invoke({"query": "q2"})
|
||||
|
||||
warnings = [r for r in caplog.records if r.levelno == logging.WARNING]
|
||||
assert len(warnings) == 1
|
||||
|
||||
def test_http_error_returns_structured_error(self, mock_config_with_key):
|
||||
mock_error_response = MagicMock()
|
||||
mock_error_response.status_code = 403
|
||||
mock_error_response.text = "Forbidden"
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = httpx.HTTPStatusError("403", request=MagicMock(), response=mock_error_response)
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert "error" in parsed
|
||||
assert "403" in parsed["error"]
|
||||
|
||||
def test_network_exception_returns_error_json(self, mock_config_with_key):
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.side_effect = Exception("timeout")
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert "error" in parsed
|
||||
|
||||
def test_sends_correct_headers_and_payload(self, mock_config_with_key):
|
||||
organic = [{"title": "T", "link": "https://x.com", "snippet": "S"}]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_post = mock_client_cls.return_value.__enter__.return_value.post
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
web_search_tool.invoke({"query": "hello world"})
|
||||
|
||||
call_kwargs = mock_post.call_args
|
||||
headers = call_kwargs.kwargs["headers"]
|
||||
payload = call_kwargs.kwargs["json"]
|
||||
|
||||
assert headers["X-API-KEY"] == "test-serper-key"
|
||||
assert payload["q"] == "hello world"
|
||||
assert payload["num"] == 5
|
||||
|
||||
def test_uses_env_key_when_config_absent(self):
|
||||
with patch("deerflow.community.serper.tools.get_app_config") as mock:
|
||||
mock.return_value.get_tool_config.return_value = None
|
||||
with patch.dict("os.environ", {"SERPER_API_KEY": "env-only-key"}):
|
||||
organic = [{"title": "T", "link": "https://x.com", "snippet": "S"}]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_post = mock_client_cls.return_value.__enter__.return_value.post
|
||||
mock_post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
web_search_tool.invoke({"query": "env key test"})
|
||||
headers = mock_post.call_args.kwargs["headers"]
|
||||
|
||||
assert headers["X-API-KEY"] == "env-only-key"
|
||||
|
||||
def test_partial_fields_in_organic_result(self, mock_config_with_key):
|
||||
"""Missing title/link/snippet should default to empty string."""
|
||||
organic = [{}]
|
||||
mock_resp = _make_serper_response(organic)
|
||||
|
||||
with patch("deerflow.community.serper.tools.httpx.Client") as mock_client_cls:
|
||||
mock_client_cls.return_value.__enter__.return_value.post.return_value = mock_resp
|
||||
|
||||
from deerflow.community.serper.tools import web_search_tool
|
||||
|
||||
result = web_search_tool.invoke({"query": "test"})
|
||||
parsed = json.loads(result)
|
||||
|
||||
assert parsed["results"][0] == {"title": "", "url": "", "content": ""}
|
||||
@ -14,12 +14,25 @@ def _write_skill(skill_dir: Path, name: str, description: str) -> None:
|
||||
(skill_dir / "SKILL.md").write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
def test_get_skills_root_path_points_to_project_root_skills():
|
||||
"""get_skills_root_path() should point to deer-flow/skills (sibling of backend/), not backend/packages/skills."""
|
||||
def test_get_skills_root_path_points_to_current_project_skills(tmp_path: Path, monkeypatch):
|
||||
"""get_skills_root_path() should point to the caller project skills directory."""
|
||||
monkeypatch.delenv("DEER_FLOW_SKILLS_PATH", raising=False)
|
||||
monkeypatch.delenv("DEER_FLOW_PROJECT_ROOT", raising=False)
|
||||
monkeypatch.chdir(tmp_path)
|
||||
|
||||
app_config = SimpleNamespace(skills=SkillsConfig())
|
||||
path = get_or_new_skill_storage(app_config=app_config).get_skills_root_path()
|
||||
assert path.name == "skills", f"Expected 'skills', got '{path.name}'"
|
||||
assert (path.parent / "backend").is_dir(), f"Expected skills path's parent to be project root containing 'backend/', but got {path}"
|
||||
assert path == tmp_path / "skills"
|
||||
|
||||
|
||||
def test_get_skills_root_path_honors_env_override(tmp_path: Path, monkeypatch):
|
||||
"""DEER_FLOW_SKILLS_PATH should override the caller project skills directory."""
|
||||
skills_root = tmp_path / "team-skills"
|
||||
monkeypatch.setenv("DEER_FLOW_SKILLS_PATH", str(skills_root))
|
||||
|
||||
app_config = SimpleNamespace(skills=SkillsConfig())
|
||||
path = get_or_new_skill_storage(app_config=app_config).get_skills_root_path()
|
||||
assert path == skills_root
|
||||
|
||||
|
||||
def test_load_skills_discovers_nested_skills_and_sets_container_paths(tmp_path: Path):
|
||||
|
||||
@ -204,7 +204,7 @@ class TestAgentConstruction:
|
||||
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
|
||||
app_config = object()
|
||||
app_config = SimpleNamespace(models=[SimpleNamespace(name="default-model")])
|
||||
model = object()
|
||||
middlewares = [object()]
|
||||
agent = object()
|
||||
@ -258,6 +258,7 @@ class TestAgentConstruction:
|
||||
}
|
||||
assert captured["middlewares"] == {
|
||||
"app_config": app_config,
|
||||
"model_name": "parent-model",
|
||||
"lazy_init": True,
|
||||
}
|
||||
assert captured["agent"]["model"] is model
|
||||
@ -265,6 +266,43 @@ class TestAgentConstruction:
|
||||
assert captured["agent"]["tools"] == []
|
||||
assert captured["agent"]["system_prompt"] == base_config.system_prompt
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_load_skill_messages_uses_explicit_app_config_for_skill_storage(
|
||||
self,
|
||||
classes,
|
||||
base_config,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
):
|
||||
"""Explicit app_config must be threaded into subagent skill storage lookup."""
|
||||
SubagentExecutor = classes["SubagentExecutor"]
|
||||
|
||||
app_config = SimpleNamespace(models=[SimpleNamespace(name="default-model")])
|
||||
skill_dir = tmp_path / "demo-skill"
|
||||
skill_dir.mkdir()
|
||||
skill_file = skill_dir / "SKILL.md"
|
||||
skill_file.write_text("Use demo skill", encoding="utf-8")
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def fake_get_or_new_skill_storage(*, app_config=None):
|
||||
captured["app_config"] = app_config
|
||||
return SimpleNamespace(load_skills=lambda *, enabled_only: [SimpleNamespace(name="demo-skill", skill_file=skill_file)])
|
||||
|
||||
monkeypatch.setattr("deerflow.skills.storage.get_or_new_skill_storage", fake_get_or_new_skill_storage)
|
||||
|
||||
executor = SubagentExecutor(
|
||||
config=base_config,
|
||||
tools=[],
|
||||
app_config=app_config,
|
||||
thread_id="test-thread",
|
||||
)
|
||||
|
||||
messages = await executor._load_skill_messages()
|
||||
|
||||
assert captured["app_config"] is app_config
|
||||
assert len(messages) == 1
|
||||
assert "Use demo skill" in messages[0].content
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Async Execution Path Tests
|
||||
|
||||
@ -9,6 +9,8 @@ Covers:
|
||||
- Skills filter passthrough in task_tool config assembly
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.subagents_config import (
|
||||
@ -343,12 +345,54 @@ class TestRegistryCustomAgentLookup:
|
||||
assert config.timeout_seconds == 600
|
||||
assert config.model == "inherit"
|
||||
|
||||
def test_custom_agent_found_from_explicit_app_config_without_global_config(self, monkeypatch):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
def fail_get_subagents_app_config():
|
||||
raise AssertionError("ambient get_subagents_app_config() must not be used when app_config is explicit")
|
||||
|
||||
monkeypatch.setattr("deerflow.config.subagents_config.get_subagents_app_config", fail_get_subagents_app_config)
|
||||
|
||||
app_config = SimpleNamespace(
|
||||
subagents=SubagentsAppConfig(
|
||||
custom_agents={
|
||||
"analysis": CustomSubagentConfig(
|
||||
description="Data analysis specialist",
|
||||
system_prompt="You are a data analysis subagent.",
|
||||
skills=["data-analysis"],
|
||||
)
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
config = get_subagent_config("analysis", app_config=app_config)
|
||||
|
||||
assert config is not None
|
||||
assert config.name == "analysis"
|
||||
assert config.skills == ["data-analysis"]
|
||||
|
||||
def test_custom_agent_not_found(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
_reset_subagents_config()
|
||||
assert get_subagent_config("nonexistent") is None
|
||||
|
||||
def test_get_available_subagent_names_falls_back_when_subagents_app_config_lacks_sandbox(self, monkeypatch):
|
||||
from deerflow.subagents import registry as registry_module
|
||||
from deerflow.subagents.registry import get_available_subagent_names
|
||||
|
||||
captured: dict[str, tuple] = {}
|
||||
|
||||
def fake_is_host_bash_allowed(*args, **kwargs):
|
||||
captured["args"] = args
|
||||
return True
|
||||
|
||||
monkeypatch.setattr(registry_module, "is_host_bash_allowed", fake_is_host_bash_allowed)
|
||||
|
||||
get_available_subagent_names(app_config=SubagentsAppConfig())
|
||||
|
||||
assert captured["args"] == ()
|
||||
|
||||
def test_builtin_takes_priority_over_custom(self):
|
||||
"""If a custom agent has the same name as a builtin, builtin wins."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
|
||||
@ -24,8 +24,11 @@ class FakeSubagentStatus(Enum):
|
||||
TIMED_OUT = "timed_out"
|
||||
|
||||
|
||||
def _make_runtime() -> SimpleNamespace:
|
||||
def _make_runtime(*, app_config=None) -> SimpleNamespace:
|
||||
# Minimal ToolRuntime-like object; task_tool only reads these three attributes.
|
||||
context = {"thread_id": "thread-1"}
|
||||
if app_config is not None:
|
||||
context["app_config"] = app_config
|
||||
return SimpleNamespace(
|
||||
state={
|
||||
"sandbox": {"sandbox_id": "local"},
|
||||
@ -35,14 +38,14 @@ def _make_runtime() -> SimpleNamespace:
|
||||
"outputs_path": "/tmp/outputs",
|
||||
},
|
||||
},
|
||||
context={"thread_id": "thread-1"},
|
||||
context=context,
|
||||
config={"metadata": {"model_name": "ark-model", "trace_id": "trace-1"}},
|
||||
)
|
||||
|
||||
|
||||
def _make_subagent_config() -> SubagentConfig:
|
||||
def _make_subagent_config(name: str = "general-purpose") -> SubagentConfig:
|
||||
return SubagentConfig(
|
||||
name="general-purpose",
|
||||
name=name,
|
||||
description="General helper",
|
||||
system_prompt="Base system prompt",
|
||||
max_turns=50,
|
||||
@ -112,6 +115,68 @@ def test_task_tool_rejects_bash_subagent_when_host_bash_disabled(monkeypatch):
|
||||
assert result.startswith("Error: Bash subagent is disabled")
|
||||
|
||||
|
||||
def test_task_tool_threads_runtime_app_config_to_subagent_dependencies(monkeypatch):
|
||||
app_config = object()
|
||||
config = _make_subagent_config(name="bash")
|
||||
runtime = _make_runtime(app_config=app_config)
|
||||
events = []
|
||||
captured = {}
|
||||
|
||||
class DummyExecutor:
|
||||
def __init__(self, **kwargs):
|
||||
captured["executor_kwargs"] = kwargs
|
||||
|
||||
def execute_async(self, prompt, task_id=None):
|
||||
captured["prompt"] = prompt
|
||||
return task_id or "generated-task-id"
|
||||
|
||||
def fake_get_available_subagent_names(*, app_config):
|
||||
captured["names_app_config"] = app_config
|
||||
return ["bash"]
|
||||
|
||||
def fake_get_subagent_config(name, *, app_config):
|
||||
captured["config_lookup"] = (name, app_config)
|
||||
return config
|
||||
|
||||
def fake_is_host_bash_allowed(config):
|
||||
captured["bash_gate_app_config"] = config
|
||||
return True
|
||||
|
||||
def fake_get_available_tools(**kwargs):
|
||||
captured["tools_kwargs"] = kwargs
|
||||
return ["tool-a"]
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", fake_get_available_subagent_names)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", fake_get_subagent_config)
|
||||
monkeypatch.setattr(task_tool_module, "is_host_bash_allowed", fake_is_host_bash_allowed)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", fake_get_available_tools)
|
||||
|
||||
output = _run_task_tool(
|
||||
runtime=runtime,
|
||||
description="运行命令",
|
||||
prompt="inspect files",
|
||||
subagent_type="bash",
|
||||
tool_call_id="tc-explicit-config",
|
||||
)
|
||||
|
||||
assert output == "Task Succeeded. Result: done"
|
||||
assert captured["names_app_config"] is app_config
|
||||
assert captured["config_lookup"] == ("bash", app_config)
|
||||
assert captured["bash_gate_app_config"] is app_config
|
||||
assert captured["tools_kwargs"]["app_config"] is app_config
|
||||
assert captured["executor_kwargs"]["app_config"] is app_config
|
||||
assert captured["executor_kwargs"]["tools"] == ["tool-a"]
|
||||
|
||||
|
||||
def test_task_tool_emits_running_and_completed_events(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
runtime = _make_runtime()
|
||||
@ -223,6 +288,56 @@ def test_task_tool_propagates_tool_groups_to_subagent(monkeypatch):
|
||||
get_available_tools.assert_called_once_with(model_name="ark-model", groups=parent_tool_groups, subagent_enabled=False)
|
||||
|
||||
|
||||
def test_task_tool_uses_subagent_model_override_for_tool_loading(monkeypatch):
|
||||
"""Subagent model overrides should drive model-gated tool loading."""
|
||||
config = SubagentConfig(
|
||||
name="general-purpose",
|
||||
description="General helper",
|
||||
system_prompt="Base system prompt",
|
||||
model="vision-subagent-model",
|
||||
max_turns=50,
|
||||
timeout_seconds=10,
|
||||
)
|
||||
runtime = _make_runtime()
|
||||
runtime.config["metadata"]["model_name"] = "parent-text-model"
|
||||
events = []
|
||||
get_available_tools = MagicMock(return_value=[])
|
||||
|
||||
class DummyExecutor:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def execute_async(self, prompt, task_id=None):
|
||||
return task_id or "generated-task-id"
|
||||
|
||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||
monkeypatch.setattr(task_tool_module, "SubagentExecutor", DummyExecutor)
|
||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_background_task_result",
|
||||
lambda _: _make_result(FakeSubagentStatus.COMPLETED, result="done"),
|
||||
)
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
|
||||
|
||||
output = _run_task_tool(
|
||||
runtime=runtime,
|
||||
description="inspect image",
|
||||
prompt="inspect the uploaded image",
|
||||
subagent_type="general-purpose",
|
||||
tool_call_id="tc-issue-2543",
|
||||
)
|
||||
|
||||
assert output == "Task Succeeded. Result: done"
|
||||
get_available_tools.assert_called_once_with(
|
||||
model_name="vision-subagent-model",
|
||||
groups=None,
|
||||
subagent_enabled=False,
|
||||
)
|
||||
|
||||
|
||||
def test_task_tool_inherits_parent_skill_allowlist_for_default_subagent(monkeypatch):
|
||||
config = _make_subagent_config()
|
||||
runtime = _make_runtime()
|
||||
@ -371,6 +486,8 @@ def test_task_tool_runtime_none_passes_groups_none(monkeypatch):
|
||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
|
||||
monkeypatch.setattr("deerflow.tools.get_available_tools", get_available_tools)
|
||||
fallback_app_config = SimpleNamespace(models=[SimpleNamespace(name="default-model")])
|
||||
monkeypatch.setattr(task_tool_module, "get_app_config", lambda: fallback_app_config)
|
||||
|
||||
output = _run_task_tool(
|
||||
runtime=None,
|
||||
@ -381,8 +498,13 @@ def test_task_tool_runtime_none_passes_groups_none(monkeypatch):
|
||||
)
|
||||
|
||||
assert output == "Task Succeeded. Result: ok"
|
||||
# runtime is None → metadata is empty dict → groups=None
|
||||
get_available_tools.assert_called_once_with(model_name=None, groups=None, subagent_enabled=False)
|
||||
# runtime is None -> metadata is empty dict -> groups=None, model falls back to app default.
|
||||
get_available_tools.assert_called_once_with(
|
||||
model_name="default-model",
|
||||
groups=None,
|
||||
subagent_enabled=False,
|
||||
app_config=fallback_app_config,
|
||||
)
|
||||
|
||||
config = _make_subagent_config()
|
||||
events = []
|
||||
|
||||
@ -1,12 +1,66 @@
|
||||
import re
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from _router_auth_helpers import make_authed_test_app
|
||||
from fastapi import HTTPException
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
from app.gateway.routers import threads
|
||||
from deerflow.config.paths import Paths
|
||||
from deerflow.persistence.thread_meta.memory import THREADS_NS, MemoryThreadMetaStore
|
||||
|
||||
_ISO_TIMESTAMP_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
|
||||
|
||||
|
||||
class _PermissiveThreadMetaStore(MemoryThreadMetaStore):
|
||||
"""Memory store that skips user-id filtering for router tests.
|
||||
|
||||
Owner isolation is exercised separately in
|
||||
``test_memory_thread_meta_isolation.py``. Router tests need to drive
|
||||
the FastAPI surface end-to-end with a single fixed app user, but the
|
||||
stub auth middleware in ``_router_auth_helpers`` stamps a fresh UUID
|
||||
on every request, so the production filtering would reject every
|
||||
pre-seeded record. Bypass that filter so the test can focus on the
|
||||
timestamp wire format.
|
||||
"""
|
||||
|
||||
async def _get_owned_record(self, thread_id, user_id, method_name): # type: ignore[override]
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
return dict(item.value) if item is not None else None
|
||||
|
||||
async def check_access(self, thread_id, user_id, *, require_existing=False): # type: ignore[override]
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
if item is None:
|
||||
return not require_existing
|
||||
return True
|
||||
|
||||
async def create(self, thread_id, *, assistant_id=None, user_id=None, display_name=None, metadata=None): # type: ignore[override]
|
||||
return await super().create(thread_id, assistant_id=assistant_id, user_id=None, display_name=display_name, metadata=metadata)
|
||||
|
||||
async def search(self, *, metadata=None, status=None, limit=100, offset=0, user_id=None): # type: ignore[override]
|
||||
return await super().search(metadata=metadata, status=status, limit=limit, offset=offset, user_id=None)
|
||||
|
||||
|
||||
def _build_thread_app() -> tuple[FastAPI, InMemoryStore, InMemorySaver]:
|
||||
"""Build a stub-authed FastAPI app wired with an in-memory ThreadMetaStore.
|
||||
|
||||
The thread_store on ``app.state`` is a permissive subclass of
|
||||
``MemoryThreadMetaStore`` so tests can drive ``/api/threads``
|
||||
end-to-end and pre-seed legacy records via the underlying BaseStore.
|
||||
|
||||
Returns ``(app, store, checkpointer)`` for direct seeding/inspection.
|
||||
"""
|
||||
app = make_authed_test_app()
|
||||
store = InMemoryStore()
|
||||
checkpointer = InMemorySaver()
|
||||
app.state.store = store
|
||||
app.state.checkpointer = checkpointer
|
||||
app.state.thread_store = _PermissiveThreadMetaStore(store)
|
||||
app.include_router(threads.router)
|
||||
return app, store, checkpointer
|
||||
|
||||
|
||||
def test_delete_thread_data_removes_thread_directory(tmp_path):
|
||||
@ -136,3 +190,244 @@ def test_strip_reserved_metadata_empty_input():
|
||||
def test_strip_reserved_metadata_strips_all_reserved_keys():
|
||||
out = threads._strip_reserved_metadata({"user_id": "x", "keep": "me"})
|
||||
assert out == {"keep": "me"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ISO 8601 timestamp contract (issue #2594)
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# Threads endpoints document ``created_at`` / ``updated_at`` as ISO
|
||||
# timestamps and that is the format LangGraph Platform uses
|
||||
# (``langgraph_sdk.schema.Thread.created_at: datetime`` JSON-encodes to
|
||||
# ISO 8601). The tests below pin that contract end-to-end and also
|
||||
# exercise the ``coerce_iso`` healing path for legacy unix-timestamp
|
||||
# records written by older Gateway versions.
|
||||
|
||||
|
||||
def test_create_thread_returns_iso_timestamps() -> None:
|
||||
app, _store, _checkpointer = _build_thread_app()
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/threads", json={"metadata": {}})
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
body = response.json()
|
||||
assert _ISO_TIMESTAMP_RE.match(body["created_at"]), body["created_at"]
|
||||
assert _ISO_TIMESTAMP_RE.match(body["updated_at"]), body["updated_at"]
|
||||
assert body["created_at"] == body["updated_at"]
|
||||
|
||||
|
||||
def test_get_thread_returns_iso_for_legacy_unix_record() -> None:
|
||||
"""A thread record written by older versions stores ``time.time()``
|
||||
floats. ``get_thread`` must transparently surface them as ISO so the
|
||||
frontend's ``new Date(...)`` parser does not break.
|
||||
"""
|
||||
app, store, checkpointer = _build_thread_app()
|
||||
|
||||
legacy_thread_id = "legacy-thread"
|
||||
legacy_ts = "1777252410.411327"
|
||||
|
||||
async def _seed() -> None:
|
||||
await store.aput(
|
||||
THREADS_NS,
|
||||
legacy_thread_id,
|
||||
{
|
||||
"thread_id": legacy_thread_id,
|
||||
"status": "idle",
|
||||
"created_at": legacy_ts,
|
||||
"updated_at": legacy_ts,
|
||||
"metadata": {},
|
||||
},
|
||||
)
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
|
||||
await checkpointer.aput(
|
||||
{"configurable": {"thread_id": legacy_thread_id, "checkpoint_ns": ""}},
|
||||
empty_checkpoint(),
|
||||
{"step": -1, "source": "input", "writes": None, "parents": {}},
|
||||
{},
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(_seed())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get(f"/api/threads/{legacy_thread_id}")
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
body = response.json()
|
||||
assert _ISO_TIMESTAMP_RE.match(body["created_at"]), body["created_at"]
|
||||
assert _ISO_TIMESTAMP_RE.match(body["updated_at"]), body["updated_at"]
|
||||
|
||||
|
||||
def test_patch_thread_returns_iso_and_advances_updated_at() -> None:
|
||||
app, store, _checkpointer = _build_thread_app()
|
||||
thread_id = "patch-target"
|
||||
|
||||
legacy_created = "1777000000.000000"
|
||||
legacy_updated = "1777000000.000000"
|
||||
|
||||
async def _seed() -> None:
|
||||
await store.aput(
|
||||
THREADS_NS,
|
||||
thread_id,
|
||||
{
|
||||
"thread_id": thread_id,
|
||||
"status": "idle",
|
||||
"created_at": legacy_created,
|
||||
"updated_at": legacy_updated,
|
||||
"metadata": {"k": "v0"},
|
||||
},
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(_seed())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.patch(f"/api/threads/{thread_id}", json={"metadata": {"k": "v1"}})
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
body = response.json()
|
||||
assert _ISO_TIMESTAMP_RE.match(body["created_at"]), body["created_at"]
|
||||
assert _ISO_TIMESTAMP_RE.match(body["updated_at"]), body["updated_at"]
|
||||
# Patch issues a fresh ``updated_at`` via ``MemoryThreadMetaStore.update_metadata``,
|
||||
# so it must be > the migrated legacy ``created_at`` (both ISO strings
|
||||
# sort lexicographically by time when the format is consistent).
|
||||
assert body["updated_at"] > body["created_at"]
|
||||
assert body["metadata"] == {"k": "v1"}
|
||||
|
||||
|
||||
def test_search_threads_normalizes_legacy_unix_seconds_to_iso() -> None:
|
||||
"""``MemoryThreadMetaStore`` may hold legacy ``time.time()`` floats
|
||||
written by older Gateway versions. ``/search`` must surface them as
|
||||
ISO via ``coerce_iso`` so the frontend's ``new Date(...)`` parser
|
||||
does not break.
|
||||
"""
|
||||
app, store, _checkpointer = _build_thread_app()
|
||||
|
||||
async def _seed() -> None:
|
||||
# Legacy unix-second float (the literal value from issue #2594).
|
||||
await store.aput(
|
||||
THREADS_NS,
|
||||
"legacy",
|
||||
{
|
||||
"thread_id": "legacy",
|
||||
"status": "idle",
|
||||
"created_at": 1777000000.0,
|
||||
"updated_at": 1777000000.0,
|
||||
"metadata": {},
|
||||
},
|
||||
)
|
||||
# Modern ISO string, slightly later.
|
||||
await store.aput(
|
||||
THREADS_NS,
|
||||
"modern",
|
||||
{
|
||||
"thread_id": "modern",
|
||||
"status": "idle",
|
||||
"created_at": "2026-04-27T00:00:00+00:00",
|
||||
"updated_at": "2026-04-27T00:00:00+00:00",
|
||||
"metadata": {},
|
||||
},
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(_seed())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/threads/search", json={"limit": 10})
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
items = response.json()
|
||||
assert {item["thread_id"] for item in items} == {"legacy", "modern"}
|
||||
for item in items:
|
||||
assert _ISO_TIMESTAMP_RE.match(item["created_at"]), item
|
||||
assert _ISO_TIMESTAMP_RE.match(item["updated_at"]), item
|
||||
|
||||
|
||||
def test_memory_thread_meta_store_writes_iso_on_create() -> None:
|
||||
"""``MemoryThreadMetaStore.create`` must emit ISO so newly created
|
||||
threads serialize correctly without depending on the router's
|
||||
``coerce_iso`` heal path.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
store = InMemoryStore()
|
||||
repo = MemoryThreadMetaStore(store)
|
||||
|
||||
async def _scenario() -> dict:
|
||||
await repo.create("fresh", user_id=None, metadata={"a": 1})
|
||||
record = (await store.aget(THREADS_NS, "fresh")).value
|
||||
return record
|
||||
|
||||
record = asyncio.run(_scenario())
|
||||
assert _ISO_TIMESTAMP_RE.match(record["created_at"]), record
|
||||
assert _ISO_TIMESTAMP_RE.match(record["updated_at"]), record
|
||||
|
||||
|
||||
def test_get_thread_state_returns_iso_for_legacy_checkpoint_metadata() -> None:
|
||||
"""Checkpoints written by older Gateway versions stored
|
||||
``created_at`` as a unix-second float in their metadata. The
|
||||
``/state`` endpoint must surface that value as ISO so the frontend's
|
||||
``new Date(...)`` parser does not break — same root cause as the
|
||||
thread-record bug fixed in #2594, but on the checkpoint side.
|
||||
"""
|
||||
app, _store, checkpointer = _build_thread_app()
|
||||
thread_id = "legacy-state"
|
||||
|
||||
async def _seed() -> None:
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
|
||||
await checkpointer.aput(
|
||||
{"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}},
|
||||
empty_checkpoint(),
|
||||
{"step": -1, "source": "input", "writes": None, "parents": {}, "created_at": 1777252410.411327},
|
||||
{},
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(_seed())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get(f"/api/threads/{thread_id}/state")
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
body = response.json()
|
||||
assert _ISO_TIMESTAMP_RE.match(body["created_at"]), body["created_at"]
|
||||
assert _ISO_TIMESTAMP_RE.match(body["checkpoint"]["ts"]), body["checkpoint"]
|
||||
|
||||
|
||||
def test_get_thread_history_returns_iso_for_legacy_checkpoint_metadata() -> None:
|
||||
"""``/history`` walks ``checkpointer.alist`` and emits one entry per
|
||||
checkpoint. Each entry's ``created_at`` must come out as ISO even if
|
||||
older checkpoints stored a unix-second float in their metadata.
|
||||
"""
|
||||
app, _store, checkpointer = _build_thread_app()
|
||||
thread_id = "legacy-history"
|
||||
|
||||
async def _seed() -> None:
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
|
||||
await checkpointer.aput(
|
||||
{"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}},
|
||||
empty_checkpoint(),
|
||||
{"step": -1, "source": "input", "writes": None, "parents": {}, "created_at": 1777252410.411327},
|
||||
{},
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(_seed())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post(f"/api/threads/{thread_id}/history", json={"limit": 10})
|
||||
|
||||
assert response.status_code == 200, response.text
|
||||
entries = response.json()
|
||||
assert entries, "expected at least one history entry"
|
||||
for entry in entries:
|
||||
assert _ISO_TIMESTAMP_RE.match(entry["created_at"]), entry
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
"""Core behavior tests for TitleMiddleware."""
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
@ -98,6 +99,34 @@ class TestTitleMiddlewareCoreLogic:
|
||||
"tags": ["middleware:title"],
|
||||
}
|
||||
|
||||
def test_generate_title_uses_explicit_app_config_without_global_config(self, monkeypatch):
|
||||
title_config = TitleConfig(enabled=True, model_name="title-model", max_chars=20)
|
||||
app_config = SimpleNamespace(title=title_config)
|
||||
middleware = TitleMiddleware(app_config=app_config)
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(return_value=AIMessage(content="显式标题"))
|
||||
|
||||
def fail_get_title_config():
|
||||
raise AssertionError("ambient get_title_config() must not be used when app_config is explicit")
|
||||
|
||||
monkeypatch.setattr(title_middleware_module, "get_title_config", fail_get_title_config)
|
||||
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
|
||||
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="请帮我写一个标题"),
|
||||
AIMessage(content="好的"),
|
||||
]
|
||||
}
|
||||
result = asyncio.run(middleware._agenerate_title_result(state))
|
||||
|
||||
assert result == {"title": "显式标题"}
|
||||
title_middleware_module.create_chat_model.assert_called_once_with(
|
||||
name="title-model",
|
||||
thinking_enabled=False,
|
||||
app_config=app_config,
|
||||
)
|
||||
|
||||
def test_generate_title_normalizes_structured_message_content(self, monkeypatch):
|
||||
_set_test_title_config(max_chars=20)
|
||||
middleware = TitleMiddleware()
|
||||
|
||||
32
backend/tests/test_token_usage_middleware.py
Normal file
32
backend/tests/test_token_usage_middleware.py
Normal file
@ -0,0 +1,32 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
|
||||
|
||||
|
||||
def test_after_model_logs_usage_metadata_counts():
|
||||
middleware = TokenUsageMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="done",
|
||||
usage_metadata={
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
},
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
with patch("deerflow.agents.middlewares.token_usage_middleware.logger.info") as info_mock:
|
||||
result = middleware.after_model(state=state, runtime=MagicMock())
|
||||
|
||||
assert result is None
|
||||
info_mock.assert_called_once_with(
|
||||
"LLM token usage: input=%s output=%s total=%s",
|
||||
10,
|
||||
5,
|
||||
15,
|
||||
)
|
||||
@ -9,11 +9,20 @@ from deerflow.agents.middlewares.tool_error_handling_middleware import (
|
||||
ToolErrorHandlingMiddleware,
|
||||
build_subagent_runtime_middlewares,
|
||||
)
|
||||
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||
from deerflow.config.app_config import AppConfig, CircuitBreakerConfig
|
||||
from deerflow.config.guardrails_config import GuardrailsConfig
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
|
||||
|
||||
def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
|
||||
tool_call = {"name": name}
|
||||
if tool_call_id is not None:
|
||||
tool_call["id"] = tool_call_id
|
||||
return SimpleNamespace(tool_call=tool_call)
|
||||
|
||||
|
||||
def _module(name: str, **attrs):
|
||||
module = ModuleType(name)
|
||||
for key, value in attrs.items():
|
||||
@ -21,19 +30,62 @@ def _module(name: str, **attrs):
|
||||
return module
|
||||
|
||||
|
||||
def _make_app_config() -> AppConfig:
|
||||
def _make_app_config(*, supports_vision: bool = False) -> AppConfig:
|
||||
return AppConfig(
|
||||
models=[
|
||||
ModelConfig(
|
||||
name="test-model",
|
||||
display_name="test-model",
|
||||
description=None,
|
||||
use="langchain_openai:ChatOpenAI",
|
||||
model="test-model",
|
||||
supports_vision=supports_vision,
|
||||
)
|
||||
],
|
||||
sandbox=SandboxConfig(use="test"),
|
||||
guardrails=GuardrailsConfig(enabled=False),
|
||||
circuit_breaker=CircuitBreakerConfig(failure_threshold=7, recovery_timeout_sec=11),
|
||||
)
|
||||
|
||||
|
||||
def _request(name: str = "web_search", tool_call_id: str | None = "tc-1"):
|
||||
tool_call = {"name": name}
|
||||
if tool_call_id is not None:
|
||||
tool_call["id"] = tool_call_id
|
||||
return SimpleNamespace(tool_call=tool_call)
|
||||
def _stub_runtime_middleware_imports(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
class FakeMiddleware:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
class FakeLLMErrorHandlingMiddleware:
|
||||
def __init__(self, *, app_config):
|
||||
self.app_config = app_config
|
||||
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"deerflow.agents.middlewares.llm_error_handling_middleware",
|
||||
_module(
|
||||
"deerflow.agents.middlewares.llm_error_handling_middleware",
|
||||
LLMErrorHandlingMiddleware=FakeLLMErrorHandlingMiddleware,
|
||||
),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"deerflow.agents.middlewares.thread_data_middleware",
|
||||
_module("deerflow.agents.middlewares.thread_data_middleware", ThreadDataMiddleware=FakeMiddleware),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"deerflow.sandbox.middleware",
|
||||
_module("deerflow.sandbox.middleware", SandboxMiddleware=FakeMiddleware),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"deerflow.agents.middlewares.dangling_tool_call_middleware",
|
||||
_module("deerflow.agents.middlewares.dangling_tool_call_middleware", DanglingToolCallMiddleware=FakeMiddleware),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"deerflow.agents.middlewares.sandbox_audit_middleware",
|
||||
_module("deerflow.agents.middlewares.sandbox_audit_middleware", SandboxAuditMiddleware=FakeMiddleware),
|
||||
)
|
||||
|
||||
|
||||
def test_build_subagent_runtime_middlewares_threads_app_config_to_llm_middleware(monkeypatch: pytest.MonkeyPatch):
|
||||
@ -166,3 +218,30 @@ async def test_awrap_tool_call_reraises_graph_interrupt():
|
||||
|
||||
with pytest.raises(GraphInterrupt):
|
||||
await middleware.awrap_tool_call(req, _interrupt)
|
||||
|
||||
|
||||
def test_subagent_runtime_middlewares_include_view_image_for_vision_model(monkeypatch):
|
||||
app_config = _make_app_config(supports_vision=True)
|
||||
_stub_runtime_middleware_imports(monkeypatch)
|
||||
|
||||
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name="test-model")
|
||||
|
||||
assert any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares)
|
||||
|
||||
|
||||
def test_subagent_runtime_middlewares_include_view_image_for_default_vision_model(monkeypatch):
|
||||
app_config = _make_app_config(supports_vision=True)
|
||||
_stub_runtime_middleware_imports(monkeypatch)
|
||||
|
||||
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name=None)
|
||||
|
||||
assert any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares)
|
||||
|
||||
|
||||
def test_subagent_runtime_middlewares_skip_view_image_for_text_model(monkeypatch):
|
||||
app_config = _make_app_config(supports_vision=False)
|
||||
_stub_runtime_middleware_imports(monkeypatch)
|
||||
|
||||
middlewares = build_subagent_runtime_middlewares(app_config=app_config, model_name="test-model")
|
||||
|
||||
assert not any(isinstance(middleware, ViewImageMiddleware) for middleware in middlewares)
|
||||
|
||||
@ -1,14 +1,20 @@
|
||||
"""Tests for deerflow.uploads.manager — shared upload management logic."""
|
||||
|
||||
import errno
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.uploads.manager import (
|
||||
PathTraversalError,
|
||||
UnsafeUploadPathError,
|
||||
claim_unique_filename,
|
||||
delete_file_safe,
|
||||
list_files_in_dir,
|
||||
normalize_filename,
|
||||
validate_path_traversal,
|
||||
write_upload_file_no_symlink,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -97,6 +103,54 @@ class TestValidatePathTraversal:
|
||||
validate_path_traversal(link, tmp_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# write_upload_file_no_symlink
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWriteUploadFileNoSymlink:
|
||||
def test_writes_new_file(self, tmp_path):
|
||||
dest = write_upload_file_no_symlink(tmp_path, "notes.txt", b"hello")
|
||||
|
||||
assert dest == tmp_path / "notes.txt"
|
||||
assert dest.read_bytes() == b"hello"
|
||||
|
||||
def test_overwrites_existing_regular_file_with_single_link(self, tmp_path):
|
||||
dest = tmp_path / "notes.txt"
|
||||
dest.write_bytes(b"old contents")
|
||||
assert os.stat(dest).st_nlink == 1
|
||||
|
||||
result = write_upload_file_no_symlink(tmp_path, "notes.txt", b"new contents")
|
||||
|
||||
assert result == dest
|
||||
assert dest.read_bytes() == b"new contents"
|
||||
assert os.stat(dest).st_nlink == 1
|
||||
|
||||
def test_fails_closed_without_no_follow_support(self, tmp_path, monkeypatch):
|
||||
monkeypatch.delattr(os, "O_NOFOLLOW", raising=False)
|
||||
|
||||
with pytest.raises(UnsafeUploadPathError, match="O_NOFOLLOW"):
|
||||
write_upload_file_no_symlink(tmp_path, "notes.txt", b"hello")
|
||||
|
||||
assert not (tmp_path / "notes.txt").exists()
|
||||
|
||||
def test_open_uses_nonblocking_flag_when_available(self, tmp_path):
|
||||
with patch("deerflow.uploads.manager.os.open", side_effect=OSError(errno.ENXIO, "no reader")) as open_mock:
|
||||
with pytest.raises(UnsafeUploadPathError, match="Unsafe upload destination"):
|
||||
write_upload_file_no_symlink(tmp_path, "pipe.txt", b"hello")
|
||||
|
||||
flags = open_mock.call_args.args[1]
|
||||
assert flags & os.O_NONBLOCK
|
||||
|
||||
@pytest.mark.parametrize("open_errno", [errno.ENXIO, errno.EAGAIN])
|
||||
def test_nonblocking_special_file_open_errors_are_unsafe(self, tmp_path, open_errno):
|
||||
with patch("deerflow.uploads.manager.os.open", side_effect=OSError(open_errno, "would block")):
|
||||
with pytest.raises(UnsafeUploadPathError, match="Unsafe upload destination"):
|
||||
write_upload_file_no_symlink(tmp_path, "pipe.txt", b"hello")
|
||||
|
||||
assert not (tmp_path / "pipe.txt").exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_files_in_dir
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -1,16 +1,40 @@
|
||||
import asyncio
|
||||
import os
|
||||
import stat
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from _router_auth_helpers import call_unwrapped
|
||||
from fastapi import UploadFile
|
||||
import pytest
|
||||
from _router_auth_helpers import call_unwrapped, make_authed_test_app
|
||||
from fastapi import HTTPException, UploadFile
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.routers import uploads
|
||||
|
||||
|
||||
class ChunkedUpload:
|
||||
def __init__(self, filename: str, chunks: list[bytes]):
|
||||
self.filename = filename
|
||||
self._chunks = list(chunks)
|
||||
self.read_calls: list[int | None] = []
|
||||
|
||||
async def read(self, size: int | None = None) -> bytes:
|
||||
self.read_calls.append(size)
|
||||
if size is None:
|
||||
raise AssertionError("upload must be read with an explicit chunk size")
|
||||
if not self._chunks:
|
||||
return b""
|
||||
return self._chunks.pop(0)
|
||||
|
||||
|
||||
def _mounted_provider() -> MagicMock:
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = True
|
||||
return provider
|
||||
|
||||
|
||||
def test_upload_files_writes_thread_storage_and_skips_local_sandbox_sync(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
@ -178,6 +202,173 @@ def test_upload_files_does_not_adjust_permissions_for_local_sandbox(tmp_path):
|
||||
make_writable.assert_not_called()
|
||||
|
||||
|
||||
def test_upload_files_acquires_non_local_sandbox_before_writing(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = False
|
||||
sandbox = MagicMock()
|
||||
provider.get.return_value = sandbox
|
||||
|
||||
def acquire_before_writes(thread_id: str) -> str:
|
||||
assert list(thread_uploads_dir.iterdir()) == []
|
||||
return "aio-1"
|
||||
|
||||
provider.acquire.side_effect = acquire_before_writes
|
||||
|
||||
with (
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
|
||||
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||
|
||||
assert result.success is True
|
||||
provider.acquire.assert_called_once_with("thread-aio")
|
||||
sandbox.update_file.assert_called_once_with("/mnt/user-data/uploads/notes.txt", b"hello uploads")
|
||||
|
||||
|
||||
def test_upload_files_fails_before_writing_when_non_local_sandbox_unavailable(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = False
|
||||
provider.acquire.side_effect = RuntimeError("sandbox unavailable")
|
||||
file = ChunkedUpload("notes.txt", [b"hello uploads"])
|
||||
|
||||
with (
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="sandbox unavailable"):
|
||||
asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||
|
||||
assert list(thread_uploads_dir.iterdir()) == []
|
||||
assert file.read_calls == []
|
||||
provider.get.assert_not_called()
|
||||
|
||||
|
||||
def test_upload_files_rejects_too_many_files_before_writing(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
with (
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=_mounted_provider()),
|
||||
patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=1, max_file_size=10, max_total_size=20)),
|
||||
):
|
||||
files = [
|
||||
ChunkedUpload("one.txt", [b"one"]),
|
||||
ChunkedUpload("two.txt", [b"two"]),
|
||||
]
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=files, config=SimpleNamespace()))
|
||||
|
||||
assert exc_info.value.status_code == 413
|
||||
assert list(thread_uploads_dir.iterdir()) == []
|
||||
assert files[0].read_calls == []
|
||||
assert files[1].read_calls == []
|
||||
|
||||
|
||||
def test_upload_files_rejects_oversized_single_file_and_removes_partial_file(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
provider = _mounted_provider()
|
||||
file = ChunkedUpload("big.txt", [b"123456"])
|
||||
|
||||
with (
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=10, max_file_size=5, max_total_size=20)),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||
|
||||
assert exc_info.value.status_code == 413
|
||||
assert not (thread_uploads_dir / "big.txt").exists()
|
||||
assert file.read_calls == [8192]
|
||||
provider.acquire.assert_not_called()
|
||||
|
||||
|
||||
def test_upload_files_rejects_total_size_over_limit_and_cleans_request_files(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
with (
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=_mounted_provider()),
|
||||
patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=10, max_file_size=10, max_total_size=5)),
|
||||
):
|
||||
files = [
|
||||
ChunkedUpload("first.txt", [b"123"]),
|
||||
ChunkedUpload("second.txt", [b"456"]),
|
||||
]
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=files, config=SimpleNamespace()))
|
||||
|
||||
assert exc_info.value.status_code == 413
|
||||
assert not (thread_uploads_dir / "first.txt").exists()
|
||||
assert not (thread_uploads_dir / "second.txt").exists()
|
||||
|
||||
|
||||
def test_upload_files_does_not_sync_non_local_sandbox_when_total_size_exceeds_limit(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = False
|
||||
provider.acquire.return_value = "aio-1"
|
||||
sandbox = MagicMock()
|
||||
provider.get.return_value = sandbox
|
||||
|
||||
with (
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
patch.object(uploads, "_get_upload_limits", return_value=uploads.UploadLimits(max_files=10, max_file_size=10, max_total_size=5)),
|
||||
):
|
||||
files = [
|
||||
ChunkedUpload("first.txt", [b"123"]),
|
||||
ChunkedUpload("second.txt", [b"456"]),
|
||||
]
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=files, config=SimpleNamespace()))
|
||||
|
||||
assert exc_info.value.status_code == 413
|
||||
provider.acquire.assert_called_once_with("thread-aio")
|
||||
provider.get.assert_called_once_with("aio-1")
|
||||
sandbox.update_file.assert_not_called()
|
||||
|
||||
|
||||
def test_upload_files_does_not_sync_non_local_sandbox_when_conversion_fails(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = False
|
||||
provider.acquire.return_value = "aio-1"
|
||||
sandbox = MagicMock()
|
||||
provider.get.return_value = sandbox
|
||||
|
||||
with (
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
patch.object(uploads, "_auto_convert_documents_enabled", return_value=True),
|
||||
patch.object(uploads, "convert_file_to_markdown", AsyncMock(side_effect=RuntimeError("conversion failed"))),
|
||||
):
|
||||
file = UploadFile(filename="report.pdf", file=BytesIO(b"pdf-bytes"))
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
|
||||
|
||||
assert exc_info.value.status_code == 500
|
||||
provider.acquire.assert_called_once_with("thread-aio")
|
||||
provider.get.assert_called_once_with("aio-1")
|
||||
sandbox.update_file.assert_not_called()
|
||||
assert not (thread_uploads_dir / "report.pdf").exists()
|
||||
|
||||
|
||||
def test_make_file_sandbox_writable_adds_write_bits_for_regular_files(tmp_path):
|
||||
file_path = tmp_path / "report.pdf"
|
||||
file_path.write_bytes(b"pdf-bytes")
|
||||
@ -238,6 +429,105 @@ def test_upload_files_rejects_dotdot_and_dot_filenames(tmp_path):
|
||||
assert [f.name for f in thread_uploads_dir.iterdir()] == ["passwd"]
|
||||
|
||||
|
||||
def test_upload_files_rejects_preexisting_symlink_destination(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
outside_file = tmp_path / "outside.txt"
|
||||
outside_file.write_text("protected", encoding="utf-8")
|
||||
(thread_uploads_dir / "victim.txt").symlink_to(outside_file)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = True
|
||||
|
||||
with (
|
||||
patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
file = UploadFile(filename="victim.txt", file=BytesIO(b"attacker upload"))
|
||||
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
|
||||
|
||||
assert result.success is False
|
||||
assert result.files == []
|
||||
assert result.skipped_files == ["victim.txt"]
|
||||
assert "skipped 1 unsafe file" in result.message
|
||||
assert outside_file.read_text(encoding="utf-8") == "protected"
|
||||
assert (thread_uploads_dir / "victim.txt").is_symlink()
|
||||
|
||||
|
||||
def test_upload_files_rejects_dangling_symlink_destination(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
missing_target = tmp_path / "missing-target.txt"
|
||||
(thread_uploads_dir / "victim.txt").symlink_to(missing_target)
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = True
|
||||
|
||||
with (
|
||||
patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
file = UploadFile(filename="victim.txt", file=BytesIO(b"attacker upload"))
|
||||
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
|
||||
|
||||
assert result.success is False
|
||||
assert result.files == []
|
||||
assert result.skipped_files == ["victim.txt"]
|
||||
assert not missing_target.exists()
|
||||
assert (thread_uploads_dir / "victim.txt").is_symlink()
|
||||
|
||||
|
||||
def test_upload_files_rejects_hardlinked_destination_without_truncating(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
outside_file = tmp_path / "outside.txt"
|
||||
outside_file.write_text("protected", encoding="utf-8")
|
||||
os.link(outside_file, thread_uploads_dir / "victim.txt")
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = True
|
||||
|
||||
with (
|
||||
patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
file = UploadFile(filename="victim.txt", file=BytesIO(b"attacker upload"))
|
||||
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
|
||||
|
||||
assert result.success is False
|
||||
assert result.files == []
|
||||
assert result.skipped_files == ["victim.txt"]
|
||||
assert outside_file.read_text(encoding="utf-8") == "protected"
|
||||
assert (thread_uploads_dir / "victim.txt").read_text(encoding="utf-8") == "protected"
|
||||
|
||||
|
||||
def test_upload_files_overwrites_existing_regular_file(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
existing_file = thread_uploads_dir / "notes.txt"
|
||||
existing_file.write_bytes(b"old upload")
|
||||
assert existing_file.stat().st_nlink == 1
|
||||
|
||||
provider = MagicMock()
|
||||
provider.uses_thread_data_mounts = True
|
||||
|
||||
with (
|
||||
patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=provider),
|
||||
):
|
||||
file = UploadFile(filename="notes.txt", file=BytesIO(b"new upload"))
|
||||
result = asyncio.run(uploads.upload_files("thread-local", files=[file]))
|
||||
|
||||
assert result.success is True
|
||||
assert [file_info["filename"] for file_info in result.files] == ["notes.txt"]
|
||||
assert existing_file.read_bytes() == b"new upload"
|
||||
assert existing_file.stat().st_nlink == 1
|
||||
|
||||
|
||||
def test_delete_uploaded_file_removes_generated_markdown_companion(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
@ -286,3 +576,65 @@ def test_auto_convert_documents_enabled_accepts_boolean_and_string_truthy_values
|
||||
assert uploads._auto_convert_documents_enabled(true_cfg) is True
|
||||
assert uploads._auto_convert_documents_enabled(string_true_cfg) is True
|
||||
assert uploads._auto_convert_documents_enabled(string_false_cfg) is False
|
||||
|
||||
|
||||
def test_upload_limits_endpoint_reads_uploads_config():
|
||||
cfg = MagicMock()
|
||||
cfg.uploads = {
|
||||
"max_files": 15,
|
||||
"max_file_size": "1048576",
|
||||
"max_total_size": 2097152,
|
||||
}
|
||||
|
||||
result = asyncio.run(call_unwrapped(uploads.get_upload_limits, "thread-local", request=MagicMock(), config=cfg))
|
||||
|
||||
assert result.max_files == 15
|
||||
assert result.max_file_size == 1048576
|
||||
assert result.max_total_size == 2097152
|
||||
|
||||
|
||||
def test_upload_limits_endpoint_requires_thread_access():
|
||||
cfg = MagicMock()
|
||||
cfg.uploads = {}
|
||||
app = make_authed_test_app(owner_check_passes=False)
|
||||
app.state.config = cfg
|
||||
app.include_router(uploads.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/threads/thread-local/uploads/limits")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_upload_limits_accept_legacy_config_keys():
|
||||
cfg = MagicMock()
|
||||
cfg.uploads = {
|
||||
"max_file_count": 7,
|
||||
"max_single_file_size": 123,
|
||||
"max_total_size": 456,
|
||||
}
|
||||
|
||||
limits = uploads._get_upload_limits(cfg)
|
||||
|
||||
assert limits == uploads.UploadLimits(max_files=7, max_file_size=123, max_total_size=456)
|
||||
|
||||
|
||||
def test_upload_files_uses_configured_file_count_limit(tmp_path):
|
||||
thread_uploads_dir = tmp_path / "uploads"
|
||||
thread_uploads_dir.mkdir(parents=True)
|
||||
|
||||
cfg = MagicMock()
|
||||
cfg.uploads = {"max_files": 1}
|
||||
|
||||
with (
|
||||
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
|
||||
patch.object(uploads, "get_sandbox_provider", return_value=_mounted_provider()),
|
||||
):
|
||||
files = [
|
||||
ChunkedUpload("one.txt", [b"one"]),
|
||||
ChunkedUpload("two.txt", [b"two"]),
|
||||
]
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=files, config=cfg))
|
||||
|
||||
assert exc_info.value.status_code == 413
|
||||
|
||||
90
backend/tests/test_utils_time.py
Normal file
90
backend/tests/test_utils_time.py
Normal file
@ -0,0 +1,90 @@
|
||||
"""Tests for ``deerflow.utils.time``."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import UTC, datetime, timedelta, timezone
|
||||
|
||||
from deerflow.utils.time import coerce_iso, now_iso
|
||||
|
||||
_ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
|
||||
|
||||
|
||||
def test_now_iso_is_utc_iso8601() -> None:
|
||||
value = now_iso()
|
||||
assert _ISO_RE.match(value), value
|
||||
parsed = datetime.fromisoformat(value)
|
||||
assert parsed.tzinfo is not None
|
||||
assert parsed.tzinfo.utcoffset(parsed) == UTC.utcoffset(parsed)
|
||||
|
||||
|
||||
def test_coerce_iso_passes_iso_through() -> None:
|
||||
iso = "2026-04-27T01:13:30.411334+00:00"
|
||||
assert coerce_iso(iso) == iso
|
||||
|
||||
|
||||
def test_coerce_iso_converts_unix_float_string() -> None:
|
||||
legacy = "1777252410.411327"
|
||||
out = coerce_iso(legacy)
|
||||
assert _ISO_RE.match(out), out
|
||||
# Round-trip: parsed timestamp matches the original epoch.
|
||||
parsed = datetime.fromisoformat(out)
|
||||
assert abs(parsed.timestamp() - 1777252410.411327) < 1e-3
|
||||
|
||||
|
||||
def test_coerce_iso_converts_unix_int_string() -> None:
|
||||
out = coerce_iso("1700000000")
|
||||
assert _ISO_RE.match(out), out
|
||||
|
||||
|
||||
def test_coerce_iso_converts_numeric_types() -> None:
|
||||
out_float = coerce_iso(1777252410.411327)
|
||||
out_int = coerce_iso(1700000000)
|
||||
assert _ISO_RE.match(out_float)
|
||||
assert _ISO_RE.match(out_int)
|
||||
|
||||
|
||||
def test_coerce_iso_handles_empty_and_none() -> None:
|
||||
assert coerce_iso(None) == ""
|
||||
assert coerce_iso("") == ""
|
||||
|
||||
|
||||
def test_coerce_iso_does_not_misinterpret_short_numeric() -> None:
|
||||
# A 4-digit year should never be parsed as a unix timestamp; only
|
||||
# 10-digit unix-second strings match the legacy pattern.
|
||||
assert coerce_iso("2026") == "2026"
|
||||
|
||||
|
||||
def test_coerce_iso_handles_unparseable_string() -> None:
|
||||
assert coerce_iso("not-a-timestamp") == "not-a-timestamp"
|
||||
|
||||
|
||||
def test_coerce_iso_rejects_bool() -> None:
|
||||
# ``bool`` is a subclass of ``int`` — must not be treated as epoch 0/1.
|
||||
assert coerce_iso(True) == "True"
|
||||
assert coerce_iso(False) == "False"
|
||||
|
||||
|
||||
def test_coerce_iso_handles_tz_aware_datetime() -> None:
|
||||
# str(datetime) would emit a space separator; coerce_iso must use ``T``.
|
||||
dt = datetime(2026, 4, 27, 1, 13, 30, 411327, tzinfo=UTC)
|
||||
out = coerce_iso(dt)
|
||||
assert out == "2026-04-27T01:13:30.411327+00:00"
|
||||
assert "T" in out and " " not in out
|
||||
|
||||
|
||||
def test_coerce_iso_handles_tz_naive_datetime_as_utc() -> None:
|
||||
dt = datetime(2026, 4, 27, 1, 13, 30, 411327)
|
||||
out = coerce_iso(dt)
|
||||
assert out == "2026-04-27T01:13:30.411327+00:00"
|
||||
parsed = datetime.fromisoformat(out)
|
||||
assert parsed.tzinfo is not None
|
||||
assert parsed.utcoffset() == timedelta(0)
|
||||
|
||||
|
||||
def test_coerce_iso_normalises_non_utc_datetime_to_utc() -> None:
|
||||
# +08:00 wall-clock 09:13 == UTC 01:13.
|
||||
plus_eight = timezone(timedelta(hours=8))
|
||||
dt = datetime(2026, 4, 27, 9, 13, 30, 411327, tzinfo=plus_eight)
|
||||
out = coerce_iso(dt)
|
||||
assert out == "2026-04-27T01:13:30.411327+00:00"
|
||||
@ -2,8 +2,11 @@
|
||||
#
|
||||
# Guidelines:
|
||||
# - Copy this file to `config.yaml` and customize it for your environment
|
||||
# - The default path of this configuration file is `config.yaml` in the current working directory.
|
||||
# However you can change it using the `DEER_FLOW_CONFIG_PATH` environment variable.
|
||||
# - The default path of this configuration file is `config.yaml` in the project root.
|
||||
# You can set `DEER_FLOW_PROJECT_ROOT` to define that root explicitly, or use
|
||||
# `DEER_FLOW_CONFIG_PATH` to point at a specific config file.
|
||||
# - Runtime state defaults to `.deer-flow` under the project root. Override it
|
||||
# with `DEER_FLOW_HOME` when you need a different writable data directory.
|
||||
# - Environment variables are available for all field values. Example: `api_key: $OPENAI_API_KEY`
|
||||
# - The `use` path is a string that looks like "package_name.sub_package_name.module_name:class_name/variable_name".
|
||||
|
||||
@ -370,6 +373,16 @@ tools:
|
||||
use: deerflow.community.ddg_search.tools:web_search_tool
|
||||
max_results: 5
|
||||
|
||||
# Web search tool (uses Serper - Google Search API, requires SERPER_API_KEY)
|
||||
# Serper provides real-time Google Search results. Sign up at https://serper.dev
|
||||
# Note: set SERPER_API_KEY in your environment before starting the app, or
|
||||
# uncomment and fill in api_key below (the $VAR syntax is resolved at startup).
|
||||
# - name: web_search
|
||||
# group: web
|
||||
# use: deerflow.community.serper.tools:web_search_tool
|
||||
# max_results: 5
|
||||
# # api_key: $SERPER_API_KEY # Optional if SERPER_API_KEY env var is set
|
||||
|
||||
# Web search tool (requires Tavily API key)
|
||||
# - name: web_search
|
||||
# group: web
|
||||
@ -501,6 +514,11 @@ tool_search:
|
||||
# Option 1: Local Sandbox (Default)
|
||||
# Executes commands directly on the host machine
|
||||
uploads:
|
||||
# Application-level upload limits enforced by the gateway and exposed to the
|
||||
# frontend before file selection.
|
||||
max_files: 10
|
||||
max_file_size: 52428800 # 50 MiB
|
||||
max_total_size: 104857600 # 100 MiB
|
||||
# Automatic Office/PDF conversion runs on the backend host before sandbox
|
||||
# isolation applies. Keep this disabled unless uploads come from a fully
|
||||
# trusted source and you intentionally accept host-side parser risk.
|
||||
@ -673,7 +691,8 @@ sandbox:
|
||||
|
||||
skills:
|
||||
# Path to skills directory on the host (relative to project root or absolute)
|
||||
# Default: ../skills (relative to backend directory)
|
||||
# Default: skills under the project root
|
||||
# Override with DEER_FLOW_SKILLS_PATH when this field is omitted.
|
||||
# Uncomment to customize:
|
||||
# path: /absolute/path/to/custom/skills
|
||||
|
||||
|
||||
@ -157,6 +157,7 @@ services:
|
||||
working_dir: /app
|
||||
environment:
|
||||
- CI=true
|
||||
- DEER_FLOW_PROJECT_ROOT=/app
|
||||
- DEER_FLOW_HOME=/app/backend/.deer-flow
|
||||
- DEER_FLOW_CHANNELS_LANGGRAPH_URL=${DEER_FLOW_CHANNELS_LANGGRAPH_URL:-http://gateway:8001/api}
|
||||
- DEER_FLOW_CHANNELS_GATEWAY_URL=${DEER_FLOW_CHANNELS_GATEWAY_URL:-http://gateway:8001}
|
||||
|
||||
@ -8,9 +8,11 @@
|
||||
# - provisioner: (optional) Sandbox provisioner for Kubernetes mode
|
||||
#
|
||||
# Key environment variables (set via environment/.env or scripts/deploy.sh):
|
||||
# DEER_FLOW_HOME — runtime data dir, default $REPO_ROOT/backend/.deer-flow
|
||||
# DEER_FLOW_PROJECT_ROOT — project root for relative runtime paths
|
||||
# DEER_FLOW_HOME — runtime data dir, default .deer-flow under $DEER_FLOW_PROJECT_ROOT (or cwd)
|
||||
# DEER_FLOW_CONFIG_PATH — path to config.yaml
|
||||
# DEER_FLOW_EXTENSIONS_CONFIG_PATH — path to extensions_config.json
|
||||
# DEER_FLOW_SKILLS_PATH — skills dir, default $DEER_FLOW_PROJECT_ROOT/skills
|
||||
# DEER_FLOW_DOCKER_SOCKET — Docker socket path, default /var/run/docker.sock
|
||||
# DEER_FLOW_REPO_ROOT — repo root (used for skills host path in DooD)
|
||||
# BETTER_AUTH_SECRET — required for frontend auth/session security
|
||||
@ -93,6 +95,7 @@ services:
|
||||
working_dir: /app
|
||||
environment:
|
||||
- CI=true
|
||||
- DEER_FLOW_PROJECT_ROOT=/app
|
||||
- DEER_FLOW_HOME=/app/backend/.deer-flow
|
||||
- DEER_FLOW_CONFIG_PATH=/app/backend/config.yaml
|
||||
- DEER_FLOW_EXTENSIONS_CONFIG_PATH=/app/backend/extensions_config.json
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user