feat(auth): authentication module with multi-tenant isolation (RFC-001)

Introduce an always-on auth layer with auto-created admin on first boot,
multi-tenant isolation for threads/stores, and a full setup/login flow.

Backend
- JWT access tokens with `ver` field for stale-token rejection; bump on
  password/email change
- Password hashing, HttpOnly+Secure cookies (Secure derived from request
  scheme at runtime)
- CSRF middleware covering both REST and LangGraph routes
- IP-based login rate limiting (5 attempts / 5-min lockout) with bounded
  dict growth and X-Forwarded-For bypass fix
- Multi-worker-safe admin auto-creation (single DB write, WAL once)
- needs_setup + token_version on User model; SQLite schema migration
- Thread/store isolation by owner; orphan thread migration on first admin
  registration
- thread_id validated as UUID to prevent log injection
- CLI tool to reset admin password
- Decorator-based authz module extracted from auth core

Frontend
- Login and setup pages with SSR guard for needs_setup flow
- Account settings page (change password / email)
- AuthProvider + route guards; skips redirect when no users registered
- i18n (en-US / zh-CN) for auth surfaces
- Typed auth API client; parseAuthError unwraps FastAPI detail envelope

Infra & tooling
- Unified `serve.sh` with gateway mode + auto dep install
- Public PyPI uv.toml pin for CI compatibility
- Regenerated uv.lock with public index

Tests
- HTTP vs HTTPS cookie security tests
- Auth middleware, rate limiter, CSRF, setup flow coverage
This commit is contained in:
greatmengqi 2026-04-08 00:31:43 +08:00
parent 636053fb6d
commit 27b66d6753
214 changed files with 18830 additions and 1065 deletions

View File

@ -6,6 +6,11 @@ JINA_API_KEY=your-jina-api-key
# InfoQuest API Key
INFOQUEST_API_KEY=your-infoquest-api-key
# Authentication — JWT secret for session signing
# If not set, an ephemeral secret is auto-generated (sessions lost on restart)
# Generate with: python -c "import secrets; print(secrets.token_urlsafe(32))"
# AUTH_JWT_SECRET=your-secure-jwt-secret-here
# CORS Origins (comma-separated) - e.g., http://localhost:3000,http://localhost:3001
# CORS_ORIGINS=http://localhost:3000
@ -32,3 +37,5 @@ INFOQUEST_API_KEY=your-infoquest-api-key
# GitHub API Token
# GITHUB_TOKEN=your-github-token
# WECOM_BOT_ID=your-wecom-bot-id
# WECOM_BOT_SECRET=your-wecom-bot-secret

1
.gitignore vendored
View File

@ -54,3 +54,4 @@ web/
# Deployment artifacts
backend/Dockerfile.langgraph
config.yaml.bak
.gstack/

View File

@ -310,7 +310,7 @@ Every pull request runs the backend regression workflow at [.github/workflows/ba
- [Configuration Guide](backend/docs/CONFIGURATION.md) - Setup and configuration
- [Architecture Overview](backend/CLAUDE.md) - Technical architecture
- [MCP Setup Guide](MCP_SETUP.md) - Model Context Protocol configuration
- [MCP Setup Guide](backend/docs/MCP_SERVER.md) - Model Context Protocol configuration
## Need Help?

View File

@ -1,6 +1,6 @@
# DeerFlow - Unified Development Environment
.PHONY: help config config-upgrade check install dev dev-daemon start stop up down clean docker-init docker-start docker-stop docker-logs docker-logs-frontend docker-logs-gateway
.PHONY: help config config-upgrade check install dev dev-pro dev-daemon dev-daemon-pro start start-pro start-daemon start-daemon-pro stop up up-pro down clean docker-init docker-start docker-start-pro docker-stop docker-logs docker-logs-frontend docker-logs-gateway
BASH ?= bash
@ -20,18 +20,25 @@ help:
@echo " make install - Install all dependencies (frontend + backend)"
@echo " make setup-sandbox - Pre-pull sandbox container image (recommended)"
@echo " make dev - Start all services in development mode (with hot-reloading)"
@echo " make dev-daemon - Start all services in background (daemon mode)"
@echo " make dev-pro - Start in dev + Gateway mode (experimental, no LangGraph server)"
@echo " make dev-daemon - Start dev services in background (daemon mode)"
@echo " make dev-daemon-pro - Start dev daemon + Gateway mode (experimental)"
@echo " make start - Start all services in production mode (optimized, no hot-reloading)"
@echo " make start-pro - Start in prod + Gateway mode (experimental)"
@echo " make start-daemon - Start prod services in background (daemon mode)"
@echo " make start-daemon-pro - Start prod daemon + Gateway mode (experimental)"
@echo " make stop - Stop all running services"
@echo " make clean - Clean up processes and temporary files"
@echo ""
@echo "Docker Production Commands:"
@echo " make up - Build and start production Docker services (localhost:2026)"
@echo " make up-pro - Build and start production Docker in Gateway mode (experimental)"
@echo " make down - Stop and remove production Docker containers"
@echo ""
@echo "Docker Development Commands:"
@echo " make docker-init - Pull the sandbox image"
@echo " make docker-start - Start Docker services (mode-aware from config.yaml, localhost:2026)"
@echo " make docker-start-pro - Start Docker in Gateway mode (experimental, no LangGraph container)"
@echo " make docker-stop - Stop Docker development services"
@echo " make docker-logs - View Docker development logs"
@echo " make docker-logs-frontend - View Docker frontend logs"
@ -105,6 +112,15 @@ else
@./scripts/serve.sh --dev
endif
# Start all services in dev + Gateway mode (experimental: agent runtime embedded in Gateway)
dev-pro:
@$(PYTHON) ./scripts/check.py
ifeq ($(OS),Windows_NT)
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --dev --gateway
else
@./scripts/serve.sh --dev --gateway
endif
# Start all services in production mode (with optimizations)
start:
@$(PYTHON) ./scripts/check.py
@ -114,30 +130,54 @@ else
@./scripts/serve.sh --prod
endif
# Start all services in prod + Gateway mode (experimental)
start-pro:
@$(PYTHON) ./scripts/check.py
ifeq ($(OS),Windows_NT)
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --prod --gateway
else
@./scripts/serve.sh --prod --gateway
endif
# Start all services in daemon mode (background)
dev-daemon:
@$(PYTHON) ./scripts/check.py
ifeq ($(OS),Windows_NT)
@call scripts\run-with-git-bash.cmd ./scripts/start-daemon.sh
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --dev --daemon
else
@./scripts/start-daemon.sh
@./scripts/serve.sh --dev --daemon
endif
# Start daemon + Gateway mode (experimental)
dev-daemon-pro:
@$(PYTHON) ./scripts/check.py
ifeq ($(OS),Windows_NT)
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --dev --gateway --daemon
else
@./scripts/serve.sh --dev --gateway --daemon
endif
# Start prod services in daemon mode (background)
start-daemon:
@$(PYTHON) ./scripts/check.py
ifeq ($(OS),Windows_NT)
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --prod --daemon
else
@./scripts/serve.sh --prod --daemon
endif
# Start prod daemon + Gateway mode (experimental)
start-daemon-pro:
@$(PYTHON) ./scripts/check.py
ifeq ($(OS),Windows_NT)
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --prod --gateway --daemon
else
@./scripts/serve.sh --prod --gateway --daemon
endif
# Stop all services
stop:
@echo "Stopping all services..."
@-pkill -f "langgraph dev" 2>/dev/null || true
@-pkill -f "uvicorn app.gateway.app:app" 2>/dev/null || true
@-pkill -f "next dev" 2>/dev/null || true
@-pkill -f "next start" 2>/dev/null || true
@-pkill -f "next-server" 2>/dev/null || true
@-pkill -f "next-server" 2>/dev/null || true
@-nginx -c $(PWD)/docker/nginx/nginx.local.conf -p $(PWD) -s quit 2>/dev/null || true
@sleep 1
@-pkill -9 nginx 2>/dev/null || true
@echo "Cleaning up sandbox containers..."
@-./scripts/cleanup-containers.sh deer-flow-sandbox 2>/dev/null || true
@echo "✓ All services stopped"
@./scripts/serve.sh --stop
# Clean up
clean: stop
@ -159,6 +199,10 @@ docker-init:
docker-start:
@./scripts/docker.sh start
# Start Docker in Gateway mode (experimental)
docker-start-pro:
@./scripts/docker.sh start --gateway
# Stop Docker development environment
docker-stop:
@./scripts/docker.sh stop
@ -181,6 +225,10 @@ docker-logs-gateway:
up:
@./scripts/deploy.sh
# Build and start production services in Gateway mode
up-pro:
@./scripts/deploy.sh --gateway
# Stop and remove production containers
down:
@./scripts/deploy.sh down

View File

@ -46,6 +46,7 @@ DeerFlow has newly integrated the intelligent search and crawling toolset indepe
- [🦌 DeerFlow - 2.0](#-deerflow---20)
- [Official Website](#official-website)
- [Coding Plan from ByteDance Volcengine](#coding-plan-from-bytedance-volcengine)
- [InfoQuest](#infoquest)
- [Table of Contents](#table-of-contents)
- [One-Line Agent Setup](#one-line-agent-setup)
@ -59,6 +60,8 @@ DeerFlow has newly integrated the intelligent search and crawling toolset indepe
- [MCP Server](#mcp-server)
- [IM Channels](#im-channels)
- [LangSmith Tracing](#langsmith-tracing)
- [Langfuse Tracing](#langfuse-tracing)
- [Using Both Providers](#using-both-providers)
- [From Deep Research to Super Agent Harness](#from-deep-research-to-super-agent-harness)
- [Core Features](#core-features)
- [Skills \& Tools](#skills--tools)
@ -71,6 +74,8 @@ DeerFlow has newly integrated the intelligent search and crawling toolset indepe
- [Embedded Python Client](#embedded-python-client)
- [Documentation](#documentation)
- [⚠️ Security Notice](#-security-notice)
- [Improper Deployment May Introduce Security Risks](#improper-deployment-may-introduce-security-risks)
- [Security Recommendations](#security-recommendations)
- [Contributing](#contributing)
- [License](#license)
- [Acknowledgments](#acknowledgments)
@ -275,6 +280,60 @@ On Windows, run the local development flow from Git Bash. Native `cmd.exe` and P
6. **Access**: http://localhost:2026
#### Startup Modes
DeerFlow supports multiple startup modes across two dimensions:
- **Dev / Prod** — dev enables hot-reload; prod uses pre-built frontend
- **Standard / Gateway** — standard uses a separate LangGraph server (4 processes); Gateway mode (experimental) embeds the agent runtime in the Gateway API (3 processes)
| | **Local Foreground** | **Local Daemon** | **Docker Dev** | **Docker Prod** |
|---|---|---|---|---|
| **Dev** | `./scripts/serve.sh --dev`<br/>`make dev` | `./scripts/serve.sh --dev --daemon`<br/>`make dev-daemon` | `./scripts/docker.sh start`<br/>`make docker-start` | — |
| **Dev + Gateway** | `./scripts/serve.sh --dev --gateway`<br/>`make dev-pro` | `./scripts/serve.sh --dev --gateway --daemon`<br/>`make dev-daemon-pro` | `./scripts/docker.sh start --gateway`<br/>`make docker-start-pro` | — |
| **Prod** | `./scripts/serve.sh --prod`<br/>`make start` | `./scripts/serve.sh --prod --daemon`<br/>`make start-daemon` | — | `./scripts/deploy.sh`<br/>`make up` |
| **Prod + Gateway** | `./scripts/serve.sh --prod --gateway`<br/>`make start-pro` | `./scripts/serve.sh --prod --gateway --daemon`<br/>`make start-daemon-pro` | — | `./scripts/deploy.sh --gateway`<br/>`make up-pro` |
| Action | Local | Docker Dev | Docker Prod |
|---|---|---|---|
| **Stop** | `./scripts/serve.sh --stop`<br/>`make stop` | `./scripts/docker.sh stop`<br/>`make docker-stop` | `./scripts/deploy.sh down`<br/>`make down` |
| **Restart** | `./scripts/serve.sh --restart [flags]` | `./scripts/docker.sh restart` | — |
> **Gateway mode** eliminates the LangGraph server process — the Gateway API handles agent execution directly via async tasks, managing its own concurrency.
#### Why Gateway Mode?
In standard mode, DeerFlow runs a dedicated [LangGraph Platform](https://langchain-ai.github.io/langgraph/) server alongside the Gateway API. This architecture works well but has trade-offs:
| | Standard Mode | Gateway Mode |
|---|---|---|
| **Architecture** | Gateway (REST API) + LangGraph (agent runtime) | Gateway embeds agent runtime |
| **Concurrency** | `--n-jobs-per-worker` per worker (requires license) | `--workers` × async tasks (no per-worker cap) |
| **Containers / Processes** | 4 (frontend, gateway, langgraph, nginx) | 3 (frontend, gateway, nginx) |
| **Resource usage** | Higher (two Python runtimes) | Lower (single Python runtime) |
| **LangGraph Platform license** | Required for production images | Not required |
| **Cold start** | Slower (two services to initialize) | Faster |
Both modes are functionally equivalent — the same agents, tools, and skills work in either mode.
#### Docker Production Deployment
`deploy.sh` supports building and starting separately. Images are mode-agnostic — runtime mode is selected at start time:
```bash
# One-step (build + start)
deploy.sh # standard mode (default)
deploy.sh --gateway # gateway mode
# Two-step (build once, start with any mode)
deploy.sh build # build all images
deploy.sh start # start in standard mode
deploy.sh start --gateway # start in gateway mode
# Stop
deploy.sh down
```
### Advanced
#### Sandbox Mode
@ -302,6 +361,7 @@ DeerFlow supports receiving tasks from messaging apps. Channels auto-start when
| Telegram | Bot API (long-polling) | Easy |
| Slack | Socket Mode | Moderate |
| Feishu / Lark | WebSocket | Moderate |
| WeCom | WebSocket | Moderate |
**Configuration in `config.yaml`:**
@ -329,6 +389,11 @@ channels:
# domain: https://open.feishu.cn # China (default)
# domain: https://open.larksuite.com # International
wecom:
enabled: true
bot_id: $WECOM_BOT_ID
bot_secret: $WECOM_BOT_SECRET
slack:
enabled: true
bot_token: $SLACK_BOT_TOKEN # xoxb-...
@ -372,6 +437,10 @@ SLACK_APP_TOKEN=xapp-...
# Feishu / Lark
FEISHU_APP_ID=cli_xxxx
FEISHU_APP_SECRET=your_app_secret
# WeCom
WECOM_BOT_ID=your_bot_id
WECOM_BOT_SECRET=your_bot_secret
```
**Telegram Setup**
@ -394,6 +463,14 @@ FEISHU_APP_SECRET=your_app_secret
3. Under **Events**, subscribe to `im.message.receive_v1` and select **Long Connection** mode.
4. Copy the App ID and App Secret. Set `FEISHU_APP_ID` and `FEISHU_APP_SECRET` in `.env` and enable the channel in `config.yaml`.
**WeCom Setup**
1. Create a bot on the WeCom AI Bot platform and obtain the `bot_id` and `bot_secret`.
2. Enable `channels.wecom` in `config.yaml` and fill in `bot_id` / `bot_secret`.
3. Set `WECOM_BOT_ID` and `WECOM_BOT_SECRET` in `.env`.
4. Make sure backend dependencies include `wecom-aibot-python-sdk`. The channel uses a WebSocket long connection and does not require a public callback URL.
5. The current integration supports inbound text, image, and file messages. Final images/files generated by the agent are also sent back to the WeCom conversation.
When DeerFlow runs in Docker Compose, IM channels execute inside the `gateway` container. In that case, do not point `channels.langgraph_url` or `channels.gateway_url` at `localhost`; use container service names such as `http://langgraph:2024` and `http://gateway:8001`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` and `DEER_FLOW_CHANNELS_GATEWAY_URL`.
**Commands**

View File

@ -232,6 +232,7 @@ DeerFlow 支持从即时通讯应用接收任务。只要配置完成,对应
| Telegram | Bot APIlong-polling | 简单 |
| Slack | Socket Mode | 中等 |
| Feishu / Lark | WebSocket | 中等 |
| 企业微信智能机器人 | WebSocket | 中等 |
**`config.yaml` 中的配置示例:**
@ -259,6 +260,11 @@ channels:
# domain: https://open.feishu.cn # 国内版(默认)
# domain: https://open.larksuite.com # 国际版
wecom:
enabled: true
bot_id: $WECOM_BOT_ID
bot_secret: $WECOM_BOT_SECRET
slack:
enabled: true
bot_token: $SLACK_BOT_TOKEN # xoxb-...
@ -302,6 +308,10 @@ SLACK_APP_TOKEN=xapp-...
# Feishu / Lark
FEISHU_APP_ID=cli_xxxx
FEISHU_APP_SECRET=your_app_secret
# 企业微信智能机器人
WECOM_BOT_ID=your_bot_id
WECOM_BOT_SECRET=your_bot_secret
```
**Telegram 配置**
@ -324,6 +334,14 @@ FEISHU_APP_SECRET=your_app_secret
3. 在 **事件订阅** 中订阅 `im.message.receive_v1`,连接方式选择 **长连接**
4. 复制 App ID 和 App Secret`.env` 中设置 `FEISHU_APP_ID``FEISHU_APP_SECRET`,并在 `config.yaml` 中启用该渠道。
**企业微信智能机器人配置**
1. 在企业微信智能机器人平台创建机器人,获取 `bot_id``bot_secret`
2. 在 `config.yaml` 中启用 `channels.wecom`,并填入 `bot_id` / `bot_secret`
3. 在 `.env` 中设置 `WECOM_BOT_ID``WECOM_BOT_SECRET`
4. 安装后端依赖时确保包含 `wecom-aibot-python-sdk`,渠道会通过 WebSocket 长连接接收消息,无需公网回调地址。
5. 当前支持文本、图片和文件入站消息agent 生成的最终图片/文件也会回传到企业微信会话中。
**命令**
渠道连接完成后,你可以直接在聊天窗口里和 DeerFlow 交互:

View File

@ -13,6 +13,10 @@ DeerFlow is a LangGraph-based AI super agent system with a full-stack architectu
- **Nginx** (port 2026): Unified reverse proxy entry point
- **Provisioner** (port 8002, optional in Docker dev): Started only when sandbox is configured for provisioner/Kubernetes mode
**Runtime Modes**:
- **Standard mode** (`make dev`): LangGraph Server handles agent execution as a separate process. 4 processes total.
- **Gateway mode** (`make dev-pro`, experimental): Agent runtime embedded in Gateway via `RunManager` + `run_agent()` + `StreamBridge` (`packages/harness/deerflow/runtime/`). Service manages its own concurrency via async tasks. 3 processes total, no LangGraph Server.
**Project Structure**:
```
deer-flow/
@ -80,6 +84,8 @@ When making code changes, you MUST update the relevant documentation:
make check # Check system requirements
make install # Install all dependencies (frontend + backend)
make dev # Start all services (LangGraph + Gateway + Frontend + Nginx), with config.yaml preflight
make dev-pro # Gateway mode (experimental): skip LangGraph, agent runtime embedded in Gateway
make start-pro # Production + Gateway mode (experimental)
make stop # Stop all services
```
@ -436,8 +442,25 @@ make dev
This starts all services and makes the application available at `http://localhost:2026`.
**All startup modes:**
| | **Local Foreground** | **Local Daemon** | **Docker Dev** | **Docker Prod** |
|---|---|---|---|---|
| **Dev** | `./scripts/serve.sh --dev`<br/>`make dev` | `./scripts/serve.sh --dev --daemon`<br/>`make dev-daemon` | `./scripts/docker.sh start`<br/>`make docker-start` | — |
| **Dev + Gateway** | `./scripts/serve.sh --dev --gateway`<br/>`make dev-pro` | `./scripts/serve.sh --dev --gateway --daemon`<br/>`make dev-daemon-pro` | `./scripts/docker.sh start --gateway`<br/>`make docker-start-pro` | — |
| **Prod** | `./scripts/serve.sh --prod`<br/>`make start` | `./scripts/serve.sh --prod --daemon`<br/>`make start-daemon` | — | `./scripts/deploy.sh`<br/>`make up` |
| **Prod + Gateway** | `./scripts/serve.sh --prod --gateway`<br/>`make start-pro` | `./scripts/serve.sh --prod --gateway --daemon`<br/>`make start-daemon-pro` | — | `./scripts/deploy.sh --gateway`<br/>`make up-pro` |
| Action | Local | Docker Dev | Docker Prod |
|---|---|---|---|
| **Stop** | `./scripts/serve.sh --stop`<br/>`make stop` | `./scripts/docker.sh stop`<br/>`make docker-stop` | `./scripts/deploy.sh down`<br/>`make down` |
| **Restart** | `./scripts/serve.sh --restart [flags]` | `./scripts/docker.sh restart` | — |
Gateway mode embeds the agent runtime in Gateway, no LangGraph server.
**Nginx routing**:
- `/api/langgraph/*` → LangGraph Server (2024)
- Standard mode: `/api/langgraph/*` → LangGraph Server (2024)
- Gateway mode: `/api/langgraph/*` → Gateway embedded runtime (8001) (via envsubst)
- `/api/*` (other) → Gateway API (8001)
- `/` (non-API) → Frontend (3000)

View File

@ -1,34 +1,86 @@
# Backend Development Dockerfile
# Backend Dockerfile — multi-stage build
# Stage 1 (builder): compiles native Python extensions with build-essential
# Stage 2 (dev): retains toolchain for dev containers (uv sync at startup)
# Stage 3 (runtime): clean image without compiler toolchain for production
# UV source image (override for restricted networks that cannot reach ghcr.io)
ARG UV_IMAGE=ghcr.io/astral-sh/uv:0.7.20
FROM ${UV_IMAGE} AS uv-source
FROM python:3.12-slim-bookworm
# ── Stage 1: Builder ──────────────────────────────────────────────────────────
FROM python:3.12-slim-bookworm AS builder
ARG NODE_MAJOR=22
ARG NODE_VERSION=22.16.0
ARG APT_MIRROR
ARG UV_INDEX_URL
ARG NODE_DIST_URL
# Optionally override apt mirror for restricted networks (e.g. APT_MIRROR=mirrors.aliyun.com)
# Optionally override apt mirror for restricted networks (e.g. APT_MIRROR=mirrors.byted.org)
RUN if [ -n "${APT_MIRROR}" ]; then \
sed -i "s|deb.debian.org|${APT_MIRROR}|g" /etc/apt/sources.list.d/debian.sources 2>/dev/null || true; \
sed -i "s|deb.debian.org|${APT_MIRROR}|g" /etc/apt/sources.list 2>/dev/null || true; \
fi
# Install system dependencies + Node.js (provides npx for MCP servers)
# Install build tools + Node.js (build-essential needed for native Python extensions)
# NODE_DIST_URL: base URL for Node.js binary tarballs in restricted networks.
# npmmirror: https://registry.npmmirror.com/-/binary/node
# official: https://nodejs.org/dist (default, via nodesource apt)
RUN apt-get update && apt-get install -y \
curl \
build-essential \
gnupg \
ca-certificates \
&& mkdir -p /etc/apt/keyrings \
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg \
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" > /etc/apt/sources.list.d/nodesource.list \
&& apt-get update \
&& apt-get install -y nodejs \
xz-utils \
&& if [ -n "${NODE_DIST_URL}" ]; then \
curl -fsSL "${NODE_DIST_URL}/v${NODE_VERSION}/node-v${NODE_VERSION}-linux-x64.tar.xz" \
| tar -xJ --strip-components=1 -C /usr/local \
&& ln -sf /usr/local/bin/node /usr/bin/node \
&& ln -sf /usr/local/lib/node_modules /usr/lib/node_modules; \
else \
mkdir -p /etc/apt/keyrings \
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg \
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" > /etc/apt/sources.list.d/nodesource.list \
&& apt-get update \
&& apt-get install -y nodejs; \
fi \
&& rm -rf /var/lib/apt/lists/*
# Install uv (source image overridable via UV_IMAGE build arg)
COPY --from=uv-source /uv /uvx /usr/local/bin/
# Set working directory
WORKDIR /app
# Copy backend source code
COPY backend ./backend
# Install dependencies with cache mount
RUN --mount=type=cache,target=/root/.cache/uv \
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync"
# ── Stage 2: Dev ──────────────────────────────────────────────────────────────
# Retains compiler toolchain from builder so startup-time `uv sync` can build
# source distributions in development containers.
FROM builder AS dev
# Install Docker CLI (for DooD: allows starting sandbox containers via host Docker socket)
COPY --from=docker:cli /usr/local/bin/docker /usr/local/bin/docker
EXPOSE 8001 2024
CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001"]
# ── Stage 3: Runtime ──────────────────────────────────────────────────────────
# Clean image without build-essential — reduces size (~200 MB) and attack surface.
FROM python:3.12-slim-bookworm
# Copy Node.js runtime from builder (provides npx for MCP servers)
COPY --from=builder /usr/bin/node /usr/bin/node
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
# Install Docker CLI (for DooD: allows starting sandbox containers via host Docker socket)
COPY --from=docker:cli /usr/local/bin/docker /usr/local/bin/docker
@ -38,12 +90,8 @@ COPY --from=uv-source /uv /uvx /usr/local/bin/
# Set working directory
WORKDIR /app
# Copy frontend source code
COPY backend ./backend
# Install dependencies with cache mount
RUN --mount=type=cache,target=/root/.cache/uv \
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync"
# Copy backend with pre-built virtualenv from builder
COPY --from=builder /app/backend ./backend
# Expose ports (gateway: 8001, langgraph: 2024)
EXPOSE 8001 2024

View File

@ -2,7 +2,7 @@ install:
uv sync
dev:
uv run langgraph dev --no-browser --allow-blocking --no-reload --n-jobs-per-worker 10
uv run langgraph dev --no-browser --no-reload --n-jobs-per-worker 10
gateway:
PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001

View File

@ -206,7 +206,9 @@ class FeishuChannel(Channel):
await asyncio.sleep(delay)
logger.error("[Feishu] send failed after %d attempts: %s", _max_retries, last_exc)
raise last_exc # type: ignore[misc]
if last_exc is None:
raise RuntimeError("Feishu send failed without an exception from any attempt")
raise last_exc
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
if not self._api_client:

View File

@ -7,9 +7,10 @@ import logging
import mimetypes
import re
import time
from collections.abc import Mapping
from collections.abc import Awaitable, Callable, Mapping
from typing import Any
import httpx
from langgraph_sdk.errors import ConflictError
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
@ -36,8 +37,49 @@ CHANNEL_CAPABILITIES = {
"feishu": {"supports_streaming": True},
"slack": {"supports_streaming": False},
"telegram": {"supports_streaming": False},
"wecom": {"supports_streaming": True},
}
InboundFileReader = Callable[[dict[str, Any], httpx.AsyncClient], Awaitable[bytes | None]]
INBOUND_FILE_READERS: dict[str, InboundFileReader] = {}
def register_inbound_file_reader(channel_name: str, reader: InboundFileReader) -> None:
INBOUND_FILE_READERS[channel_name] = reader
async def _read_http_inbound_file(file_info: dict[str, Any], client: httpx.AsyncClient) -> bytes | None:
url = file_info.get("url")
if not isinstance(url, str) or not url:
return None
resp = await client.get(url)
resp.raise_for_status()
return resp.content
async def _read_wecom_inbound_file(file_info: dict[str, Any], client: httpx.AsyncClient) -> bytes | None:
data = await _read_http_inbound_file(file_info, client)
if data is None:
return None
aeskey = file_info.get("aeskey") if isinstance(file_info.get("aeskey"), str) else None
if not aeskey:
return data
try:
from aibot.crypto_utils import decrypt_file
except Exception:
logger.exception("[Manager] failed to import WeCom decrypt_file")
return None
return decrypt_file(data, aeskey)
register_inbound_file_reader("wecom", _read_wecom_inbound_file)
class InvalidChannelSessionConfigError(ValueError):
"""Raised when IM channel session overrides contain invalid agent config."""
@ -342,6 +384,105 @@ def _prepare_artifact_delivery(
return response_text, attachments
async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dict[str, Any]]:
if not msg.files:
return []
from deerflow.uploads.manager import claim_unique_filename, ensure_uploads_dir, normalize_filename
uploads_dir = ensure_uploads_dir(thread_id)
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
created: list[dict[str, Any]] = []
file_reader = INBOUND_FILE_READERS.get(msg.channel_name, _read_http_inbound_file)
async with httpx.AsyncClient(timeout=httpx.Timeout(20.0)) as client:
for idx, f in enumerate(msg.files):
if not isinstance(f, dict):
continue
ftype = f.get("type") if isinstance(f.get("type"), str) else "file"
filename = f.get("filename") if isinstance(f.get("filename"), str) else ""
try:
data = await file_reader(f, client)
except Exception:
logger.exception(
"[Manager] failed to read inbound file: channel=%s, file=%s",
msg.channel_name,
f.get("url") or filename or idx,
)
continue
if data is None:
logger.warning(
"[Manager] inbound file reader returned no data: channel=%s, file=%s",
msg.channel_name,
f.get("url") or filename or idx,
)
continue
if not filename:
ext = ".bin"
if ftype == "image":
ext = ".png"
filename = f"{msg.thread_ts or 'msg'}_{idx}{ext}"
try:
safe_name = claim_unique_filename(normalize_filename(filename), seen_names)
except ValueError:
logger.warning(
"[Manager] skipping inbound file with unsafe filename: channel=%s, file=%r",
msg.channel_name,
filename,
)
continue
dest = uploads_dir / safe_name
try:
dest.write_bytes(data)
except Exception:
logger.exception("[Manager] failed to write inbound file: %s", dest)
continue
created.append(
{
"filename": safe_name,
"size": len(data),
"path": f"/mnt/user-data/uploads/{safe_name}",
"is_image": ftype == "image",
}
)
return created
def _format_uploaded_files_block(files: list[dict[str, Any]]) -> str:
lines = [
"<uploaded_files>",
"The following files were uploaded in this message:",
"",
]
if not files:
lines.append("(empty)")
else:
for f in files:
filename = f.get("filename", "")
size = int(f.get("size") or 0)
size_kb = size / 1024 if size else 0
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
path = f.get("path", "")
is_image = bool(f.get("is_image"))
file_kind = "image" if is_image else "file"
lines.append(f"- {filename} ({size_str})")
lines.append(f" Type: {file_kind}")
lines.append(f" Path: {path}")
lines.append("")
lines.append("Use `read_file` for text-based files and documents.")
lines.append("Use `view_image` for image files (jpg, jpeg, png, webp) so the model can inspect the image content.")
lines.append("</uploaded_files>")
return "\n".join(lines)
class ChannelManager:
"""Core dispatcher that bridges IM channels to the DeerFlow agent.
@ -536,6 +677,11 @@ class ChannelManager:
assistant_id, run_config, run_context = self._resolve_run_params(msg, thread_id)
if extra_context:
run_context.update(extra_context)
uploaded = await _ingest_inbound_files(thread_id, msg)
if uploaded:
msg.text = f"{_format_uploaded_files_block(uploaded)}\n\n{msg.text}".strip()
if self._channel_supports_streaming(msg.channel_name):
await self._handle_streaming_chat(
client,

View File

@ -17,6 +17,7 @@ _CHANNEL_REGISTRY: dict[str, str] = {
"feishu": "app.channels.feishu:FeishuChannel",
"slack": "app.channels.slack:SlackChannel",
"telegram": "app.channels.telegram:TelegramChannel",
"wecom": "app.channels.wecom:WeComChannel",
}
_CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"

View File

@ -30,7 +30,7 @@ class SlackChannel(Channel):
self._socket_client = None
self._web_client = None
self._loop: asyncio.AbstractEventLoop | None = None
self._allowed_users: set[str] = set(config.get("allowed_users", []))
self._allowed_users: set[str] = {str(user_id) for user_id in config.get("allowed_users", [])}
async def start(self) -> None:
if self._running:
@ -126,7 +126,9 @@ class SlackChannel(Channel):
)
except Exception:
pass
raise last_exc # type: ignore[misc]
if last_exc is None:
raise RuntimeError("Slack send failed without an exception from any attempt")
raise last_exc
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
if not self._web_client:

View File

@ -125,7 +125,9 @@ class TelegramChannel(Channel):
await asyncio.sleep(delay)
logger.error("[Telegram] send failed after %d attempts: %s", _max_retries, last_exc)
raise last_exc # type: ignore[misc]
if last_exc is None:
raise RuntimeError("Telegram send failed without an exception from any attempt")
raise last_exc
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
if not self._application:

View File

@ -0,0 +1,394 @@
from __future__ import annotations
import asyncio
import base64
import hashlib
import logging
from collections.abc import Awaitable, Callable
from typing import Any, cast
from app.channels.base import Channel
from app.channels.message_bus import (
InboundMessageType,
MessageBus,
OutboundMessage,
ResolvedAttachment,
)
logger = logging.getLogger(__name__)
class WeComChannel(Channel):
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
super().__init__(name="wecom", bus=bus, config=config)
self._bot_id: str | None = None
self._bot_secret: str | None = None
self._ws_client = None
self._ws_task: asyncio.Task | None = None
self._ws_frames: dict[str, dict[str, Any]] = {}
self._ws_stream_ids: dict[str, str] = {}
self._working_message = "Working on it..."
def _clear_ws_context(self, thread_ts: str | None) -> None:
if not thread_ts:
return
self._ws_frames.pop(thread_ts, None)
self._ws_stream_ids.pop(thread_ts, None)
async def _send_ws_upload_command(self, req_id: str, body: dict[str, Any], cmd: str) -> dict[str, Any]:
if not self._ws_client:
raise RuntimeError("WeCom WebSocket client is not available")
ws_manager = getattr(self._ws_client, "_ws_manager", None)
send_reply = getattr(ws_manager, "send_reply", None)
if not callable(send_reply):
raise RuntimeError("Installed wecom-aibot-python-sdk does not expose the WebSocket media upload API expected by DeerFlow. Use wecom-aibot-python-sdk==0.1.6 or update the adapter.")
send_reply_async = cast(Callable[[str, dict[str, Any], str], Awaitable[dict[str, Any]]], send_reply)
return await send_reply_async(req_id, body, cmd)
async def start(self) -> None:
if self._running:
return
bot_id = self.config.get("bot_id")
bot_secret = self.config.get("bot_secret")
working_message = self.config.get("working_message")
self._bot_id = bot_id if isinstance(bot_id, str) and bot_id else None
self._bot_secret = bot_secret if isinstance(bot_secret, str) and bot_secret else None
self._working_message = working_message if isinstance(working_message, str) and working_message else "Working on it..."
if not self._bot_id or not self._bot_secret:
logger.error("WeCom channel requires bot_id and bot_secret")
return
try:
from aibot import WSClient, WSClientOptions
except ImportError:
logger.error("wecom-aibot-python-sdk is not installed. Install it with: uv add wecom-aibot-python-sdk")
return
else:
self._ws_client = WSClient(WSClientOptions(bot_id=self._bot_id, secret=self._bot_secret, logger=logger))
self._ws_client.on("message.text", self._on_ws_text)
self._ws_client.on("message.mixed", self._on_ws_mixed)
self._ws_client.on("message.image", self._on_ws_image)
self._ws_client.on("message.file", self._on_ws_file)
self._ws_task = asyncio.create_task(self._ws_client.connect())
self._running = True
self.bus.subscribe_outbound(self._on_outbound)
logger.info("WeCom channel started")
async def stop(self) -> None:
self._running = False
self.bus.unsubscribe_outbound(self._on_outbound)
if self._ws_task:
try:
self._ws_task.cancel()
except Exception:
pass
self._ws_task = None
if self._ws_client:
try:
self._ws_client.disconnect()
except Exception:
pass
self._ws_client = None
self._ws_frames.clear()
self._ws_stream_ids.clear()
logger.info("WeCom channel stopped")
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
if self._ws_client:
await self._send_ws(msg, _max_retries=_max_retries)
return
logger.warning("[WeCom] send called but WebSocket client is not available")
async def _on_outbound(self, msg: OutboundMessage) -> None:
if msg.channel_name != self.name:
return
try:
await self.send(msg)
except Exception:
logger.exception("Failed to send outbound message on channel %s", self.name)
if msg.is_final:
self._clear_ws_context(msg.thread_ts)
return
for attachment in msg.attachments:
try:
success = await self.send_file(msg, attachment)
if not success:
logger.warning("[%s] file upload skipped for %s", self.name, attachment.filename)
except Exception:
logger.exception("[%s] failed to upload file %s", self.name, attachment.filename)
if msg.is_final:
self._clear_ws_context(msg.thread_ts)
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
if not msg.is_final:
return True
if not self._ws_client:
return False
if not msg.thread_ts:
return False
frame = self._ws_frames.get(msg.thread_ts)
if not frame:
return False
media_type = "image" if attachment.is_image else "file"
size_limit = 2 * 1024 * 1024 if attachment.is_image else 20 * 1024 * 1024
if attachment.size > size_limit:
logger.warning(
"[WeCom] %s too large (%d bytes), skipping: %s",
media_type,
attachment.size,
attachment.filename,
)
return False
try:
media_id = await self._upload_media_ws(
media_type=media_type,
filename=attachment.filename,
path=str(attachment.actual_path),
size=attachment.size,
)
if not media_id:
return False
body = {media_type: {"media_id": media_id}, "msgtype": media_type}
await self._ws_client.reply(frame, body)
logger.debug("[WeCom] %s sent via ws: %s", media_type, attachment.filename)
return True
except Exception:
logger.exception("[WeCom] failed to upload/send file via ws: %s", attachment.filename)
return False
async def _on_ws_text(self, frame: dict[str, Any]) -> None:
body = frame.get("body", {}) or {}
text = ((body.get("text") or {}).get("content") or "").strip()
quote = body.get("quote", {}).get("text", {}).get("content", "").strip()
if not text and not quote:
return
await self._publish_ws_inbound(frame, text + (f"\nQuote message: {quote}" if quote else ""))
async def _on_ws_mixed(self, frame: dict[str, Any]) -> None:
body = frame.get("body", {}) or {}
mixed = body.get("mixed") or {}
items = mixed.get("msg_item") or []
parts: list[str] = []
files: list[dict[str, Any]] = []
for item in items:
item_type = (item or {}).get("msgtype")
if item_type == "text":
content = (((item or {}).get("text") or {}).get("content") or "").strip()
if content:
parts.append(content)
elif item_type in ("image", "file"):
payload = (item or {}).get(item_type) or {}
url = payload.get("url")
aeskey = payload.get("aeskey")
if isinstance(url, str) and url:
files.append(
{
"type": item_type,
"url": url,
"aeskey": (aeskey if isinstance(aeskey, str) and aeskey else None),
}
)
text = "\n\n".join(parts).strip()
if not text and not files:
return
if not text:
text = "receive image/file"
await self._publish_ws_inbound(frame, text, files=files)
async def _on_ws_image(self, frame: dict[str, Any]) -> None:
body = frame.get("body", {}) or {}
image = body.get("image") or {}
url = image.get("url")
aeskey = image.get("aeskey")
if not isinstance(url, str) or not url:
return
await self._publish_ws_inbound(
frame,
"receive image ",
files=[
{
"type": "image",
"url": url,
"aeskey": aeskey if isinstance(aeskey, str) and aeskey else None,
}
],
)
async def _on_ws_file(self, frame: dict[str, Any]) -> None:
body = frame.get("body", {}) or {}
file_obj = body.get("file") or {}
url = file_obj.get("url")
aeskey = file_obj.get("aeskey")
if not isinstance(url, str) or not url:
return
await self._publish_ws_inbound(
frame,
"receive file",
files=[
{
"type": "file",
"url": url,
"aeskey": aeskey if isinstance(aeskey, str) and aeskey else None,
}
],
)
async def _publish_ws_inbound(
self,
frame: dict[str, Any],
text: str,
*,
files: list[dict[str, Any]] | None = None,
) -> None:
if not self._ws_client:
return
try:
from aibot import generate_req_id
except Exception:
return
body = frame.get("body", {}) or {}
msg_id = body.get("msgid")
if not msg_id:
return
user_id = (body.get("from") or {}).get("userid")
inbound_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
inbound = self._make_inbound(
chat_id=user_id, # keep user's conversation in memory
user_id=user_id,
text=text,
msg_type=inbound_type,
thread_ts=msg_id,
files=files or [],
metadata={"aibotid": body.get("aibotid"), "chattype": body.get("chattype")},
)
inbound.topic_id = user_id # keep the same thread
stream_id = generate_req_id("stream")
self._ws_frames[msg_id] = frame
self._ws_stream_ids[msg_id] = stream_id
try:
await self._ws_client.reply_stream(frame, stream_id, self._working_message, False)
except Exception:
pass
await self.bus.publish_inbound(inbound)
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
if not self._ws_client:
return
try:
from aibot import generate_req_id
except Exception:
generate_req_id = None
if msg.thread_ts and msg.thread_ts in self._ws_frames:
frame = self._ws_frames[msg.thread_ts]
stream_id = self._ws_stream_ids.get(msg.thread_ts)
if not stream_id and generate_req_id:
stream_id = generate_req_id("stream")
self._ws_stream_ids[msg.thread_ts] = stream_id
if not stream_id:
return
last_exc: Exception | None = None
for attempt in range(_max_retries):
try:
await self._ws_client.reply_stream(frame, stream_id, msg.text, bool(msg.is_final))
return
except Exception as exc:
last_exc = exc
if attempt < _max_retries - 1:
await asyncio.sleep(2**attempt)
if last_exc:
raise last_exc
body = {"msgtype": "markdown", "markdown": {"content": msg.text}}
last_exc = None
for attempt in range(_max_retries):
try:
await self._ws_client.send_message(msg.chat_id, body)
return
except Exception as exc:
last_exc = exc
if attempt < _max_retries - 1:
await asyncio.sleep(2**attempt)
if last_exc:
raise last_exc
async def _upload_media_ws(
self,
*,
media_type: str,
filename: str,
path: str,
size: int,
) -> str | None:
if not self._ws_client:
return None
try:
from aibot import generate_req_id
except Exception:
return None
chunk_size = 512 * 1024
total_chunks = (size + chunk_size - 1) // chunk_size
if total_chunks < 1 or total_chunks > 100:
logger.warning("[WeCom] invalid total_chunks=%d for %s", total_chunks, filename)
return None
md5_hasher = hashlib.md5()
with open(path, "rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
md5_hasher.update(chunk)
md5 = md5_hasher.hexdigest()
init_req_id = generate_req_id("aibot_upload_media_init")
init_body = {
"type": media_type,
"filename": filename,
"total_size": int(size),
"total_chunks": int(total_chunks),
"md5": md5,
}
init_ack = await self._send_ws_upload_command(init_req_id, init_body, "aibot_upload_media_init")
upload_id = (init_ack.get("body") or {}).get("upload_id")
if not upload_id:
logger.warning("[WeCom] upload init returned no upload_id: %s", init_ack)
return None
with open(path, "rb") as f:
for idx in range(total_chunks):
data = f.read(chunk_size)
if not data:
break
chunk_req_id = generate_req_id("aibot_upload_media_chunk")
chunk_body = {
"upload_id": upload_id,
"chunk_index": int(idx),
"base64_data": base64.b64encode(data).decode("utf-8"),
}
await self._send_ws_upload_command(chunk_req_id, chunk_body, "aibot_upload_media_chunk")
finish_req_id = generate_req_id("aibot_upload_media_finish")
finish_ack = await self._send_ws_upload_command(finish_req_id, {"upload_id": upload_id}, "aibot_upload_media_finish")
media_id = (finish_ack.get("body") or {}).get("media_id")
if not media_id:
logger.warning("[WeCom] upload finish returned no media_id: %s", finish_ack)
return None
return media_id

View File

@ -1,15 +1,21 @@
import logging
import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from datetime import UTC
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.gateway.auth_middleware import AuthMiddleware
from app.gateway.config import get_gateway_config
from app.gateway.csrf_middleware import CSRFMiddleware
from app.gateway.deps import langgraph_runtime
from app.gateway.routers import (
agents,
artifacts,
assistants_compat,
auth,
channels,
mcp,
memory,
@ -33,6 +39,88 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
async def _ensure_admin_user(app: FastAPI) -> None:
"""Auto-create the admin user on first boot if no users exist.
Prints the generated password to stdout so the operator can log in.
On subsequent boots, warns if any user still needs setup.
Multi-worker safe: relies on SQLite UNIQUE constraint to resolve races.
Only the worker that successfully creates/updates the admin prints the
password; losers silently skip.
"""
import secrets
from app.gateway.deps import get_local_provider
provider = get_local_provider()
user_count = await provider.count_users()
if user_count == 0:
password = secrets.token_urlsafe(16)
try:
admin = await provider.create_user(email="admin@deerflow.dev", password=password, system_role="admin", needs_setup=True)
except ValueError:
return # Another worker already created the admin.
# Migrate orphaned threads (no user_id) to this admin
store = getattr(app.state, "store", None)
if store is not None:
await _migrate_orphaned_threads(store, str(admin.id))
logger.info("=" * 60)
logger.info(" Admin account created on first boot")
logger.info(" Email: %s", admin.email)
logger.info(" Password: %s", password)
logger.info(" Change it after login: Settings -> Account")
logger.info("=" * 60)
return
# Admin exists but setup never completed — reset password so operator
# can always find it in the console without needing the CLI.
# Multi-worker guard: if admin was created less than 5s ago, another
# worker just created it and will print the password — skip reset.
admin = await provider.get_user_by_email("admin@deerflow.dev")
if admin and admin.needs_setup:
import time
age = time.time() - admin.created_at.replace(tzinfo=UTC).timestamp()
if age < 30:
return # Just created by another worker in this startup; its password is still valid.
from app.gateway.auth.password import hash_password_async
password = secrets.token_urlsafe(16)
admin.password_hash = await hash_password_async(password)
admin.token_version += 1
await provider.update_user(admin)
logger.info("=" * 60)
logger.info(" Admin account setup incomplete — password reset")
logger.info(" Email: %s", admin.email)
logger.info(" Password: %s", password)
logger.info(" Change it after login: Settings -> Account")
logger.info("=" * 60)
async def _migrate_orphaned_threads(store, admin_user_id: str) -> None:
"""Migrate threads with no user_id to the given admin."""
try:
migrated = 0
results = await store.asearch(("threads",), limit=1000)
for item in results:
metadata = item.value.get("metadata", {})
if not metadata.get("user_id"):
metadata["user_id"] = admin_user_id
item.value["metadata"] = metadata
await store.aput(("threads",), item.key, item.value)
migrated += 1
if migrated:
logger.info("Migrated %d orphaned thread(s) to admin", migrated)
except Exception:
logger.exception("Thread migration failed (non-fatal)")
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan handler."""
@ -52,6 +140,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
async with langgraph_runtime(app):
logger.info("LangGraph runtime initialised")
# Ensure admin user exists (auto-create on first boot)
# Must run AFTER langgraph_runtime so app.state.store is available for thread migration
await _ensure_admin_user(app)
# Start IM channel service if any channels are configured
try:
from app.channels.service import start_channel_service
@ -163,7 +255,30 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
],
)
# CORS is handled by nginx - no need for FastAPI middleware
# Auth: reject unauthenticated requests to non-public paths (fail-closed safety net)
app.add_middleware(AuthMiddleware)
# CSRF: Double Submit Cookie pattern for state-changing requests
app.add_middleware(CSRFMiddleware)
# CORS: when GATEWAY_CORS_ORIGINS is set (dev without nginx), add CORS middleware
cors_origins_env = os.environ.get("GATEWAY_CORS_ORIGINS", "")
if cors_origins_env:
cors_origins = [o.strip() for o in cors_origins_env.split(",") if o.strip()]
# Validate: wildcard origin with credentials is a security misconfiguration
for origin in cors_origins:
if origin == "*":
logger.error("GATEWAY_CORS_ORIGINS contains wildcard '*' with allow_credentials=True. This is a security misconfiguration — browsers will reject the response. Use explicit scheme://host:port origins instead.")
cors_origins = [o for o in cors_origins if o != "*"]
break
if cors_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include routers
# Models API is mounted at /api/models
@ -199,6 +314,9 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
# Assistants compatibility API (LangGraph Platform stub)
app.include_router(assistants_compat.router)
# Auth API is mounted at /api/v1/auth
app.include_router(auth.router)
# Thread Runs API (LangGraph Platform-compatible runs lifecycle)
app.include_router(thread_runs.router)

View File

@ -0,0 +1,42 @@
"""Authentication module for DeerFlow.
This module provides:
- JWT-based authentication
- Provider Factory pattern for extensible auth methods
- UserRepository interface for storage backends (SQLite)
"""
from app.gateway.auth.config import AuthConfig, get_auth_config, set_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
from app.gateway.auth.jwt import TokenPayload, create_access_token, decode_token
from app.gateway.auth.local_provider import LocalAuthProvider
from app.gateway.auth.models import User, UserResponse
from app.gateway.auth.password import hash_password, verify_password
from app.gateway.auth.providers import AuthProvider
from app.gateway.auth.repositories.base import UserRepository
__all__ = [
# Config
"AuthConfig",
"get_auth_config",
"set_auth_config",
# Errors
"AuthErrorCode",
"AuthErrorResponse",
"TokenError",
# JWT
"TokenPayload",
"create_access_token",
"decode_token",
# Password
"hash_password",
"verify_password",
# Models
"User",
"UserResponse",
# Providers
"AuthProvider",
"LocalAuthProvider",
# Repository
"UserRepository",
]

View File

@ -0,0 +1,55 @@
"""Authentication configuration for DeerFlow."""
import logging
import os
import secrets
from dotenv import load_dotenv
from pydantic import BaseModel, Field
load_dotenv()
logger = logging.getLogger(__name__)
class AuthConfig(BaseModel):
"""JWT and auth-related configuration. Parsed once at startup."""
jwt_secret: str = Field(
...,
description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.",
)
token_expiry_days: int = Field(default=7, ge=1, le=30)
users_db_path: str | None = Field(
default=None,
description="Path to users SQLite DB. Defaults to .deer-flow/users.db",
)
oauth_github_client_id: str | None = Field(default=None)
oauth_github_client_secret: str | None = Field(default=None)
_auth_config: AuthConfig | None = None
def get_auth_config() -> AuthConfig:
"""Get the global AuthConfig instance. Parses from env on first call."""
global _auth_config
if _auth_config is None:
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
if not jwt_secret:
jwt_secret = secrets.token_urlsafe(32)
os.environ["AUTH_JWT_SECRET"] = jwt_secret
logger.warning(
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
"Sessions will be invalidated on restart. "
"For production, add AUTH_JWT_SECRET to your .env file: "
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
)
_auth_config = AuthConfig(jwt_secret=jwt_secret)
return _auth_config
def set_auth_config(config: AuthConfig) -> None:
"""Set the global AuthConfig instance (for testing)."""
global _auth_config
_auth_config = config

View File

@ -0,0 +1,44 @@
"""Typed error definitions for auth module.
AuthErrorCode: exhaustive enum of all auth failure conditions.
TokenError: exhaustive enum of JWT decode failures.
AuthErrorResponse: structured error payload for HTTP responses.
"""
from enum import StrEnum
from pydantic import BaseModel
class AuthErrorCode(StrEnum):
"""Exhaustive list of auth error conditions."""
INVALID_CREDENTIALS = "invalid_credentials"
TOKEN_EXPIRED = "token_expired"
TOKEN_INVALID = "token_invalid"
USER_NOT_FOUND = "user_not_found"
EMAIL_ALREADY_EXISTS = "email_already_exists"
PROVIDER_NOT_FOUND = "provider_not_found"
NOT_AUTHENTICATED = "not_authenticated"
class TokenError(StrEnum):
"""Exhaustive list of JWT decode failure reasons."""
EXPIRED = "expired"
INVALID_SIGNATURE = "invalid_signature"
MALFORMED = "malformed"
class AuthErrorResponse(BaseModel):
"""Structured error response — replaces bare `detail` strings."""
code: AuthErrorCode
message: str
def token_error_to_code(err: TokenError) -> AuthErrorCode:
"""Map TokenError to AuthErrorCode — single source of truth."""
if err == TokenError.EXPIRED:
return AuthErrorCode.TOKEN_EXPIRED
return AuthErrorCode.TOKEN_INVALID

View File

@ -0,0 +1,55 @@
"""JWT token creation and verification."""
from datetime import UTC, datetime, timedelta
import jwt
from pydantic import BaseModel
from app.gateway.auth.config import get_auth_config
from app.gateway.auth.errors import TokenError
class TokenPayload(BaseModel):
"""JWT token payload."""
sub: str # user_id
exp: datetime
iat: datetime | None = None
ver: int = 0 # token_version — must match User.token_version
def create_access_token(user_id: str, expires_delta: timedelta | None = None, token_version: int = 0) -> str:
"""Create a JWT access token.
Args:
user_id: The user's UUID as string
expires_delta: Optional custom expiry, defaults to 7 days
token_version: User's current token_version for invalidation
Returns:
Encoded JWT string
"""
config = get_auth_config()
expiry = expires_delta or timedelta(days=config.token_expiry_days)
now = datetime.now(UTC)
payload = {"sub": user_id, "exp": now + expiry, "iat": now, "ver": token_version}
return jwt.encode(payload, config.jwt_secret, algorithm="HS256")
def decode_token(token: str) -> TokenPayload | TokenError:
"""Decode and validate a JWT token.
Returns:
TokenPayload if valid, or a specific TokenError variant.
"""
config = get_auth_config()
try:
payload = jwt.decode(token, config.jwt_secret, algorithms=["HS256"])
return TokenPayload(**payload)
except jwt.ExpiredSignatureError:
return TokenError.EXPIRED
except jwt.InvalidSignatureError:
return TokenError.INVALID_SIGNATURE
except jwt.PyJWTError:
return TokenError.MALFORMED

View File

@ -0,0 +1,87 @@
"""Local email/password authentication provider."""
from app.gateway.auth.models import User
from app.gateway.auth.password import hash_password_async, verify_password_async
from app.gateway.auth.providers import AuthProvider
from app.gateway.auth.repositories.base import UserRepository
class LocalAuthProvider(AuthProvider):
"""Email/password authentication provider using local database."""
def __init__(self, repository: UserRepository):
"""Initialize with a UserRepository.
Args:
repository: UserRepository implementation (SQLite)
"""
self._repo = repository
async def authenticate(self, credentials: dict) -> User | None:
"""Authenticate with email and password.
Args:
credentials: dict with 'email' and 'password' keys
Returns:
User if authentication succeeds, None otherwise
"""
email = credentials.get("email")
password = credentials.get("password")
if not email or not password:
return None
user = await self._repo.get_user_by_email(email)
if user is None:
return None
if user.password_hash is None:
# OAuth user without local password
return None
if not await verify_password_async(password, user.password_hash):
return None
return user
async def get_user(self, user_id: str) -> User | None:
"""Get user by ID."""
return await self._repo.get_user_by_id(user_id)
async def create_user(self, email: str, password: str | None = None, system_role: str = "user", needs_setup: bool = False) -> User:
"""Create a new local user.
Args:
email: User email address
password: Plain text password (will be hashed)
system_role: Role to assign ("admin" or "user")
needs_setup: If True, user must complete setup on first login
Returns:
Created User instance
"""
password_hash = await hash_password_async(password) if password else None
user = User(
email=email,
password_hash=password_hash,
system_role=system_role,
needs_setup=needs_setup,
)
return await self._repo.create_user(user)
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
"""Get user by OAuth provider and ID."""
return await self._repo.get_user_by_oauth(provider, oauth_id)
async def count_users(self) -> int:
"""Return total number of registered users."""
return await self._repo.count_users()
async def update_user(self, user: User) -> User:
"""Update an existing user."""
return await self._repo.update_user(user)
async def get_user_by_email(self, email: str) -> User | None:
"""Get user by email."""
return await self._repo.get_user_by_email(email)

View File

@ -0,0 +1,41 @@
"""User Pydantic models for authentication."""
from datetime import UTC, datetime
from typing import Literal
from uuid import UUID, uuid4
from pydantic import BaseModel, ConfigDict, EmailStr, Field
def _utc_now() -> datetime:
"""Return current UTC time (timezone-aware)."""
return datetime.now(UTC)
class User(BaseModel):
"""Internal user representation."""
model_config = ConfigDict(from_attributes=True)
id: UUID = Field(default_factory=uuid4, description="Primary key")
email: EmailStr = Field(..., description="Unique email address")
password_hash: str | None = Field(None, description="bcrypt hash, nullable for OAuth users")
system_role: Literal["admin", "user"] = Field(default="user")
created_at: datetime = Field(default_factory=_utc_now)
# OAuth linkage (optional)
oauth_provider: str | None = Field(None, description="e.g. 'github', 'google'")
oauth_id: str | None = Field(None, description="User ID from OAuth provider")
# Auth lifecycle
needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes")
token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs")
class UserResponse(BaseModel):
"""Response model for user info endpoint."""
id: str
email: str
system_role: Literal["admin", "user"]
needs_setup: bool = False

View File

@ -0,0 +1,33 @@
"""Password hashing utilities using bcrypt directly."""
import asyncio
import bcrypt
def hash_password(password: str) -> str:
"""Hash a password using bcrypt."""
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash."""
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
async def hash_password_async(password: str) -> str:
"""Hash a password using bcrypt (non-blocking).
Wraps the blocking bcrypt operation in a thread pool to avoid
blocking the event loop during password hashing.
"""
return await asyncio.to_thread(hash_password, password)
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash (non-blocking).
Wraps the blocking bcrypt operation in a thread pool to avoid
blocking the event loop during password verification.
"""
return await asyncio.to_thread(verify_password, plain_password, hashed_password)

View File

@ -0,0 +1,24 @@
"""Auth provider abstraction."""
from abc import ABC, abstractmethod
class AuthProvider(ABC):
"""Abstract base class for authentication providers."""
@abstractmethod
async def authenticate(self, credentials: dict) -> "User | None":
"""Authenticate user with given credentials.
Returns User if authentication succeeds, None otherwise.
"""
...
@abstractmethod
async def get_user(self, user_id: str) -> "User | None":
"""Retrieve user by ID."""
...
# Import User at runtime to avoid circular imports
from app.gateway.auth.models import User # noqa: E402

View File

@ -0,0 +1,82 @@
"""User repository interface for abstracting database operations."""
from abc import ABC, abstractmethod
from app.gateway.auth.models import User
class UserRepository(ABC):
"""Abstract interface for user data storage.
Implement this interface to support different storage backends
(SQLite)
"""
@abstractmethod
async def create_user(self, user: User) -> User:
"""Create a new user.
Args:
user: User object to create
Returns:
Created User with ID assigned
Raises:
ValueError: If email already exists
"""
...
@abstractmethod
async def get_user_by_id(self, user_id: str) -> User | None:
"""Get user by ID.
Args:
user_id: User UUID as string
Returns:
User if found, None otherwise
"""
...
@abstractmethod
async def get_user_by_email(self, email: str) -> User | None:
"""Get user by email.
Args:
email: User email address
Returns:
User if found, None otherwise
"""
...
@abstractmethod
async def update_user(self, user: User) -> User:
"""Update an existing user.
Args:
user: User object with updated fields
Returns:
Updated User
"""
...
@abstractmethod
async def count_users(self) -> int:
"""Return total number of registered users."""
...
@abstractmethod
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
"""Get user by OAuth provider and ID.
Args:
provider: OAuth provider name (e.g. 'github', 'google')
oauth_id: User ID from the OAuth provider
Returns:
User if found, None otherwise
"""
...

View File

@ -0,0 +1,196 @@
"""SQLite implementation of UserRepository."""
import asyncio
import sqlite3
from contextlib import contextmanager
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from uuid import UUID
from app.gateway.auth.config import get_auth_config
from app.gateway.auth.models import User
from app.gateway.auth.repositories.base import UserRepository
_resolved_db_path: Path | None = None
_table_initialized: bool = False
def _get_users_db_path() -> Path:
"""Get the users database path (resolved and cached once)."""
global _resolved_db_path
if _resolved_db_path is not None:
return _resolved_db_path
config = get_auth_config()
if config.users_db_path:
_resolved_db_path = Path(config.users_db_path)
else:
_resolved_db_path = Path(".deer-flow/users.db")
_resolved_db_path.parent.mkdir(parents=True, exist_ok=True)
return _resolved_db_path
def _get_connection() -> sqlite3.Connection:
"""Get a SQLite connection for the users database."""
db_path = _get_users_db_path()
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
return conn
def _init_users_table(conn: sqlite3.Connection) -> None:
"""Initialize the users table if it doesn't exist."""
conn.execute("PRAGMA journal_mode=WAL")
conn.execute(
"""
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
email TEXT UNIQUE NOT NULL,
password_hash TEXT,
system_role TEXT NOT NULL DEFAULT 'user',
created_at REAL NOT NULL,
oauth_provider TEXT,
oauth_id TEXT,
needs_setup INTEGER NOT NULL DEFAULT 0,
token_version INTEGER NOT NULL DEFAULT 0
)
"""
)
# Add unique constraint for OAuth identity to prevent duplicate social logins
conn.execute(
"""
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_oauth_identity
ON users(oauth_provider, oauth_id)
WHERE oauth_provider IS NOT NULL AND oauth_id IS NOT NULL
"""
)
conn.commit()
@contextmanager
def _get_users_conn():
"""Context manager for users database connection."""
global _table_initialized
conn = _get_connection()
try:
if not _table_initialized:
_init_users_table(conn)
_table_initialized = True
yield conn
finally:
conn.close()
class SQLiteUserRepository(UserRepository):
"""SQLite implementation of UserRepository."""
async def create_user(self, user: User) -> User:
"""Create a new user in SQLite."""
return await asyncio.to_thread(self._create_user_sync, user)
def _create_user_sync(self, user: User) -> User:
"""Synchronous user creation (runs in thread pool)."""
with _get_users_conn() as conn:
try:
conn.execute(
"""
INSERT INTO users (id, email, password_hash, system_role, created_at, oauth_provider, oauth_id, needs_setup, token_version)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
str(user.id),
user.email,
user.password_hash,
user.system_role,
datetime.now(UTC).timestamp(),
user.oauth_provider,
user.oauth_id,
int(user.needs_setup),
user.token_version,
),
)
conn.commit()
except sqlite3.IntegrityError as e:
if "UNIQUE constraint failed: users.email" in str(e):
raise ValueError(f"Email already registered: {user.email}") from e
raise
return user
async def get_user_by_id(self, user_id: str) -> User | None:
"""Get user by ID from SQLite."""
return await asyncio.to_thread(self._get_user_by_id_sync, user_id)
def _get_user_by_id_sync(self, user_id: str) -> User | None:
"""Synchronous get by ID (runs in thread pool)."""
with _get_users_conn() as conn:
cursor = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,))
row = cursor.fetchone()
if row is None:
return None
return self._row_to_user(dict(row))
async def get_user_by_email(self, email: str) -> User | None:
"""Get user by email from SQLite."""
return await asyncio.to_thread(self._get_user_by_email_sync, email)
def _get_user_by_email_sync(self, email: str) -> User | None:
"""Synchronous get by email (runs in thread pool)."""
with _get_users_conn() as conn:
cursor = conn.execute("SELECT * FROM users WHERE email = ?", (email,))
row = cursor.fetchone()
if row is None:
return None
return self._row_to_user(dict(row))
async def update_user(self, user: User) -> User:
"""Update an existing user in SQLite."""
return await asyncio.to_thread(self._update_user_sync, user)
def _update_user_sync(self, user: User) -> User:
with _get_users_conn() as conn:
conn.execute(
"UPDATE users SET email = ?, password_hash = ?, system_role = ?, oauth_provider = ?, oauth_id = ?, needs_setup = ?, token_version = ? WHERE id = ?",
(user.email, user.password_hash, user.system_role, user.oauth_provider, user.oauth_id, int(user.needs_setup), user.token_version, str(user.id)),
)
conn.commit()
return user
async def count_users(self) -> int:
"""Return total number of registered users."""
return await asyncio.to_thread(self._count_users_sync)
def _count_users_sync(self) -> int:
with _get_users_conn() as conn:
cursor = conn.execute("SELECT COUNT(*) FROM users")
return cursor.fetchone()[0]
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
"""Get user by OAuth provider and ID from SQLite."""
return await asyncio.to_thread(self._get_user_by_oauth_sync, provider, oauth_id)
def _get_user_by_oauth_sync(self, provider: str, oauth_id: str) -> User | None:
"""Synchronous get by OAuth (runs in thread pool)."""
with _get_users_conn() as conn:
cursor = conn.execute(
"SELECT * FROM users WHERE oauth_provider = ? AND oauth_id = ?",
(provider, oauth_id),
)
row = cursor.fetchone()
if row is None:
return None
return self._row_to_user(dict(row))
@staticmethod
def _row_to_user(row: dict[str, Any]) -> User:
"""Convert a database row to a User model."""
return User(
id=UUID(row["id"]),
email=row["email"],
password_hash=row["password_hash"],
system_role=row["system_role"],
created_at=datetime.fromtimestamp(row["created_at"], tz=UTC),
oauth_provider=row.get("oauth_provider"),
oauth_id=row.get("oauth_id"),
needs_setup=bool(row["needs_setup"]),
token_version=int(row["token_version"]),
)

View File

@ -0,0 +1,66 @@
"""CLI tool to reset admin password.
Usage:
python -m app.gateway.auth.reset_admin
python -m app.gateway.auth.reset_admin --email admin@example.com
"""
import argparse
import secrets
import sys
from app.gateway.auth.password import hash_password
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
def main() -> None:
parser = argparse.ArgumentParser(description="Reset admin password")
parser.add_argument("--email", help="Admin email (default: first admin found)")
args = parser.parse_args()
repo = SQLiteUserRepository()
# Find admin user synchronously (CLI context, no event loop)
import asyncio
user = asyncio.run(_find_admin(repo, args.email))
if user is None:
if args.email:
print(f"Error: user '{args.email}' not found.", file=sys.stderr)
else:
print("Error: no admin user found.", file=sys.stderr)
sys.exit(1)
new_password = secrets.token_urlsafe(16)
user.password_hash = hash_password(new_password)
user.token_version += 1
user.needs_setup = True
asyncio.run(repo.update_user(user))
print(f"Password reset for: {user.email}")
print(f"New password: {new_password}")
print("Next login will require setup (new email + password).")
async def _find_admin(repo: SQLiteUserRepository, email: str | None):
if email:
return await repo.get_user_by_email(email)
# Find first admin
import asyncio
from app.gateway.auth.repositories.sqlite import _get_users_conn
def _find_sync():
with _get_users_conn() as conn:
cursor = conn.execute("SELECT id FROM users WHERE system_role = 'admin' LIMIT 1")
row = cursor.fetchone()
return dict(row)["id"] if row else None
admin_id = await asyncio.to_thread(_find_sync)
if admin_id:
return await repo.get_user_by_id(admin_id)
return None
if __name__ == "__main__":
main()

View File

@ -0,0 +1,71 @@
"""Global authentication middleware — fail-closed safety net.
Rejects unauthenticated requests to non-public paths with 401.
Fine-grained permission checks remain in authz.py decorators.
"""
from collections.abc import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from starlette.types import ASGIApp
from app.gateway.auth.errors import AuthErrorCode
# Paths that never require authentication.
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
"/health",
"/docs",
"/redoc",
"/openapi.json",
)
# Exact auth paths that are public (login/register/status check).
# /api/v1/auth/me, /api/v1/auth/change-password etc. are NOT public.
_PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
{
"/api/v1/auth/login/local",
"/api/v1/auth/register",
"/api/v1/auth/logout",
"/api/v1/auth/setup-status",
}
)
def _is_public(path: str) -> bool:
stripped = path.rstrip("/")
if stripped in _PUBLIC_EXACT_PATHS:
return True
return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES)
class AuthMiddleware(BaseHTTPMiddleware):
"""Coarse-grained auth gate: reject requests without a valid session cookie.
This does NOT verify JWT signature or user existence that is the job of
``get_current_user_from_request`` in deps.py (called by ``@require_auth``).
The middleware only checks *presence* of the cookie so that new endpoints
that forget ``@require_auth`` are not completely exposed.
"""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
if _is_public(request.url.path):
return await call_next(request)
# Non-public path: require session cookie
if not request.cookies.get("access_token"):
return JSONResponse(
status_code=401,
content={
"detail": {
"code": AuthErrorCode.NOT_AUTHENTICATED,
"message": "Authentication required",
}
},
)
return await call_next(request)

View File

@ -0,0 +1,261 @@
"""Authorization decorators and context for DeerFlow.
Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blob/main/libs/sdk-py/langgraph_sdk/auth/__init__.py
**Usage:**
1. Use ``@require_auth`` on routes that need authentication
2. Use ``@require_permission("resource", "action", filter_key=...)`` for permission checks
3. The decorator chain processes from bottom to top
**Example:**
@router.get("/{thread_id}")
@require_auth
@require_permission("threads", "read", owner_check=True)
async def get_thread(thread_id: str, request: Request):
# User is authenticated and has threads:read permission
...
**Permission Model:**
- threads:read - View thread
- threads:write - Create/update thread
- threads:delete - Delete thread
- runs:create - Run agent
- runs:read - View run
- runs:cancel - Cancel run
"""
from __future__ import annotations
import functools
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
from fastapi import HTTPException, Request
if TYPE_CHECKING:
from app.gateway.auth.models import User
P = ParamSpec("P")
T = TypeVar("T")
# Permission constants
class Permissions:
"""Permission constants for resource:action format."""
# Threads
THREADS_READ = "threads:read"
THREADS_WRITE = "threads:write"
THREADS_DELETE = "threads:delete"
# Runs
RUNS_CREATE = "runs:create"
RUNS_READ = "runs:read"
RUNS_CANCEL = "runs:cancel"
class AuthContext:
"""Authentication context for the current request.
Stored in request.state.auth after require_auth decoration.
Attributes:
user: The authenticated user, or None if anonymous
permissions: List of permission strings (e.g., "threads:read")
"""
__slots__ = ("user", "permissions")
def __init__(self, user: User | None = None, permissions: list[str] | None = None):
self.user = user
self.permissions = permissions or []
@property
def is_authenticated(self) -> bool:
"""Check if user is authenticated."""
return self.user is not None
def has_permission(self, resource: str, action: str) -> bool:
"""Check if context has permission for resource:action.
Args:
resource: Resource name (e.g., "threads")
action: Action name (e.g., "read")
Returns:
True if user has permission
"""
permission = f"{resource}:{action}"
return permission in self.permissions
def require_user(self) -> User:
"""Get user or raise 401.
Raises:
HTTPException 401 if not authenticated
"""
if not self.user:
raise HTTPException(status_code=401, detail="Authentication required")
return self.user
def get_auth_context(request: Request) -> AuthContext | None:
"""Get AuthContext from request state."""
return getattr(request.state, "auth", None)
_ALL_PERMISSIONS: list[str] = [
Permissions.THREADS_READ,
Permissions.THREADS_WRITE,
Permissions.THREADS_DELETE,
Permissions.RUNS_CREATE,
Permissions.RUNS_READ,
Permissions.RUNS_CANCEL,
]
async def _authenticate(request: Request) -> AuthContext:
"""Authenticate request and return AuthContext.
Delegates to deps.get_optional_user_from_request() for the JWTUser pipeline.
Returns AuthContext with user=None for anonymous requests.
"""
from app.gateway.deps import get_optional_user_from_request
user = await get_optional_user_from_request(request)
if user is None:
return AuthContext(user=None, permissions=[])
# In future, permissions could be stored in user record
return AuthContext(user=user, permissions=_ALL_PERMISSIONS)
def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
"""Decorator that authenticates the request and sets AuthContext.
Must be placed ABOVE other decorators (executes after them).
Usage:
@router.get("/{thread_id}")
@require_auth # Bottom decorator (executes first after permission check)
@require_permission("threads", "read")
async def get_thread(thread_id: str, request: Request):
auth: AuthContext = request.state.auth
...
Raises:
ValueError: If 'request' parameter is missing
"""
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
request = kwargs.get("request")
if request is None:
raise ValueError("require_auth decorator requires 'request' parameter")
# Authenticate and set context
auth_context = await _authenticate(request)
request.state.auth = auth_context
return await func(*args, **kwargs)
return wrapper
def require_permission(
resource: str,
action: str,
owner_check: bool = False,
owner_filter_key: str = "user_id",
inject_record: bool = False,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Decorator that checks permission for resource:action.
Must be used AFTER @require_auth.
Args:
resource: Resource name (e.g., "threads", "runs")
action: Action name (e.g., "read", "write", "delete")
owner_check: If True, validates that the current user owns the resource.
Requires 'thread_id' path parameter and performs ownership check.
owner_filter_key: Field name for ownership filter (default: "user_id")
inject_record: If True and owner_check is True, injects the thread record
into kwargs['thread_record'] for use in the handler.
Usage:
# Simple permission check
@require_permission("threads", "read")
async def get_thread(thread_id: str, request: Request):
...
# With ownership check (for /threads/{thread_id} endpoints)
@require_permission("threads", "delete", owner_check=True)
async def delete_thread(thread_id: str, request: Request):
...
# With ownership check and record injection
@require_permission("threads", "delete", owner_check=True, inject_record=True)
async def delete_thread(thread_id: str, request: Request, thread_record: dict = None):
# thread_record is injected if found
...
Raises:
HTTPException 401: If authentication required but user is anonymous
HTTPException 403: If user lacks permission
HTTPException 404: If owner_check=True but user doesn't own the thread
ValueError: If owner_check=True but 'thread_id' parameter is missing
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
request = kwargs.get("request")
if request is None:
raise ValueError("require_permission decorator requires 'request' parameter")
auth: AuthContext = getattr(request.state, "auth", None)
if auth is None:
auth = await _authenticate(request)
request.state.auth = auth
if not auth.is_authenticated:
raise HTTPException(status_code=401, detail="Authentication required")
# Check permission
if not auth.has_permission(resource, action):
raise HTTPException(
status_code=403,
detail=f"Permission denied: {resource}:{action}",
)
# Owner check for thread-specific resources
if owner_check:
thread_id = kwargs.get("thread_id")
if thread_id is None:
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
# Get thread and verify ownership
from app.gateway.routers.threads import _store_get, get_store
store = get_store(request)
if store is not None:
record = await _store_get(store, thread_id)
if record:
owner_id = record.get("metadata", {}).get(owner_filter_key)
if owner_id and owner_id != str(auth.user.id):
raise HTTPException(
status_code=404,
detail=f"Thread {thread_id} not found",
)
# Inject record if requested
if inject_record:
kwargs["thread_record"] = record
return await func(*args, **kwargs)
return wrapper
return decorator

View File

@ -0,0 +1,112 @@
"""CSRF protection middleware for FastAPI.
Per RFC-001:
State-changing operations require CSRF protection.
"""
import secrets
from collections.abc import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from starlette.types import ASGIApp
CSRF_COOKIE_NAME = "csrf_token"
CSRF_HEADER_NAME = "X-CSRF-Token"
CSRF_TOKEN_LENGTH = 64 # bytes
def is_secure_request(request: Request) -> bool:
"""Detect whether the original client request was made over HTTPS."""
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
def generate_csrf_token() -> str:
"""Generate a secure random CSRF token."""
return secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
def should_check_csrf(request: Request) -> bool:
"""Determine if a request needs CSRF validation.
CSRF is checked for state-changing methods (POST, PUT, DELETE, PATCH).
GET, HEAD, OPTIONS, and TRACE are exempt per RFC 7231.
"""
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
return False
path = request.url.path.rstrip("/")
# Exempt /api/v1/auth/me endpoint
if path == "/api/v1/auth/me":
return False
return True
_AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
{
"/api/v1/auth/login/local",
"/api/v1/auth/logout",
"/api/v1/auth/register",
}
)
def is_auth_endpoint(request: Request) -> bool:
"""Check if the request is to an auth endpoint.
Auth endpoints don't need CSRF validation on first call (no token).
"""
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
class CSRFMiddleware(BaseHTTPMiddleware):
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
_is_auth = is_auth_endpoint(request)
if should_check_csrf(request) and not _is_auth:
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
header_token = request.headers.get(CSRF_HEADER_NAME)
if not cookie_token or not header_token:
return JSONResponse(
status_code=403,
content={"detail": "CSRF token missing. Include X-CSRF-Token header."},
)
if not secrets.compare_digest(cookie_token, header_token):
return JSONResponse(
status_code=403,
content={"detail": "CSRF token mismatch."},
)
response = await call_next(request)
# For auth endpoints that set up session, also set CSRF cookie
if _is_auth and request.method == "POST":
# Generate a new CSRF token for the session
csrf_token = generate_csrf_token()
is_https = is_secure_request(request)
response.set_cookie(
key=CSRF_COOKIE_NAME,
value=csrf_token,
httponly=False, # Must be JS-readable for Double Submit Cookie pattern
secure=is_https,
samesite="strict",
)
return response
def get_csrf_token(request: Request) -> str | None:
"""Get the CSRF token from the current request's cookies.
This is useful for server-side rendering where you need to embed
token in forms or headers.
"""
return request.cookies.get(CSRF_COOKIE_NAME)

View File

@ -3,38 +3,22 @@
**Getters** (used by routers): raise 503 when a required dependency is
missing, except ``get_store`` which returns ``None``.
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack``.
"""
from __future__ import annotations
from collections.abc import AsyncGenerator
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING
from fastapi import FastAPI, HTTPException, Request
from deerflow.runtime import RunManager, StreamBridge
@asynccontextmanager
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
"""Bootstrap and tear down all LangGraph runtime singletons.
Usage in ``app.py``::
async with langgraph_runtime(app):
yield
"""
from deerflow.agents.checkpointer.async_provider import make_checkpointer
from deerflow.runtime import make_store, make_stream_bridge
async with AsyncExitStack() as stack:
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
app.state.store = await stack.enter_async_context(make_store())
app.state.run_manager = RunManager()
yield
if TYPE_CHECKING:
from app.gateway.auth.local_provider import LocalAuthProvider
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
# ---------------------------------------------------------------------------
# Getters called by routers per-request
@ -68,3 +52,102 @@ def get_checkpointer(request: Request):
def get_store(request: Request):
"""Return the global store (may be ``None`` if not configured)."""
return getattr(request.app.state, "store", None)
# ---------------------------------------------------------------------------
# Auth helpers (used by authz.py)
# ---------------------------------------------------------------------------
# Cached singletons to avoid repeated instantiation per request
_cached_local_provider: LocalAuthProvider | None = None
_cached_repo: SQLiteUserRepository | None = None
def get_local_provider() -> LocalAuthProvider:
"""Get or create the cached LocalAuthProvider singleton."""
global _cached_local_provider, _cached_repo
if _cached_repo is None:
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
_cached_repo = SQLiteUserRepository()
if _cached_local_provider is None:
from app.gateway.auth.local_provider import LocalAuthProvider
_cached_local_provider = LocalAuthProvider(repository=_cached_repo)
return _cached_local_provider
async def get_current_user_from_request(request: Request):
"""Get the current authenticated user from the request cookie.
Raises HTTPException 401 if not authenticated.
"""
from app.gateway.auth import decode_token
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
access_token = request.cookies.get("access_token")
if not access_token:
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(),
)
payload = decode_token(access_token)
if isinstance(payload, TokenError):
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=token_error_to_code(payload), message=f"Token error: {payload.value}").model_dump(),
)
provider = get_local_provider()
user = await provider.get_user(payload.sub)
if user is None:
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(),
)
# Token version mismatch → password was changed, token is stale
if user.token_version != payload.ver:
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=AuthErrorCode.TOKEN_INVALID, message="Token revoked (password changed)").model_dump(),
)
return user
async def get_optional_user_from_request(request: Request):
"""Get optional authenticated user from request.
Returns None if not authenticated.
"""
try:
return await get_current_user_from_request(request)
except HTTPException:
return None
# ---------------------------------------------------------------------------
# Runtime bootstrap
# ---------------------------------------------------------------------------
@asynccontextmanager
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
"""Bootstrap and tear down all LangGraph runtime singletons.
Usage in ``app.py``::
async with langgraph_runtime(app):
yield
"""
from deerflow.agents.checkpointer.async_provider import make_checkpointer
from deerflow.runtime import make_store, make_stream_bridge
async with AsyncExitStack() as stack:
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
app.state.store = await stack.enter_async_context(make_store())
app.state.run_manager = RunManager()
yield

View File

@ -0,0 +1,106 @@
"""LangGraph Server auth handler — shares JWT logic with Gateway.
Loaded by LangGraph Server via langgraph.json ``auth.path``.
Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway,
so both modes validate tokens with the same secret and rules.
Two layers:
1. @auth.authenticate validates JWT cookie, extracts user_id,
and enforces CSRF on state-changing methods (POST/PUT/DELETE/PATCH)
2. @auth.on returns metadata filter so each user only sees own threads
"""
import secrets
from langgraph_sdk import Auth
from app.gateway.auth.errors import TokenError
from app.gateway.auth.jwt import decode_token
from app.gateway.deps import get_local_provider
auth = Auth()
# Methods that require CSRF validation (state-changing per RFC 7231).
_CSRF_METHODS = frozenset({"POST", "PUT", "DELETE", "PATCH"})
def _check_csrf(request) -> None:
"""Enforce Double Submit Cookie CSRF check for state-changing requests.
Mirrors Gateway's CSRFMiddleware logic so that LangGraph routes
proxied directly by nginx have the same CSRF protection.
"""
method = getattr(request, "method", "") or ""
if method.upper() not in _CSRF_METHODS:
return
cookie_token = request.cookies.get("csrf_token")
header_token = request.headers.get("x-csrf-token")
if not cookie_token or not header_token:
raise Auth.exceptions.HTTPException(
status_code=403,
detail="CSRF token missing. Include X-CSRF-Token header.",
)
if not secrets.compare_digest(cookie_token, header_token):
raise Auth.exceptions.HTTPException(
status_code=403,
detail="CSRF token mismatch.",
)
@auth.authenticate
async def authenticate(request):
"""Validate the session cookie, decode JWT, and check token_version.
Same validation chain as Gateway's get_current_user_from_request:
cookie decode JWT DB lookup token_version match
Also enforces CSRF on state-changing methods.
"""
# CSRF check before authentication so forged cross-site requests
# are rejected early, even if the cookie carries a valid JWT.
_check_csrf(request)
token = request.cookies.get("access_token")
if not token:
raise Auth.exceptions.HTTPException(
status_code=401,
detail="Not authenticated",
)
payload = decode_token(token)
if isinstance(payload, TokenError):
raise Auth.exceptions.HTTPException(
status_code=401,
detail=f"Token error: {payload.value}",
)
user = await get_local_provider().get_user(payload.sub)
if user is None:
raise Auth.exceptions.HTTPException(
status_code=401,
detail="User not found",
)
if user.token_version != payload.ver:
raise Auth.exceptions.HTTPException(
status_code=401,
detail="Token revoked (password changed)",
)
return payload.sub
@auth.on
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
"""Inject user_id metadata on writes; filter by user_id on reads.
Gateway stores thread ownership as ``metadata.user_id``.
This handler ensures LangGraph Server enforces the same isolation.
"""
# On create/update: stamp user_id into metadata
metadata = value.setdefault("metadata", {})
metadata["user_id"] = ctx.user.identity
# Return filter dict — LangGraph applies it to search/read/delete
return {"user_id": ctx.user.identity}

View File

@ -1,3 +1,3 @@
from . import artifacts, assistants_compat, mcp, models, skills, suggestions, thread_runs, threads, uploads
from . import artifacts, assistants_compat, auth, mcp, models, skills, suggestions, thread_runs, threads, uploads
__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"]
__all__ = ["artifacts", "assistants_compat", "auth", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"]

View File

@ -24,7 +24,7 @@ class AgentResponse(BaseModel):
description: str = Field(default="", description="Agent description")
model: str | None = Field(default=None, description="Optional model override")
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
soul: str | None = Field(default=None, description="SOUL.md content (included on GET /{name})")
soul: str | None = Field(default=None, description="SOUL.md content")
class AgentsListResponse(BaseModel):
@ -92,17 +92,17 @@ def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False
"/agents",
response_model=AgentsListResponse,
summary="List Custom Agents",
description="List all custom agents available in the agents directory.",
description="List all custom agents available in the agents directory, including their soul content.",
)
async def list_agents() -> AgentsListResponse:
"""List all custom agents.
Returns:
List of all custom agents with their metadata (without soul content).
List of all custom agents with their metadata and soul content.
"""
try:
agents = list_custom_agents()
return AgentsListResponse(agents=[_agent_config_to_response(a) for a in agents])
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
except Exception as e:
logger.error(f"Failed to list agents: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to list agents: {str(e)}")

View File

@ -0,0 +1,303 @@
"""Authentication endpoints."""
import logging
import time
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel, EmailStr, Field
from app.gateway.auth import (
UserResponse,
create_access_token,
)
from app.gateway.auth.config import get_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
from app.gateway.csrf_middleware import is_secure_request
from app.gateway.deps import get_current_user_from_request, get_local_provider
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
# ── Request/Response Models ──────────────────────────────────────────────
class LoginResponse(BaseModel):
"""Response model for login — token only lives in HttpOnly cookie."""
expires_in: int # seconds
needs_setup: bool = False
class RegisterRequest(BaseModel):
"""Request model for user registration."""
email: EmailStr
password: str = Field(..., min_length=8)
class ChangePasswordRequest(BaseModel):
"""Request model for password change (also handles setup flow)."""
current_password: str
new_password: str = Field(..., min_length=8)
new_email: EmailStr | None = None
class MessageResponse(BaseModel):
"""Generic message response."""
message: str
# ── Helpers ───────────────────────────────────────────────────────────────
def _set_session_cookie(response: Response, token: str, request: Request) -> None:
"""Set the access_token HttpOnly cookie on the response."""
config = get_auth_config()
is_https = is_secure_request(request)
response.set_cookie(
key="access_token",
value=token,
httponly=True,
secure=is_https,
samesite="lax",
max_age=config.token_expiry_days * 24 * 3600 if is_https else None,
)
# ── Rate Limiting ────────────────────────────────────────────────────────
# In-process dict — not shared across workers. Sufficient for single-worker deployments.
_MAX_LOGIN_ATTEMPTS = 5
_LOCKOUT_SECONDS = 300 # 5 minutes
# ip → (fail_count, lock_until_timestamp)
_login_attempts: dict[str, tuple[int, float]] = {}
def _get_client_ip(request: Request) -> str:
"""Extract the real client IP for rate limiting.
Uses ``X-Real-IP`` header set by nginx (``proxy_set_header X-Real-IP
$remote_addr``). Nginx unconditionally overwrites any client-supplied
``X-Real-IP``, so the value seen by Gateway is always the TCP peer IP
that nginx observed it cannot be spoofed by the client.
``request.client.host`` is NOT reliable because uvicorn's default
``proxy_headers=True`` replaces it with the *first* entry from
``X-Forwarded-For``, which IS client-spoofable.
``X-Forwarded-For`` is intentionally NOT used for the same reason.
"""
real_ip = request.headers.get("x-real-ip", "").strip()
if real_ip:
return real_ip
# Fallback: direct connection without nginx (e.g. unit tests, dev).
return request.client.host if request.client else "unknown"
def _check_rate_limit(ip: str) -> None:
"""Raise 429 if the IP is currently locked out."""
record = _login_attempts.get(ip)
if record is None:
return
fail_count, lock_until = record
if fail_count >= _MAX_LOGIN_ATTEMPTS:
if time.time() < lock_until:
raise HTTPException(
status_code=429,
detail="Too many login attempts. Try again later.",
)
del _login_attempts[ip]
_MAX_TRACKED_IPS = 10000
def _record_login_failure(ip: str) -> None:
"""Record a failed login attempt for the given IP."""
# Evict expired lockouts when dict grows too large
if len(_login_attempts) >= _MAX_TRACKED_IPS:
now = time.time()
expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t]
for k in expired:
del _login_attempts[k]
# If still too large, evict cheapest-to-lose half: below-threshold
# IPs (lock_until=0.0) sort first, then earliest-expiring lockouts.
if len(_login_attempts) >= _MAX_TRACKED_IPS:
by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1])
for k, _ in by_time[: len(by_time) // 2]:
del _login_attempts[k]
record = _login_attempts.get(ip)
if record is None:
_login_attempts[ip] = (1, 0.0)
else:
new_count = record[0] + 1
lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0
_login_attempts[ip] = (new_count, lock_until)
def _record_login_success(ip: str) -> None:
"""Clear failure counter for the given IP on successful login."""
_login_attempts.pop(ip, None)
# ── Endpoints ─────────────────────────────────────────────────────────────
@router.post("/login/local", response_model=LoginResponse)
async def login_local(
request: Request,
response: Response,
form_data: OAuth2PasswordRequestForm = Depends(),
):
"""Local email/password login."""
client_ip = _get_client_ip(request)
_check_rate_limit(client_ip)
user = await get_local_provider().authenticate({"email": form_data.username, "password": form_data.password})
if user is None:
_record_login_failure(client_ip)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Incorrect email or password").model_dump(),
)
_record_login_success(client_ip)
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return LoginResponse(
expires_in=get_auth_config().token_expiry_days * 24 * 3600,
needs_setup=user.needs_setup,
)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(request: Request, response: Response, body: RegisterRequest):
"""Register a new user account (always 'user' role).
Admin is auto-created on first boot. This endpoint creates regular users.
Auto-login by setting the session cookie.
"""
try:
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="user")
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already registered").model_dump(),
)
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
@router.post("/logout", response_model=MessageResponse)
async def logout(request: Request, response: Response):
"""Logout current user by clearing the cookie."""
response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax")
return MessageResponse(message="Successfully logged out")
@router.post("/change-password", response_model=MessageResponse)
async def change_password(request: Request, response: Response, body: ChangePasswordRequest):
"""Change password for the currently authenticated user.
Also handles the first-boot setup flow:
- If new_email is provided, updates email (checks uniqueness)
- If user.needs_setup is True and new_email is given, clears needs_setup
- Always increments token_version to invalidate old sessions
- Re-issues session cookie with new token_version
"""
from app.gateway.auth.password import hash_password_async, verify_password_async
user = await get_current_user_from_request(request)
if user.password_hash is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
if not await verify_password_async(body.current_password, user.password_hash):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Current password is incorrect").model_dump())
provider = get_local_provider()
# Update email if provided
if body.new_email is not None:
existing = await provider.get_user_by_email(body.new_email)
if existing and str(existing.id) != str(user.id):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already in use").model_dump())
user.email = body.new_email
# Update password + bump version
user.password_hash = await hash_password_async(body.new_password)
user.token_version += 1
# Clear setup flag if this is the setup flow
if user.needs_setup and body.new_email is not None:
user.needs_setup = False
await provider.update_user(user)
# Re-issue cookie with new token_version
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return MessageResponse(message="Password changed successfully")
@router.get("/me", response_model=UserResponse)
async def get_me(request: Request):
"""Get current authenticated user info."""
user = await get_current_user_from_request(request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
@router.get("/setup-status")
async def setup_status():
"""Check if admin account exists. Always False after first boot."""
user_count = await get_local_provider().count_users()
return {"needs_setup": user_count == 0}
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
@router.get("/oauth/{provider}")
async def oauth_login(provider: str):
"""Initiate OAuth login flow.
Redirects to the OAuth provider's authorization URL.
Currently a placeholder - requires OAuth provider implementation.
"""
if provider not in ["github", "google"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported OAuth provider: {provider}",
)
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="OAuth login not yet implemented",
)
@router.get("/callback/{provider}")
async def oauth_callback(provider: str, code: str, state: str):
"""OAuth callback endpoint.
Handles the OAuth provider's callback after user authorization.
Currently a placeholder.
"""
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="OAuth callback not yet implemented",
)

View File

@ -2,6 +2,7 @@ import json
import logging
from fastapi import APIRouter
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from deerflow.models import create_chat_model
@ -106,22 +107,21 @@ async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> S
if not conversation:
return SuggestionsResponse(suggestions=[])
prompt = (
system_instruction = (
"You are generating follow-up questions to help the user continue the conversation.\n"
f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n"
"Requirements:\n"
"- Questions must be relevant to the conversation.\n"
"- Questions must be relevant to the preceding conversation.\n"
"- Questions must be written in the same language as the user.\n"
"- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n"
"- Do NOT include numbering, markdown, or any extra text.\n"
"- Output MUST be a JSON array of strings only.\n\n"
"Conversation:\n"
f"{conversation}\n"
"- Output MUST be a JSON array of strings only.\n"
)
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
try:
model = create_chat_model(name=request.model_name, thinking_enabled=False)
response = model.invoke(prompt)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
raw = _extract_response_text(response.content)
suggestions = _parse_json_string_list(raw) or []
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]

View File

@ -19,6 +19,7 @@ from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel, Field
from app.gateway.authz import require_auth, require_permission
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
from app.gateway.services import sse_consumer, start_run
from deerflow.runtime import RunRecord, serialize_channel_values
@ -92,19 +93,28 @@ def _record_to_response(record: RunRecord) -> RunResponse:
@router.post("/{thread_id}/runs", response_model=RunResponse)
@require_auth
@require_permission("runs", "create", owner_check=True)
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
"""Create a background run (returns immediately)."""
"""Create a background run (returns immediately).
Multi-tenant isolation: only the thread owner can create runs.
"""
record = await start_run(body, thread_id, request)
return _record_to_response(record)
@router.post("/{thread_id}/runs/stream")
@require_auth
@require_permission("runs", "create", owner_check=True)
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
"""Create a run and stream events via SSE.
The response includes a ``Content-Location`` header with the run's
resource URL, matching the LangGraph Platform protocol. The
``useStream`` React hook uses this to extract run metadata.
Multi-tenant isolation: only the thread owner can stream runs.
"""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
@ -125,8 +135,13 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -
@router.post("/{thread_id}/runs/wait", response_model=dict)
@require_auth
@require_permission("runs", "create", owner_check=True)
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
"""Create a run and block until it completes, returning the final state."""
"""Create a run and block until it completes, returning the final state.
Multi-tenant isolation: only the thread owner can wait for runs.
"""
record = await start_run(body, thread_id, request)
if record.task is not None:
@ -150,16 +165,26 @@ async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) ->
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
@require_auth
@require_permission("runs", "read", owner_check=True)
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
"""List all runs for a thread."""
"""List all runs for a thread.
Multi-tenant isolation: only the thread owner can list runs.
"""
run_mgr = get_run_manager(request)
records = await run_mgr.list_by_thread(thread_id)
return [_record_to_response(r) for r in records]
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
@require_auth
@require_permission("runs", "read", owner_check=True)
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
"""Get details of a specific run."""
"""Get details of a specific run.
Multi-tenant isolation: only the thread owner can get runs.
"""
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
if record is None or record.thread_id != thread_id:
@ -168,6 +193,8 @@ async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
@router.post("/{thread_id}/runs/{run_id}/cancel")
@require_auth
@require_permission("runs", "cancel", owner_check=True)
async def cancel_run(
thread_id: str,
run_id: str,
@ -181,6 +208,8 @@ async def cancel_run(
- action=rollback: Stop execution, revert to pre-run checkpoint state
- wait=true: Block until the run fully stops, return 204
- wait=false: Return immediately with 202
Multi-tenant isolation: only the thread owner can cancel runs.
"""
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
@ -205,8 +234,13 @@ async def cancel_run(
@router.get("/{thread_id}/runs/{run_id}/join")
@require_auth
@require_permission("runs", "read", owner_check=True)
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
"""Join an existing run's SSE stream."""
"""Join an existing run's SSE stream.
Multi-tenant isolation: only the thread owner can join runs.
"""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)

View File

@ -13,17 +13,26 @@ matching the LangGraph Platform wire format expected by the
from __future__ import annotations
import logging
import re
import time
import uuid
from typing import Any
from typing import Annotated, Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
from fastapi import APIRouter, HTTPException, Path, Request
from pydantic import BaseModel, Field, field_validator
from app.gateway.authz import require_auth, require_permission
from app.gateway.deps import get_checkpointer, get_store
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values
# ---------------------------------------------------------------------------
# Thread ID validation (prevents log-injection via control characters)
# ---------------------------------------------------------------------------
_UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$")
ThreadId = Annotated[str, Path(description="Thread UUID", pattern=_UUID_RE.pattern)]
# ---------------------------------------------------------------------------
# Store namespace
# ---------------------------------------------------------------------------
@ -65,6 +74,13 @@ class ThreadCreateRequest(BaseModel):
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
@field_validator("thread_id")
@classmethod
def _validate_uuid(cls, v: str | None) -> str | None:
if v is not None and not _UUID_RE.match(v):
raise ValueError("thread_id must be a valid UUID")
return v
class ThreadSearchRequest(BaseModel):
"""Request body for searching threads."""
@ -215,17 +231,23 @@ def _derive_thread_status(checkpoint_tuple) -> str:
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
@require_auth
@require_permission("threads", "delete", owner_check=True)
async def delete_thread_data(thread_id: ThreadId, request: Request) -> ThreadDeleteResponse:
"""Delete local persisted filesystem data for a thread.
Cleans DeerFlow-managed thread directories, removes checkpoint data,
and removes the thread record from the Store.
Multi-tenant isolation: only the thread owner can delete their thread.
"""
store = get_store(request)
checkpointer = get_checkpointer(request)
# Clean local filesystem
response = _delete_thread_data(thread_id)
# Remove from Store (best-effort)
store = get_store(request)
if store is not None:
try:
await store.adelete(THREADS_NS, thread_id)
@ -233,7 +255,6 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
logger.debug("Could not delete store record for thread %s (not critical)", thread_id)
# Remove checkpoints (best-effort)
checkpointer = getattr(request.app.state, "checkpointer", None)
if checkpointer is not None:
try:
if hasattr(checkpointer, "adelete_thread"):
@ -251,12 +272,23 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
The thread record is written to the Store (for fast listing) and an
empty checkpoint is written to the checkpointer (for state reads).
Idempotent: returns the existing record when ``thread_id`` already exists.
If authenticated, the user's ID is injected into the thread metadata
for multi-tenant isolation.
"""
store = get_store(request)
checkpointer = get_checkpointer(request)
thread_id = body.thread_id or str(uuid.uuid4())
now = time.time()
from app.gateway.deps import get_optional_user_from_request
user = await get_optional_user_from_request(request)
thread_metadata = dict(body.metadata)
if user:
thread_metadata["user_id"] = str(user.id)
# Idempotency: return existing record from Store when already present
if store is not None:
existing_record = await _store_get(store, thread_id)
@ -279,7 +311,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
"status": "idle",
"created_at": now,
"updated_at": now,
"metadata": body.metadata,
"metadata": thread_metadata,
},
)
except Exception:
@ -296,7 +328,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
"source": "input",
"writes": None,
"parents": {},
**body.metadata,
**thread_metadata,
"created_at": now,
}
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
@ -304,13 +336,13 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
logger.exception("Failed to create checkpoint for thread %s", thread_id)
raise HTTPException(status_code=500, detail="Failed to create thread")
logger.info("Thread created: %s", thread_id)
logger.info("Thread created: %s (user_id=%s)", thread_id, thread_metadata.get("user_id"))
return ThreadResponse(
thread_id=thread_id,
status="idle",
created_at=str(now),
updated_at=str(now),
metadata=body.metadata,
metadata=thread_metadata,
)
@ -330,10 +362,18 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
newly found thread is immediately written to the Store so that the next
search skips Phase 2 for that thread the Store converges to a full
index over time without a one-shot migration job.
If authenticated, only threads belonging to the current user are returned
(enforced by user_id metadata filter for multi-tenant isolation).
"""
store = get_store(request)
checkpointer = get_checkpointer(request)
from app.gateway.deps import get_optional_user_from_request
user = await get_optional_user_from_request(request)
user_id = str(user.id) if user else None
# -----------------------------------------------------------------------
# Phase 1: Store
# -----------------------------------------------------------------------
@ -409,6 +449,10 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
# -----------------------------------------------------------------------
results = list(merged.values())
# Multi-tenant isolation: filter by user_id if authenticated
if user_id:
results = [r for r in results if r.metadata.get("user_id") == user_id]
if body.metadata:
results = [r for r in results if all(r.metadata.get(k) == v for k, v in body.metadata.items())]
@ -420,13 +464,20 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
@router.patch("/{thread_id}", response_model=ThreadResponse)
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
"""Merge metadata into a thread record."""
@require_auth
@require_permission("threads", "write", owner_check=True, inject_record=True)
async def patch_thread(thread_id: ThreadId, request: Request, body: ThreadPatchRequest, thread_record: dict = None) -> ThreadResponse:
"""Merge metadata into a thread record.
Multi-tenant isolation: only the thread owner can patch their thread.
"""
store = get_store(request)
if store is None:
raise HTTPException(status_code=503, detail="Store not available")
record = await _store_get(store, thread_id)
record = thread_record
if record is None:
record = await _store_get(store, thread_id)
if record is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
@ -451,12 +502,17 @@ async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Reques
@router.get("/{thread_id}", response_model=ThreadResponse)
async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
@require_auth
@require_permission("threads", "read", owner_check=True)
async def get_thread(thread_id: ThreadId, request: Request) -> ThreadResponse:
"""Get thread info.
Reads metadata from the Store and derives the accurate execution
status from the checkpointer. Falls back to the checkpointer alone
for threads that pre-date Store adoption (backward compat).
Multi-tenant isolation: returns 404 if the thread does not belong to
the authenticated user.
"""
store = get_store(request)
checkpointer = get_checkpointer(request)
@ -488,26 +544,33 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
}
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle") # type: ignore[union-attr]
if record is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle")
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} if checkpoint_tuple is not None else {}
channel_values = checkpoint.get("channel_values", {})
return ThreadResponse(
thread_id=thread_id,
status=status,
created_at=str(record.get("created_at", "")), # type: ignore[union-attr]
updated_at=str(record.get("updated_at", "")), # type: ignore[union-attr]
metadata=record.get("metadata", {}), # type: ignore[union-attr]
created_at=str(record.get("created_at", "")),
updated_at=str(record.get("updated_at", "")),
metadata=record.get("metadata", {}),
values=serialize_channel_values(channel_values),
)
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
@require_auth
@require_permission("threads", "read", owner_check=True)
async def get_thread_state(thread_id: ThreadId, request: Request) -> ThreadStateResponse:
"""Get the latest state snapshot for a thread.
Channel values are serialized to ensure LangChain message objects
are converted to JSON-safe dicts.
Multi-tenant isolation: returns 404 if thread does not belong to user.
"""
checkpointer = get_checkpointer(request)
@ -552,12 +615,16 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
@require_auth
@require_permission("threads", "write", owner_check=True)
async def update_thread_state(thread_id: ThreadId, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
"""Update thread state (e.g. for human-in-the-loop resume or title rename).
Writes a new checkpoint that merges *body.values* into the latest
channel values, then syncs any updated ``title`` field back to the Store
so that ``/threads/search`` reflects the change immediately.
Multi-tenant isolation: only the thread owner can update their thread.
"""
checkpointer = get_checkpointer(request)
store = get_store(request)
@ -635,8 +702,13 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
"""Get checkpoint history for a thread."""
@require_auth
@require_permission("threads", "read", owner_check=True)
async def get_thread_history(thread_id: ThreadId, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
"""Get checkpoint history for a thread.
Multi-tenant isolation: returns 404 if thread does not belong to user.
"""
checkpointer = get_checkpointer(request)
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}

View File

@ -116,6 +116,7 @@ def build_run_config(
metadata: dict[str, Any] | None,
*,
assistant_id: str | None = None,
user_id: str | None = None,
) -> dict[str, Any]:
"""Build a RunnableConfig dict for the agent.
@ -128,6 +129,9 @@ def build_run_config(
This mirrors the channel manager's ``_resolve_run_params`` logic so that
the LangGraph Platform-compatible HTTP API and the IM channel path behave
identically.
If *user_id* is provided, it is injected into the config metadata for
multi-tenant isolation.
"""
config: dict[str, Any] = {"recursion_limit": 100}
if request_config:
@ -161,6 +165,11 @@ def build_run_config(
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
config["configurable"]["agent_name"] = normalized
# Multi-tenant isolation: inject user_id into metadata
if user_id:
config.setdefault("metadata", {})["user_id"] = user_id
if metadata:
config.setdefault("metadata", {}).update(metadata)
return config
@ -260,6 +269,10 @@ async def start_run(
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
# Reuse auth context set by @require_auth decorator to avoid redundant DB lookup
auth = getattr(request.state, "auth", None)
user_id = str(auth.user.id) if auth and auth.user else None
try:
record = await run_mgr.create_or_reject(
thread_id,
@ -282,7 +295,13 @@ async def start_run(
agent_factory = resolve_agent_factory(body.assistant_id)
graph_input = normalize_input(body.input)
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
config = build_run_config(
thread_id,
body.config,
body.metadata,
assistant_id=body.assistant_id,
user_id=user_id,
)
# Merge DeerFlow-specific context overrides into configurable.
# The ``context`` field is a custom extension for the langgraph-compat layer

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,129 @@
# Authentication Upgrade Guide
DeerFlow 内置了认证模块。本文档面向从无认证版本升级的用户。
## 核心概念
认证模块采用**始终强制**策略:
- 首次启动时自动创建 admin 账号,随机密码打印到控制台日志
- 认证从一开始就是强制的,无竞争窗口
- 历史对话(升级前创建的 thread自动迁移到 admin 名下
## 升级步骤
### 1. 更新代码
```bash
git pull origin main
cd backend && make install
```
### 2. 首次启动
```bash
make dev
```
控制台会输出:
```
============================================================
Admin account created on first boot
Email: admin@deerflow.dev
Password: aB3xK9mN_pQ7rT2w
Change it after login: Settings → Account
============================================================
```
如果未登录就重启了服务,不用担心——只要 setup 未完成,每次启动都会重置密码并重新打印到控制台。
### 3. 登录
访问 `http://localhost:2026/login`,使用控制台输出的邮箱和密码登录。
### 4. 修改密码
登录后进入 Settings → Account → Change Password。
### 5. 添加用户(可选)
其他用户通过 `/login` 页面注册,自动获得 **user** 角色。每个用户只能看到自己的对话。
## 安全机制
| 机制 | 说明 |
|------|------|
| JWT HttpOnly Cookie | Token 不暴露给 JavaScript防止 XSS 窃取 |
| CSRF Double Submit Cookie | 所有 POST/PUT/DELETE 请求需携带 `X-CSRF-Token` |
| bcrypt 密码哈希 | 密码不以明文存储 |
| 多租户隔离 | 用户只能访问自己的 thread |
| HTTPS 自适应 | 检测 `x-forwarded-proto`,自动设置 `Secure` cookie 标志 |
## 常见操作
### 忘记密码
```bash
cd backend
# 重置 admin 密码
python -m app.gateway.auth.reset_admin
# 重置指定用户密码
python -m app.gateway.auth.reset_admin --email user@example.com
```
会输出新的随机密码。
### 完全重置
删除用户数据库,重启后自动创建新 admin
```bash
rm -f backend/.deer-flow/users.db
# 重启服务,控制台输出新密码
```
## 数据存储
| 文件 | 内容 |
|------|------|
| `.deer-flow/users.db` | SQLite 用户数据库(密码哈希、角色) |
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成临时密钥,重启后 session 失效) |
### 生产环境建议
```bash
# 生成持久化 JWT 密钥,避免重启后所有用户需重新登录
python -c "import secrets; print(secrets.token_urlsafe(32))"
# 将输出添加到 .env
# AUTH_JWT_SECRET=<生成的密钥>
```
## API 端点
| 端点 | 方法 | 说明 |
|------|------|------|
| `/api/v1/auth/login/local` | POST | 邮箱密码登录OAuth2 form |
| `/api/v1/auth/register` | POST | 注册新用户user 角色) |
| `/api/v1/auth/logout` | POST | 登出(清除 cookie |
| `/api/v1/auth/me` | GET | 获取当前用户信息 |
| `/api/v1/auth/change-password` | POST | 修改密码 |
| `/api/v1/auth/setup-status` | GET | 检查 admin 是否存在 |
## 兼容性
- **标准模式**`make dev`完全兼容admin 自动创建
- **Gateway 模式**`make dev-pro`):完全兼容
- **Docker 部署**:完全兼容,`.deer-flow/users.db` 需持久化卷挂载
- **IM 渠道**Feishu/Slack/Telegram通过 LangGraph SDK 通信,不经过认证层
- **DeerFlowClient**(嵌入式):不经过 HTTP不受认证影响
## 故障排查
| 症状 | 原因 | 解决 |
|------|------|------|
| 启动后没看到密码 | admin 已存在(非首次启动) | 用 `reset_admin` 重置,或删 `users.db` |
| 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 |
| 重启后需要重新登录 | `AUTH_JWT_SECRET` 未持久化 | 在 `.env` 中设置固定密钥 |

View File

@ -248,7 +248,7 @@ def after_agent(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | N
- [`packages/harness/deerflow/agents/thread_state.py`](../packages/harness/deerflow/agents/thread_state.py) - ThreadState 定义
- [`packages/harness/deerflow/agents/middlewares/title_middleware.py`](../packages/harness/deerflow/agents/middlewares/title_middleware.py) - TitleMiddleware 实现
- [`packages/harness/deerflow/config/title_config.py`](../packages/harness/deerflow/config/title_config.py) - 配置管理
- [`config.yaml`](../config.yaml) - 配置文件
- [`config.yaml`](../../config.example.yaml) - 配置文件
- [`packages/harness/deerflow/agents/lead_agent/agent.py`](../packages/harness/deerflow/agents/lead_agent/agent.py) - Middleware 注册
## 参考资料

View File

@ -30,7 +30,7 @@
### 2. 配置文件
#### [`config.yaml`](../config.yaml)
#### [`config.yaml`](../../config.example.yaml)
- ✅ 添加 title 配置段:
```yaml
title:
@ -51,7 +51,7 @@ title:
- ✅ 故障排查指南
- ✅ State vs Metadata 对比
#### [`BACKEND_TODO.md`](../BACKEND_TODO.md)
#### [`TODO.md`](TODO.md)
- ✅ 添加功能完成记录
### 4. 测试

View File

@ -0,0 +1,446 @@
# [RFC] 在 DeerFlow 中增加 `grep``glob` 文件搜索工具
## Summary
我认为这个方向是对的,而且值得做。
如果 DeerFlow 想更接近 Claude Code 这类 coding agent 的实际工作流,仅有 `ls` / `read_file` / `write_file` / `str_replace` 还不够。模型在进入修改前,通常还需要两类能力:
- `glob`: 快速按路径模式找文件
- `grep`: 快速按内容模式找候选位置
这两类工具的价值,不是“功能上 bash 也能做”,而是它们能以更低 token 成本、更强约束、更稳定的输出格式,替代模型频繁走 `bash find` / `bash grep` / `rg` 的习惯。
但前提是实现方式要对:**它们应该是只读、结构化、受限、可审计的原生工具,而不是对 shell 命令的简单包装。**
## Problem
当前 DeerFlow 的文件工具层主要覆盖:
- `ls`: 浏览目录结构
- `read_file`: 读取文件内容
- `write_file`: 写文件
- `str_replace`: 做局部字符串替换
- `bash`: 兜底执行命令
这套能力能完成任务,但在代码库探索阶段效率不高。
典型问题:
1. 模型想找 “所有 `*.tsx` 的 page 文件” 时,只能反复 `ls` 多层目录,或者退回 `bash find`
2. 模型想找 “某个 symbol / 文案 / 配置键在哪里出现” 时,只能逐文件 `read_file`,或者退回 `bash grep` / `rg`
3. 一旦退回 `bash`,工具调用就失去结构化输出,结果也更难做裁剪、分页、审计和跨 sandbox 一致化
4. 对没有开启 host bash 的本地模式,`bash` 甚至可能不可用,此时缺少足够强的只读检索能力
结论DeerFlow 现在缺的不是“再多一个 shell 命令”,而是**文件系统检索层**。
## Goals
- 为 agent 提供稳定的路径搜索和内容搜索能力
- 减少对 `bash` 的依赖,特别是在仓库探索阶段
- 保持与现有 sandbox 安全模型一致
- 输出格式结构化,便于模型后续串联 `read_file` / `str_replace`
- 让本地 sandbox、容器 sandbox、未来 MCP 文件系统工具都能遵守同一语义
## Non-Goals
- 不做通用 shell 兼容层
- 不暴露完整 grep/find/rg CLI 语法
- 不在第一版支持二进制检索、复杂 PCRE 特性、上下文窗口高亮渲染等重功能
- 不把它做成“任意磁盘搜索”,仍然只允许在 DeerFlow 已授权的路径内执行
## Why This Is Worth Doing
参考 Claude Code 这一类 agent 的设计思路,`glob``grep` 的核心价值不是新能力本身,而是把“探索代码库”的常见动作从开放式 shell 降到受控工具层。
这样有几个直接收益:
1. **更低的模型负担**
模型不需要自己拼 `find`, `grep`, `rg`, `xargs`, quoting 等命令细节。
2. **更稳定的跨环境行为**
本地、Docker、AIO sandbox 不必依赖容器里是否装了 `rg`,也不会因为 shell 差异导致行为漂移。
3. **更强的安全与审计**
调用参数就是“搜索什么、在哪搜、最多返回多少”,天然比任意命令更容易审计和限流。
4. **更好的 token 效率**
`grep` 返回的是命中摘要而不是整段文件,模型只对少数候选路径再调用 `read_file`
5. **对 `tool_search` 友好**
当 DeerFlow 持续扩展工具集时,`grep` / `glob` 会成为非常高频的基础工具,值得保留为 built-in而不是让模型总是退回通用 bash。
## Proposal
增加两个 built-in sandbox tools
- `glob`
- `grep`
推荐继续放在:
- `backend/packages/harness/deerflow/sandbox/tools.py`
并在 `config.example.yaml` 中默认加入 `file:read` 组。
### 1. `glob` 工具
用途:按路径模式查找文件或目录。
建议 schema
```python
@tool("glob", parse_docstring=True)
def glob_tool(
runtime: ToolRuntime[ContextT, ThreadState],
description: str,
pattern: str,
path: str,
include_dirs: bool = False,
max_results: int = 200,
) -> str:
...
```
参数语义:
- `description`: 与现有工具保持一致
- `pattern`: glob 模式,例如 `**/*.py``src/**/test_*.ts`
- `path`: 搜索根目录,必须是绝对路径
- `include_dirs`: 是否返回目录
- `max_results`: 最大返回条数,防止一次性打爆上下文
建议返回格式:
```text
Found 3 paths under /mnt/user-data/workspace
1. /mnt/user-data/workspace/backend/app.py
2. /mnt/user-data/workspace/backend/tests/test_app.py
3. /mnt/user-data/workspace/scripts/build.py
```
如果后续想更适合前端消费,也可以改成 JSON 字符串;但第一版为了兼容现有工具风格,返回可读文本即可。
### 2. `grep` 工具
用途:按内容模式搜索文件,返回命中位置摘要。
建议 schema
```python
@tool("grep", parse_docstring=True)
def grep_tool(
runtime: ToolRuntime[ContextT, ThreadState],
description: str,
pattern: str,
path: str,
glob: str | None = None,
literal: bool = False,
case_sensitive: bool = False,
max_results: int = 100,
) -> str:
...
```
参数语义:
- `pattern`: 搜索词或正则
- `path`: 搜索根目录,必须是绝对路径
- `glob`: 可选路径过滤,例如 `**/*.py`
- `literal`: 为 `True` 时按普通字符串匹配,不解释为正则
- `case_sensitive`: 是否大小写敏感
- `max_results`: 最大返回命中数,不是文件数
建议返回格式:
```text
Found 4 matches under /mnt/user-data/workspace
/mnt/user-data/workspace/backend/config.py:12: TOOL_GROUPS = [...]
/mnt/user-data/workspace/backend/config.py:48: def load_tool_config(...):
/mnt/user-data/workspace/backend/tools.py:91: "tool_groups"
/mnt/user-data/workspace/backend/tests/test_config.py:22: assert "tool_groups" in data
```
第一版建议只返回:
- 文件路径
- 行号
- 命中行摘要
不返回上下文块,避免结果过大。模型如果需要上下文,再调用 `read_file(path, start_line, end_line)`
## Design Principles
### A. 不做 shell wrapper
不建议把 `grep` 实现为:
```python
subprocess.run("grep ...")
```
也不建议在容器里直接拼 `find` / `rg` 命令。
原因:
- 会引入 shell quoting 和注入面
- 会依赖不同 sandbox 内镜像是否安装同一套命令
- Windows / macOS / Linux 行为不一致
- 很难稳定控制输出条数与格式
正确方向是:
- `glob` 使用 Python 标准库路径遍历
- `grep` 使用 Python 逐文件扫描
- 输出由 DeerFlow 自己格式化
如果未来为了性能考虑要优先调用 `rg`,也应该封装在 provider 内部,并保证外部语义不变,而不是把 CLI 暴露给模型。
### B. 继续沿用 DeerFlow 的路径权限模型
这两个工具必须复用当前 `ls` / `read_file` 的路径校验逻辑:
- 本地模式走 `validate_local_tool_path(..., read_only=True)`
- 支持 `/mnt/skills/...`
- 支持 `/mnt/acp-workspace/...`
- 支持 thread workspace / uploads / outputs 的虚拟路径解析
- 明确拒绝越权路径与 path traversal
也就是说,它们属于 **file:read**,不是 `bash` 的替代越权入口。
### C. 结果必须硬限制
没有硬限制的 `glob` / `grep` 很容易炸上下文。
建议第一版至少限制:
- `glob.max_results` 默认 200最大 1000
- `grep.max_results` 默认 100最大 500
- 单行摘要最大长度,例如 200 字符
- 二进制文件跳过
- 超大文件跳过,例如单文件大于 1 MB 或按配置控制
此外,命中数超过阈值时应返回:
- 已展示的条数
- 被截断的事实
- 建议用户缩小搜索范围
例如:
```text
Found more than 100 matches, showing first 100. Narrow the path or add a glob filter.
```
### D. 工具语义要彼此互补
推荐模型工作流应该是:
1. `glob` 找候选文件
2. `grep` 找候选位置
3. `read_file` 读局部上下文
4. `str_replace` / `write_file` 执行修改
这样工具边界清晰,也更利于 prompt 中教模型形成稳定习惯。
## Implementation Approach
## Option A: 直接在 `sandbox/tools.py` 实现第一版
这是我推荐的起步方案。
做法:
- 在 `sandbox/tools.py` 新增 `glob_tool``grep_tool`
- 在 local sandbox 场景直接使用 Python 文件系统 API
- 在非 local sandbox 场景,优先也通过 DeerFlow 自己控制的路径访问层实现
优点:
- 改动小
- 能尽快验证 agent 效果
- 不需要先改 `Sandbox` 抽象
缺点:
- `tools.py` 会继续变胖
- 如果未来想在 provider 侧做性能优化,需要再抽象一次
## Option B: 先扩展 `Sandbox` 抽象
例如新增:
```python
class Sandbox(ABC):
def glob(self, path: str, pattern: str, include_dirs: bool = False, max_results: int = 200) -> list[str]:
...
def grep(
self,
path: str,
pattern: str,
*,
glob: str | None = None,
literal: bool = False,
case_sensitive: bool = False,
max_results: int = 100,
) -> list[GrepMatch]:
...
```
优点:
- 抽象更干净
- 容器 / 远程 sandbox 可以各自优化
缺点:
- 首次引入成本更高
- 需要同步改所有 sandbox provider
结论:
**第一版建议走 Option A等工具价值验证后再下沉到 `Sandbox` 抽象层。**
## Detailed Behavior
### `glob` 行为
- 输入根目录不存在:返回清晰错误
- 根路径不是目录:返回清晰错误
- 模式非法:返回清晰错误
- 结果为空:返回 `No files matched`
- 默认忽略项应尽量与当前 `list_dir` 对齐,例如:
- `.git`
- `node_modules`
- `__pycache__`
- `.venv`
- 构建产物目录
这里建议抽一个共享 ignore 集,避免 `ls``glob` 结果风格不一致。
### `grep` 行为
- 默认只扫描文本文件
- 检测到二进制文件直接跳过
- 对超大文件直接跳过或只扫前 N KB
- regex 编译失败时返回参数错误
- 输出中的路径继续使用虚拟路径,而不是暴露宿主真实路径
- 建议默认按文件路径、行号排序,保持稳定输出
## Prompting Guidance
如果引入这两个工具,建议同步更新系统提示中的文件操作建议:
- 查找文件名模式时优先用 `glob`
- 查找代码符号、配置项、文案时优先用 `grep`
- 只有在工具不足以完成目标时才退回 `bash`
否则模型仍会习惯性先调用 `bash`
## Risks
### 1. 与 `bash` 能力重叠
这是事实,但不是问题。
`ls``read_file` 也都能被 `bash` 替代,但我们仍然保留它们,因为结构化工具更适合 agent。
### 2. 性能问题
在大仓库上,纯 Python `grep` 可能比 `rg` 慢。
缓解方式:
- 第一版先加结果上限和文件大小上限
- 路径上强制要求 root path
- 提供 `glob` 过滤缩小扫描范围
- 后续如有必要,在 provider 内部做 `rg` 优化,但保持同一 schema
### 3. 忽略规则不一致
如果 `ls` 能看到的路径,`glob` 却看不到,模型会困惑。
缓解方式:
- 统一 ignore 规则
- 在文档里明确“默认跳过常见依赖和构建目录”
### 4. 正则搜索过于复杂
如果第一版就支持大量 grep 方言,边界会很乱。
缓解方式:
- 第一版只支持 Python `re`
- 并提供 `literal=True` 的简单模式
## Alternatives Considered
### A. 不增加工具,完全依赖 `bash`
不推荐。
这会让 DeerFlow 在代码探索体验上持续落后,也削弱无 bash 或受限 bash 场景下的能力。
### B. 只加 `glob`,不加 `grep`
不推荐。
只解决“找文件”,没有解决“找位置”。模型最终还是会退回 `bash grep`
### C. 只加 `grep`,不加 `glob`
也不推荐。
`grep` 缺少路径模式过滤时,扫描范围经常太大;`glob` 是它的天然前置工具。
### D. 直接接入 MCP filesystem server 的搜索能力
短期不推荐作为主路径。
MCP 可以是补充,但 `glob` / `grep` 作为 DeerFlow 的基础 coding tool最好仍然是 built-in这样才能在默认安装中稳定可用。
## Acceptance Criteria
- `config.example.yaml` 中可默认启用 `glob``grep`
- 两个工具归属 `file:read`
- 本地 sandbox 下严格遵守现有路径权限
- 输出不泄露宿主机真实路径
- 大结果集会被截断并明确提示
- 模型可以通过 `glob -> grep -> read_file -> str_replace` 完成典型改码流
- 在禁用 host bash 的本地模式下,仓库探索能力明显提升
## Rollout Plan
1. 在 `sandbox/tools.py` 中实现 `glob_tool``grep_tool`
2. 抽取与 `list_dir` 一致的 ignore 规则,避免行为漂移
3. 在 `config.example.yaml` 默认加入工具配置
4. 为本地路径校验、虚拟路径映射、结果截断、二进制跳过补测试
5. 更新 README / backend docs / prompt guidance
6. 收集实际 agent 调用数据,再决定是否下沉到 `Sandbox` 抽象
## Suggested Config
```yaml
tools:
- name: glob
group: file:read
use: deerflow.sandbox.tools:glob_tool
- name: grep
group: file:read
use: deerflow.sandbox.tools:grep_tool
```
## Final Recommendation
结论是:**可以加,而且应该加。**
但我会明确卡三个边界:
1. `grep` / `glob` 必须是 built-in 的只读结构化工具
2. 第一版不要做 shell wrapper不要把 CLI 方言直接暴露给模型
3. 先在 `sandbox/tools.py` 验证价值,再考虑是否下沉到 `Sandbox` provider 抽象
如果按这个方向做,它会明显提升 DeerFlow 在 coding / repo exploration 场景下的可用性,而且风险可控。

View File

@ -8,6 +8,9 @@
"graphs": {
"lead_agent": "deerflow.agents:make_lead_agent"
},
"auth": {
"path": "./app/gateway/langgraph_auth.py:auth"
},
"checkpointer": {
"path": "./packages/harness/deerflow/agents/checkpointer/async_provider.py:make_checkpointer"
}

View File

@ -8,6 +8,14 @@ from deerflow.subagents import get_available_subagent_names
logger = logging.getLogger(__name__)
def _get_enabled_skills():
try:
return list(load_skills(enabled_only=True))
except Exception:
logger.exception("Failed to load enabled skills for prompt injection")
return []
def _build_subagent_section(max_concurrent: int) -> str:
"""Build the subagent system prompt section with dynamic concurrency limit.
@ -386,7 +394,7 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
Returns the <skill_system>...</skill_system> block listing all enabled skills,
suitable for injection into any agent's system prompt.
"""
skills = load_skills(enabled_only=True)
skills = _get_enabled_skills()
try:
from deerflow.config import get_app_config
@ -450,7 +458,7 @@ def get_deferred_tools_prompt_section() -> str:
if not get_app_config().tool_search.enabled:
return ""
except FileNotFoundError:
except Exception:
return ""
registry = get_deferred_registry()

View File

@ -246,6 +246,10 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
if earlier.get("summary"):
history_sections.append(f"Earlier: {earlier['summary']}")
background = history_data.get("longTermBackground", {})
if background.get("summary"):
history_sections.append(f"Background: {background['summary']}")
if history_sections:
sections.append("History:\n" + "\n".join(f"- {s}" for s in history_sections))

View File

@ -21,6 +21,7 @@ class ConversationContext:
timestamp: datetime = field(default_factory=datetime.utcnow)
agent_name: str | None = None
correction_detected: bool = False
reinforcement_detected: bool = False
class MemoryUpdateQueue:
@ -44,6 +45,7 @@ class MemoryUpdateQueue:
messages: list[Any],
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> None:
"""Add a conversation to the update queue.
@ -52,6 +54,7 @@ class MemoryUpdateQueue:
messages: The conversation messages.
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
"""
config = get_memory_config()
if not config.enabled:
@ -63,11 +66,13 @@ class MemoryUpdateQueue:
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
context = ConversationContext(
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
correction_detected=merged_correction_detected,
reinforcement_detected=merged_reinforcement_detected,
)
# Check if this thread already has a pending update
@ -130,6 +135,7 @@ class MemoryUpdateQueue:
thread_id=context.thread_id,
agent_name=context.agent_name,
correction_detected=context.correction_detected,
reinforcement_detected=context.reinforcement_detected,
)
if success:
logger.info("Memory updated successfully for thread %s", context.thread_id)

View File

@ -246,7 +246,7 @@ def _fact_content_key(content: Any) -> str | None:
stripped = content.strip()
if not stripped:
return None
return stripped
return stripped.casefold()
class MemoryUpdater:
@ -272,6 +272,7 @@ class MemoryUpdater:
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool:
"""Update memory based on conversation messages.
@ -280,6 +281,7 @@ class MemoryUpdater:
thread_id: Optional thread ID for tracking source.
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
Returns:
True if update was successful, False otherwise.
@ -310,6 +312,14 @@ class MemoryUpdater:
"and record the correct approach as a fact with category "
'"correction" and confidence >= 0.95 when appropriate.'
)
if reinforcement_detected:
reinforcement_hint = (
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
"The user explicitly confirmed the agent's approach was correct or helpful. "
"Record the confirmed approach, style, or preference as a fact with category "
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
)
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
prompt = MEMORY_UPDATE_PROMPT.format(
current_memory=json.dumps(current_memory, indent=2),
@ -441,6 +451,7 @@ def update_memory_from_conversation(
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> bool:
"""Convenience function to update memory from a conversation.
@ -449,9 +460,10 @@ def update_memory_from_conversation(
thread_id: Optional thread ID.
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
Returns:
True if successful, False otherwise.
"""
updater = MemoryUpdater()
return updater.update_memory(messages, thread_id, agent_name, correction_detected)
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)

View File

@ -182,6 +182,23 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
return None, False
@staticmethod
def _append_text(content: str | list | None, text: str) -> str | list:
"""Append *text* to AIMessage content, handling str, list, and None.
When content is a list of content blocks (e.g. Anthropic thinking mode),
we append a new ``{"type": "text", ...}`` block instead of concatenating
a string to a list, which would raise ``TypeError``.
"""
if content is None:
return text
if isinstance(content, list):
return [*content, {"type": "text", "text": f"\n\n{text}"}]
if isinstance(content, str):
return content + f"\n\n{text}"
# Fallback: coerce unexpected types to str to avoid TypeError
return str(content) + f"\n\n{text}"
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
warning, hard_stop = self._track_and_check(state, runtime)
@ -192,7 +209,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
stripped_msg = last_msg.model_copy(
update={
"tool_calls": [],
"content": (last_msg.content or "") + f"\n\n{_HARD_STOP_MSG}",
"content": self._append_text(last_msg.content, _HARD_STOP_MSG),
}
)
return {"messages": [stripped_msg]}

View File

@ -29,6 +29,22 @@ _CORRECTION_PATTERNS = (
re.compile(r"改用"),
)
_REINFORCEMENT_PATTERNS = (
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
re.compile(r"对[,]?\s*就是这样(?:[。!?!?.]|$)"),
re.compile(r"完全正确(?:[。!?!?.]|$)"),
re.compile(r"(?:对[,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
re.compile(r"继续保持(?:[。!?!?.]|$)"),
)
class MemoryMiddlewareState(AgentState):
"""Compatible with the `ThreadState` schema."""
@ -132,6 +148,29 @@ def detect_correction(messages: list[Any]) -> bool:
return False
def detect_reinforcement(messages: list[Any]) -> bool:
"""Detect explicit positive reinforcement signals in recent conversation turns.
Complements detect_correction() by identifying when the user confirms the
agent's approach was correct. This allows the memory system to record what
worked well, not just what went wrong.
The queue keeps only one pending context per thread, so callers pass the
latest filtered message list. Checking only recent user turns keeps signal
detection conservative while avoiding stale signals from long histories.
"""
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
for msg in recent_user_msgs:
content = _extract_message_text(msg).strip()
if not content:
continue
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
return True
return False
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
"""Middleware that queues conversation for memory update after agent execution.
@ -196,12 +235,14 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
# Queue the filtered conversation for memory update
correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
queue = get_memory_queue()
queue.add(
thread_id=thread_id,
messages=filtered_messages,
agent_name=self._agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
return None

View File

@ -101,44 +101,33 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
return user_msg if user_msg else "New Conversation"
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Synchronously generate a title. Returns state update or None."""
"""Generate a local fallback title without blocking on an LLM call."""
if not self._should_generate_title(state):
return None
prompt, user_msg = self._build_title_prompt(state)
config = get_title_config()
model = create_chat_model(name=config.model_name, thinking_enabled=False)
try:
response = model.invoke(prompt)
title = self._parse_title(response.content)
if not title:
title = self._fallback_title(user_msg)
except Exception:
logger.exception("Failed to generate title (sync)")
title = self._fallback_title(user_msg)
return {"title": title}
_, user_msg = self._build_title_prompt(state)
return {"title": self._fallback_title(user_msg)}
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Asynchronously generate a title. Returns state update or None."""
"""Generate a title asynchronously and fall back locally on failure."""
if not self._should_generate_title(state):
return None
prompt, user_msg = self._build_title_prompt(state)
config = get_title_config()
model = create_chat_model(name=config.model_name, thinking_enabled=False)
prompt, user_msg = self._build_title_prompt(state)
try:
if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False)
else:
model = create_chat_model(thinking_enabled=False)
response = await model.ainvoke(prompt)
title = self._parse_title(response.content)
if not title:
title = self._fallback_title(user_msg)
if title:
return {"title": title}
except Exception:
logger.exception("Failed to generate title (async)")
title = self._fallback_title(user_msg)
return {"title": title}
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
return {"title": self._fallback_title(user_msg)}
@override
def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:

View File

@ -138,6 +138,6 @@ def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentM
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
return _build_runtime_middlewares(
include_uploads=False,
include_dangling_tool_call_patch=False,
include_dangling_tool_call_patch=True,
lazy_init=lazy_init,
)

View File

@ -10,10 +10,52 @@ from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
from deerflow.config.paths import Paths, get_paths
from deerflow.utils.file_conversion import extract_outline
logger = logging.getLogger(__name__)
_OUTLINE_PREVIEW_LINES = 5
def _extract_outline_for_file(file_path: Path) -> tuple[list[dict], list[str]]:
"""Return the document outline and fallback preview for *file_path*.
Looks for a sibling ``<stem>.md`` file produced by the upload conversion
pipeline.
Returns:
(outline, preview) where:
- outline: list of ``{title, line}`` dicts (plus optional sentinel).
Empty when no headings are found or no .md exists.
- preview: first few non-empty lines of the .md, used as a content
anchor when outline is empty so the agent has some context.
Empty when outline is non-empty (no fallback needed).
"""
md_path = file_path.with_suffix(".md")
if not md_path.is_file():
return [], []
outline = extract_outline(md_path)
if outline:
logger.debug("Extracted %d outline entries from %s", len(outline), file_path.name)
return outline, []
# outline is empty — read the first few non-empty lines as a content preview
preview: list[str] = []
try:
with md_path.open(encoding="utf-8") as f:
for line in f:
stripped = line.strip()
if stripped:
preview.append(stripped)
if len(preview) >= _OUTLINE_PREVIEW_LINES:
break
except Exception:
logger.debug("Failed to read preview lines from %s", md_path, exc_info=True)
return [], preview
class UploadsMiddlewareState(AgentState):
"""State schema for uploads middleware."""
@ -39,12 +81,38 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
super().__init__()
self._paths = Paths(base_dir) if base_dir else get_paths()
def _format_file_entry(self, file: dict, lines: list[str]) -> None:
"""Append a single file entry (name, size, path, optional outline) to lines."""
size_kb = file["size"] / 1024
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
lines.append(f"- {file['filename']} ({size_str})")
lines.append(f" Path: {file['path']}")
outline = file.get("outline") or []
if outline:
truncated = outline[-1].get("truncated", False)
visible = [e for e in outline if not e.get("truncated")]
lines.append(" Document outline (use `read_file` with line ranges to read sections):")
for entry in visible:
lines.append(f" L{entry['line']}: {entry['title']}")
if truncated:
lines.append(f" ... (showing first {len(visible)} headings; use `read_file` to explore further)")
else:
preview = file.get("outline_preview") or []
if preview:
lines.append(" No structural headings detected. Document begins with:")
for text in preview:
lines.append(f" > {text}")
lines.append(" Use `grep` to search for keywords (e.g. `grep(pattern='keyword', path='/mnt/user-data/uploads/')`).")
lines.append("")
def _create_files_message(self, new_files: list[dict], historical_files: list[dict]) -> str:
"""Create a formatted message listing uploaded files.
Args:
new_files: Files uploaded in the current message.
historical_files: Files uploaded in previous messages.
Each file dict may contain an optional ``outline`` key a list of
``{title, line}`` dicts extracted from the converted Markdown file.
Returns:
Formatted string inside <uploaded_files> tags.
@ -55,25 +123,24 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
lines.append("")
if new_files:
for file in new_files:
size_kb = file["size"] / 1024
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
lines.append(f"- {file['filename']} ({size_str})")
lines.append(f" Path: {file['path']}")
lines.append("")
self._format_file_entry(file, lines)
else:
lines.append("(empty)")
lines.append("")
if historical_files:
lines.append("The following files were uploaded in previous messages and are still available:")
lines.append("")
for file in historical_files:
size_kb = file["size"] / 1024
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
lines.append(f"- {file['filename']} ({size_str})")
lines.append(f" Path: {file['path']}")
lines.append("")
self._format_file_entry(file, lines)
lines.append("You can read these files using the `read_file` tool with the paths shown above.")
lines.append("To work with these files:")
lines.append("- Read from the file first — use the outline line numbers and `read_file` to locate relevant sections.")
lines.append("- Use `grep` to search for keywords when you are not sure which section to look at")
lines.append(" (e.g. `grep(pattern='revenue', path='/mnt/user-data/uploads/')`).")
lines.append("- Use `glob` to find files by name pattern")
lines.append(" (e.g. `glob(pattern='**/*.md', path='/mnt/user-data/uploads/')`).")
lines.append("- Only fall back to web search if the file content is clearly insufficient to answer the question.")
lines.append("</uploaded_files>")
return "\n".join(lines)
@ -147,6 +214,13 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
# Resolve uploads directory for existence checks
thread_id = (runtime.context or {}).get("thread_id")
if thread_id is None:
try:
from langgraph.config import get_config
thread_id = get_config().get("configurable", {}).get("thread_id")
except RuntimeError:
pass # get_config() raises outside a runnable context (e.g. unit tests)
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
# Get newly uploaded files from the current message's additional_kwargs.files
@ -159,15 +233,26 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
for file_path in sorted(uploads_dir.iterdir()):
if file_path.is_file() and file_path.name not in new_filenames:
stat = file_path.stat()
outline, preview = _extract_outline_for_file(file_path)
historical_files.append(
{
"filename": file_path.name,
"size": stat.st_size,
"path": f"/mnt/user-data/uploads/{file_path.name}",
"extension": file_path.suffix,
"outline": outline,
"outline_preview": preview,
}
)
# Attach outlines to new files as well
if uploads_dir:
for file in new_files:
phys_path = uploads_dir / file["filename"]
outline, preview = _extract_outline_for_file(phys_path)
file["outline"] = outline
file["outline_preview"] = preview
if not new_files and not historical_files:
return None

View File

@ -117,6 +117,7 @@ class DeerFlowClient:
subagent_enabled: bool = False,
plan_mode: bool = False,
agent_name: str | None = None,
available_skills: set[str] | None = None,
middlewares: Sequence[AgentMiddleware] | None = None,
):
"""Initialize the client.
@ -133,6 +134,7 @@ class DeerFlowClient:
subagent_enabled: Enable subagent delegation.
plan_mode: Enable TodoList middleware for plan mode.
agent_name: Name of the agent to use.
available_skills: Optional set of skill names to make available. If None (default), all scanned skills are available.
middlewares: Optional list of custom middlewares to inject into the agent.
"""
if config_path is not None:
@ -148,6 +150,7 @@ class DeerFlowClient:
self._subagent_enabled = subagent_enabled
self._plan_mode = plan_mode
self._agent_name = agent_name
self._available_skills = set(available_skills) if available_skills is not None else None
self._middlewares = list(middlewares) if middlewares else []
# Lazy agent — created on first call, recreated when config changes.
@ -208,6 +211,8 @@ class DeerFlowClient:
cfg.get("thinking_enabled"),
cfg.get("is_plan_mode"),
cfg.get("subagent_enabled"),
self._agent_name,
frozenset(self._available_skills) if self._available_skills is not None else None,
)
if self._agent is not None and self._agent_config_key == key:
@ -226,6 +231,7 @@ class DeerFlowClient:
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
agent_name=self._agent_name,
available_skills=self._available_skills,
),
"state_schema": ThreadState,
}

View File

@ -7,6 +7,7 @@ import uuid
from agent_sandbox import Sandbox as AioSandboxClient
from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.search import GrepMatch, path_matches, should_ignore_path, truncate_line
logger = logging.getLogger(__name__)
@ -135,6 +136,86 @@ class AioSandbox(Sandbox):
logger.error(f"Failed to write file in sandbox: {e}")
raise
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
if not include_dirs:
result = self._client.file.find_files(path=path, glob=pattern)
files = result.data.files if result.data and result.data.files else []
filtered = [file_path for file_path in files if not should_ignore_path(file_path)]
truncated = len(filtered) > max_results
return filtered[:max_results], truncated
result = self._client.file.list_path(path=path, recursive=True, show_hidden=False)
entries = result.data.files if result.data and result.data.files else []
matches: list[str] = []
root_path = path.rstrip("/") or "/"
root_prefix = root_path if root_path == "/" else f"{root_path}/"
for entry in entries:
if entry.path != root_path and not entry.path.startswith(root_prefix):
continue
if should_ignore_path(entry.path):
continue
rel_path = entry.path[len(root_path) :].lstrip("/")
if path_matches(pattern, rel_path):
matches.append(entry.path)
if len(matches) >= max_results:
return matches, True
return matches, False
def grep(
self,
path: str,
pattern: str,
*,
glob: str | None = None,
literal: bool = False,
case_sensitive: bool = False,
max_results: int = 100,
) -> tuple[list[GrepMatch], bool]:
import re as _re
regex_source = _re.escape(pattern) if literal else pattern
# Validate the pattern locally so an invalid regex raises re.error
# (caught by grep_tool's except re.error handler) rather than a
# generic remote API error.
_re.compile(regex_source, 0 if case_sensitive else _re.IGNORECASE)
regex = regex_source if case_sensitive else f"(?i){regex_source}"
if glob is not None:
find_result = self._client.file.find_files(path=path, glob=glob)
candidate_paths = find_result.data.files if find_result.data and find_result.data.files else []
else:
list_result = self._client.file.list_path(path=path, recursive=True, show_hidden=False)
entries = list_result.data.files if list_result.data and list_result.data.files else []
candidate_paths = [entry.path for entry in entries if not entry.is_directory]
matches: list[GrepMatch] = []
truncated = False
for file_path in candidate_paths:
if should_ignore_path(file_path):
continue
search_result = self._client.file.search_in_file(file=file_path, regex=regex)
data = search_result.data
if data is None:
continue
line_numbers = data.line_numbers or []
matched_lines = data.matches or []
for line_number, line in zip(line_numbers, matched_lines):
matches.append(
GrepMatch(
path=file_path,
line_number=line_number if isinstance(line_number, int) else 0,
line=truncate_line(line),
)
)
if len(matches) >= max_results:
truncated = True
return matches, truncated
return matches, truncated
def update_file(self, path: str, content: bytes) -> None:
"""Update a file with binary content in the sandbox.

View File

@ -1,5 +1,6 @@
import logging
import os
from contextvars import ContextVar
from pathlib import Path
from typing import Any, Self
@ -10,15 +11,15 @@ from pydantic import BaseModel, ConfigDict, Field
from deerflow.config.acp_config import load_acp_config_from_dict
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import load_guardrails_config_from_dict
from deerflow.config.memory_config import load_memory_config_from_dict
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
from deerflow.config.model_config import ModelConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.skills_config import SkillsConfig
from deerflow.config.stream_bridge_config import StreamBridgeConfig, load_stream_bridge_config_from_dict
from deerflow.config.subagents_config import load_subagents_config_from_dict
from deerflow.config.summarization_config import load_summarization_config_from_dict
from deerflow.config.title_config import load_title_config_from_dict
from deerflow.config.subagents_config import SubagentsAppConfig, load_subagents_config_from_dict
from deerflow.config.summarization_config import SummarizationConfig, load_summarization_config_from_dict
from deerflow.config.title_config import TitleConfig, load_title_config_from_dict
from deerflow.config.token_usage_config import TokenUsageConfig
from deerflow.config.tool_config import ToolConfig, ToolGroupConfig
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
@ -28,6 +29,13 @@ load_dotenv()
logger = logging.getLogger(__name__)
def _default_config_candidates() -> tuple[Path, ...]:
"""Return deterministic config.yaml locations without relying on cwd."""
backend_dir = Path(__file__).resolve().parents[4]
repo_root = backend_dir.parent
return (backend_dir / "config.yaml", repo_root / "config.yaml")
class AppConfig(BaseModel):
"""Config for the DeerFlow application"""
@ -40,6 +48,11 @@ class AppConfig(BaseModel):
skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration")
extensions: ExtensionsConfig = Field(default_factory=ExtensionsConfig, description="Extensions configuration (MCP servers and skills state)")
tool_search: ToolSearchConfig = Field(default_factory=ToolSearchConfig, description="Tool search / deferred loading configuration")
title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration")
summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration")
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
model_config = ConfigDict(extra="allow", frozen=False)
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
@ -51,7 +64,7 @@ class AppConfig(BaseModel):
Priority:
1. If provided `config_path` argument, use it.
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
3. Otherwise, first check the `config.yaml` in the current directory, then fallback to `config.yaml` in the parent directory.
3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`.
"""
if config_path:
path = Path(config_path)
@ -64,14 +77,10 @@ class AppConfig(BaseModel):
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
return path
else:
# Check if the config.yaml is in the current directory
path = Path(os.getcwd()) / "config.yaml"
if not path.exists():
# Check if the config.yaml is in the parent directory of CWD
path = Path(os.getcwd()).parent / "config.yaml"
if not path.exists():
raise FileNotFoundError("`config.yaml` file not found at the current directory nor its parent directory")
return path
for path in _default_config_candidates():
if path.exists():
return path
raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations")
@classmethod
def from_file(cls, config_path: str | None = None) -> Self:
@ -244,6 +253,8 @@ _app_config: AppConfig | None = None
_app_config_path: Path | None = None
_app_config_mtime: float | None = None
_app_config_is_custom = False
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
def _get_config_mtime(config_path: Path) -> float | None:
@ -276,6 +287,10 @@ def get_app_config() -> AppConfig:
"""
global _app_config, _app_config_path, _app_config_mtime
runtime_override = _current_app_config.get()
if runtime_override is not None:
return runtime_override
if _app_config is not None and _app_config_is_custom:
return _app_config
@ -337,3 +352,26 @@ def set_app_config(config: AppConfig) -> None:
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = True
def peek_current_app_config() -> AppConfig | None:
"""Return the runtime-scoped AppConfig override, if one is active."""
return _current_app_config.get()
def push_current_app_config(config: AppConfig) -> None:
"""Push a runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
_current_app_config_stack.set(stack + (_current_app_config.get(),))
_current_app_config.set(config)
def pop_current_app_config() -> None:
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
if not stack:
_current_app_config.set(None)
return
previous = stack[-1]
_current_app_config_stack.set(stack[:-1])
_current_app_config.set(previous)

View File

@ -80,6 +80,12 @@ class ExtensionsConfig(BaseModel):
Args:
config_path: Optional path to extensions config file.
Resolution order:
1. If provided `config_path` argument, use it.
2. If provided `DEER_FLOW_EXTENSIONS_CONFIG_PATH` environment variable, use it.
3. Otherwise, search backend/repository-root defaults for
`extensions_config.json`, then legacy `mcp_config.json`.
Returns:
Path to the extensions config file if found, otherwise None.
"""
@ -94,24 +100,16 @@ class ExtensionsConfig(BaseModel):
raise FileNotFoundError(f"Extensions config file specified by environment variable `DEER_FLOW_EXTENSIONS_CONFIG_PATH` not found at {path}")
return path
else:
# Check if the extensions_config.json is in the current directory
path = Path(os.getcwd()) / "extensions_config.json"
if path.exists():
return path
# Check if the extensions_config.json is in the parent directory of CWD
path = Path(os.getcwd()).parent / "extensions_config.json"
if path.exists():
return path
# Backward compatibility: check for mcp_config.json
path = Path(os.getcwd()) / "mcp_config.json"
if path.exists():
return path
path = Path(os.getcwd()).parent / "mcp_config.json"
if path.exists():
return path
backend_dir = Path(__file__).resolve().parents[4]
repo_root = backend_dir.parent
for path in (
backend_dir / "extensions_config.json",
repo_root / "extensions_config.json",
backend_dir / "mcp_config.json",
repo_root / "mcp_config.json",
):
if path.exists():
return path
# Extensions are optional, so return None if not found
return None

View File

@ -9,6 +9,12 @@ VIRTUAL_PATH_PREFIX = "/mnt/user-data"
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
def _default_local_base_dir() -> Path:
"""Return the repo-local DeerFlow state directory without relying on cwd."""
backend_dir = Path(__file__).resolve().parents[4]
return backend_dir / ".deer-flow"
def _validate_thread_id(thread_id: str) -> str:
"""Validate a thread ID before using it in filesystem paths."""
if not _SAFE_THREAD_ID_RE.match(thread_id):
@ -67,8 +73,7 @@ class Paths:
BaseDir resolution (in priority order):
1. Constructor argument `base_dir`
2. DEER_FLOW_HOME environment variable
3. Local dev fallback: cwd/.deer-flow (when cwd is the backend/ dir)
4. Default: $HOME/.deer-flow
3. Repo-local fallback derived from this module path: `{backend_dir}/.deer-flow`
"""
def __init__(self, base_dir: str | Path | None = None) -> None:
@ -104,11 +109,7 @@ class Paths:
if env_home := os.getenv("DEER_FLOW_HOME"):
return Path(env_home).resolve()
cwd = Path.cwd()
if cwd.name == "backend" or (cwd / "pyproject.toml").exists():
return cwd / ".deer-flow"
return Path.home() / ".deer-flow"
return _default_local_base_dir()
@property
def memory_file(self) -> Path:

View File

@ -3,6 +3,11 @@ from pathlib import Path
from pydantic import BaseModel, Field
def _default_repo_root() -> Path:
"""Resolve the repo root without relying on the current working directory."""
return Path(__file__).resolve().parents[5]
class SkillsConfig(BaseModel):
"""Configuration for skills system"""
@ -26,8 +31,8 @@ class SkillsConfig(BaseModel):
# Use configured path (can be absolute or relative)
path = Path(self.path)
if not path.is_absolute():
# If relative, resolve from current working directory
path = Path.cwd() / path
# If relative, resolve from the repo root for deterministic behavior.
path = _default_repo_root() / path
return path.resolve()
else:
# Default: ../skills relative to backend directory

View File

@ -15,6 +15,11 @@ class SubagentOverrideConfig(BaseModel):
ge=1,
description="Timeout in seconds for this subagent (None = use global default)",
)
max_turns: int | None = Field(
default=None,
ge=1,
description="Maximum turns for this subagent (None = use global or builtin default)",
)
class SubagentsAppConfig(BaseModel):
@ -25,6 +30,11 @@ class SubagentsAppConfig(BaseModel):
ge=1,
description="Default timeout in seconds for all subagents (default: 900 = 15 minutes)",
)
max_turns: int | None = Field(
default=None,
ge=1,
description="Optional default max-turn override for all subagents (None = keep builtin defaults)",
)
agents: dict[str, SubagentOverrideConfig] = Field(
default_factory=dict,
description="Per-agent configuration overrides keyed by agent name",
@ -44,6 +54,15 @@ class SubagentsAppConfig(BaseModel):
return override.timeout_seconds
return self.timeout_seconds
def get_max_turns_for(self, agent_name: str, builtin_default: int) -> int:
"""Get the effective max_turns for a specific agent."""
override = self.agents.get(agent_name)
if override is not None and override.max_turns is not None:
return override.max_turns
if self.max_turns is not None:
return self.max_turns
return builtin_default
_subagents_config: SubagentsAppConfig = SubagentsAppConfig()
@ -58,8 +77,26 @@ def load_subagents_config_from_dict(config_dict: dict) -> None:
global _subagents_config
_subagents_config = SubagentsAppConfig(**config_dict)
overrides_summary = {name: f"{override.timeout_seconds}s" for name, override in _subagents_config.agents.items() if override.timeout_seconds is not None}
overrides_summary = {}
for name, override in _subagents_config.agents.items():
parts = []
if override.timeout_seconds is not None:
parts.append(f"timeout={override.timeout_seconds}s")
if override.max_turns is not None:
parts.append(f"max_turns={override.max_turns}")
if parts:
overrides_summary[name] = ", ".join(parts)
if overrides_summary:
logger.info(f"Subagents config loaded: default timeout={_subagents_config.timeout_seconds}s, per-agent overrides={overrides_summary}")
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, per-agent overrides=%s",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
overrides_summary,
)
else:
logger.info(f"Subagents config loaded: default timeout={_subagents_config.timeout_seconds}s, no per-agent overrides")
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, no per-agent overrides",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
)

View File

@ -25,6 +25,7 @@ class MemoryStreamBridge(StreamBridge):
self._maxsize = queue_maxsize
self._queues: dict[str, asyncio.Queue[StreamEvent]] = {}
self._counters: dict[str, int] = {}
self._dropped_counts: dict[str, int] = {}
# -- helpers ---------------------------------------------------------------
@ -32,6 +33,7 @@ class MemoryStreamBridge(StreamBridge):
if run_id not in self._queues:
self._queues[run_id] = asyncio.Queue(maxsize=self._maxsize)
self._counters[run_id] = 0
self._dropped_counts[run_id] = 0
return self._queues[run_id]
def _next_id(self, run_id: str) -> str:
@ -48,14 +50,41 @@ class MemoryStreamBridge(StreamBridge):
try:
await asyncio.wait_for(queue.put(entry), timeout=_PUBLISH_TIMEOUT)
except TimeoutError:
logger.warning("Stream bridge queue full for run %s — dropping event %s", run_id, event)
self._dropped_counts[run_id] = self._dropped_counts.get(run_id, 0) + 1
logger.warning(
"Stream bridge queue full for run %s — dropping event %s (total dropped: %d)",
run_id,
event,
self._dropped_counts[run_id],
)
async def publish_end(self, run_id: str) -> None:
queue = self._get_or_create_queue(run_id)
try:
await asyncio.wait_for(queue.put(END_SENTINEL), timeout=_PUBLISH_TIMEOUT)
except TimeoutError:
logger.warning("Stream bridge queue full for run %s — dropping END sentinel", run_id)
# END sentinel is critical — it is the only signal that allows
# subscribers to terminate. If the queue is full we evict the
# oldest *regular* events to make room rather than dropping END,
# which would cause the SSE connection to hang forever and leak
# the queue/counter resources for this run_id.
if queue.full():
evicted = 0
while queue.full():
try:
queue.get_nowait()
evicted += 1
except asyncio.QueueEmpty:
break # pragma: no cover defensive
if evicted:
logger.warning(
"Stream bridge queue full for run %s — evicted %d event(s) to guarantee END sentinel delivery",
run_id,
evicted,
)
# After eviction the queue is guaranteed to have space, so a
# simple non-blocking put is safe. We still use put() (which
# blocks until space is available) as a defensive measure.
await queue.put(END_SENTINEL)
async def subscribe(
self,
@ -84,7 +113,18 @@ class MemoryStreamBridge(StreamBridge):
await asyncio.sleep(delay)
self._queues.pop(run_id, None)
self._counters.pop(run_id, None)
self._dropped_counts.pop(run_id, None)
async def close(self) -> None:
self._queues.clear()
self._counters.clear()
self._dropped_counts.clear()
def dropped_count(self, run_id: str) -> int:
"""Return the number of events dropped for *run_id*."""
return self._dropped_counts.get(run_id, 0)
@property
def dropped_total(self) -> int:
"""Return the total number of events dropped across all runs."""
return sum(self._dropped_counts.values())

View File

@ -1,72 +1,6 @@
import fnmatch
from pathlib import Path
IGNORE_PATTERNS = [
# Version Control
".git",
".svn",
".hg",
".bzr",
# Dependencies
"node_modules",
"__pycache__",
".venv",
"venv",
".env",
"env",
".tox",
".nox",
".eggs",
"*.egg-info",
"site-packages",
# Build outputs
"dist",
"build",
".next",
".nuxt",
".output",
".turbo",
"target",
"out",
# IDE & Editor
".idea",
".vscode",
"*.swp",
"*.swo",
"*~",
".project",
".classpath",
".settings",
# OS generated
".DS_Store",
"Thumbs.db",
"desktop.ini",
"*.lnk",
# Logs & temp files
"*.log",
"*.tmp",
"*.temp",
"*.bak",
"*.cache",
".cache",
"logs",
# Coverage & test artifacts
".coverage",
"coverage",
".nyc_output",
"htmlcov",
".pytest_cache",
".mypy_cache",
".ruff_cache",
]
def _should_ignore(name: str) -> bool:
"""Check if a file/directory name matches any ignore pattern."""
for pattern in IGNORE_PATTERNS:
if fnmatch.fnmatch(name, pattern):
return True
return False
from deerflow.sandbox.search import should_ignore_name
def list_dir(path: str, max_depth: int = 2) -> list[str]:
@ -95,7 +29,7 @@ def list_dir(path: str, max_depth: int = 2) -> list[str]:
try:
for item in current_path.iterdir():
if _should_ignore(item.name):
if should_ignore_name(item.name):
continue
post_fix = "/" if item.is_dir() else ""

View File

@ -1,11 +1,23 @@
import errno
import ntpath
import os
import shutil
import subprocess
from dataclasses import dataclass
from pathlib import Path
from deerflow.sandbox.local.list_dir import list_dir
from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.search import GrepMatch, find_glob_matches, find_grep_matches
@dataclass(frozen=True)
class PathMapping:
"""A path mapping from a container path to a local path with optional read-only flag."""
container_path: str
local_path: str
read_only: bool = False
class LocalSandbox(Sandbox):
@ -39,17 +51,42 @@ class LocalSandbox(Sandbox):
return None
def __init__(self, id: str, path_mappings: dict[str, str] | None = None):
def __init__(self, id: str, path_mappings: list[PathMapping] | None = None):
"""
Initialize local sandbox with optional path mappings.
Args:
id: Sandbox identifier
path_mappings: Dictionary mapping container paths to local paths
Example: {"/mnt/skills": "/absolute/path/to/skills"}
path_mappings: List of path mappings with optional read-only flag.
Skills directory is read-only by default.
"""
super().__init__(id)
self.path_mappings = path_mappings or {}
self.path_mappings = path_mappings or []
def _is_read_only_path(self, resolved_path: str) -> bool:
"""Check if a resolved path is under a read-only mount.
When multiple mappings match (nested mounts), prefer the most specific
mapping (i.e. the one whose local_path is the longest prefix of the
resolved path), similar to how ``_resolve_path`` handles container paths.
"""
resolved = str(Path(resolved_path).resolve())
best_mapping: PathMapping | None = None
best_prefix_len = -1
for mapping in self.path_mappings:
local_resolved = str(Path(mapping.local_path).resolve())
if resolved == local_resolved or resolved.startswith(local_resolved + os.sep):
prefix_len = len(local_resolved)
if prefix_len > best_prefix_len:
best_prefix_len = prefix_len
best_mapping = mapping
if best_mapping is None:
return False
return best_mapping.read_only
def _resolve_path(self, path: str) -> str:
"""
@ -64,7 +101,9 @@ class LocalSandbox(Sandbox):
path_str = str(path)
# Try each mapping (longest prefix first for more specific matches)
for container_path, local_path in sorted(self.path_mappings.items(), key=lambda x: len(x[0]), reverse=True):
for mapping in sorted(self.path_mappings, key=lambda m: len(m.container_path), reverse=True):
container_path = mapping.container_path
local_path = mapping.local_path
if path_str == container_path or path_str.startswith(container_path + "/"):
# Replace the container path prefix with local path
relative = path_str[len(container_path) :].lstrip("/")
@ -84,15 +123,16 @@ class LocalSandbox(Sandbox):
Returns:
Container path if mapping exists, otherwise original path
"""
path_str = str(Path(path).resolve())
normalized_path = path.replace("\\", "/")
path_str = str(Path(normalized_path).resolve())
# Try each mapping (longest local path first for more specific matches)
for container_path, local_path in sorted(self.path_mappings.items(), key=lambda x: len(x[1]), reverse=True):
local_path_resolved = str(Path(local_path).resolve())
if path_str.startswith(local_path_resolved):
for mapping in sorted(self.path_mappings, key=lambda m: len(m.local_path), reverse=True):
local_path_resolved = str(Path(mapping.local_path).resolve())
if path_str == local_path_resolved or path_str.startswith(local_path_resolved + "/"):
# Replace the local path prefix with container path
relative = path_str[len(local_path_resolved) :].lstrip("/")
resolved = f"{container_path}/{relative}" if relative else container_path
resolved = f"{mapping.container_path}/{relative}" if relative else mapping.container_path
return resolved
# No mapping found, return original path
@ -111,7 +151,7 @@ class LocalSandbox(Sandbox):
import re
# Sort mappings by local path length (longest first) for correct prefix matching
sorted_mappings = sorted(self.path_mappings.items(), key=lambda x: len(x[1]), reverse=True)
sorted_mappings = sorted(self.path_mappings, key=lambda m: len(m.local_path), reverse=True)
if not sorted_mappings:
return output
@ -119,12 +159,11 @@ class LocalSandbox(Sandbox):
# Create pattern that matches absolute paths
# Match paths like /Users/... or other absolute paths
result = output
for container_path, local_path in sorted_mappings:
local_path_resolved = str(Path(local_path).resolve())
for mapping in sorted_mappings:
# Escape the local path for use in regex
escaped_local = re.escape(local_path_resolved)
# Match the local path followed by optional path components
pattern = re.compile(escaped_local + r"(?:/[^\s\"';&|<>()]*)?")
escaped_local = re.escape(str(Path(mapping.local_path).resolve()))
# Match the local path followed by optional path components with either separator
pattern = re.compile(escaped_local + r"(?:[/\\][^\s\"';&|<>()]*)?")
def replace_match(match: re.Match) -> str:
matched_path = match.group(0)
@ -147,7 +186,7 @@ class LocalSandbox(Sandbox):
import re
# Sort mappings by length (longest first) for correct prefix matching
sorted_mappings = sorted(self.path_mappings.items(), key=lambda x: len(x[0]), reverse=True)
sorted_mappings = sorted(self.path_mappings, key=lambda m: len(m.container_path), reverse=True)
# Build regex pattern to match all container paths
# Match container path followed by optional path components
@ -157,7 +196,7 @@ class LocalSandbox(Sandbox):
# Create pattern that matches any of the container paths.
# The lookahead (?=/|$|...) ensures we only match at a path-segment boundary,
# preventing /mnt/skills from matching inside /mnt/skills-extra.
patterns = [re.escape(container_path) + r"(?=/|$|[\s\"';&|<>()])(?:/[^\s\"';&|<>()]*)?" for container_path, _ in sorted_mappings]
patterns = [re.escape(m.container_path) + r"(?=/|$|[\s\"';&|<>()])(?:/[^\s\"';&|<>()]*)?" for m in sorted_mappings]
pattern = re.compile("|".join(f"({p})" for p in patterns))
def replace_match(match: re.Match) -> str:
@ -248,6 +287,8 @@ class LocalSandbox(Sandbox):
def write_file(self, path: str, content: str, append: bool = False) -> None:
resolved_path = self._resolve_path(path)
if self._is_read_only_path(resolved_path):
raise OSError(errno.EROFS, "Read-only file system", path)
try:
dir_path = os.path.dirname(resolved_path)
if dir_path:
@ -259,8 +300,43 @@ class LocalSandbox(Sandbox):
# Re-raise with the original path for clearer error messages, hiding internal resolved paths
raise type(e)(e.errno, e.strerror, path) from None
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
resolved_path = Path(self._resolve_path(path))
matches, truncated = find_glob_matches(resolved_path, pattern, include_dirs=include_dirs, max_results=max_results)
return [self._reverse_resolve_path(match) for match in matches], truncated
def grep(
self,
path: str,
pattern: str,
*,
glob: str | None = None,
literal: bool = False,
case_sensitive: bool = False,
max_results: int = 100,
) -> tuple[list[GrepMatch], bool]:
resolved_path = Path(self._resolve_path(path))
matches, truncated = find_grep_matches(
resolved_path,
pattern,
glob_pattern=glob,
literal=literal,
case_sensitive=case_sensitive,
max_results=max_results,
)
return [
GrepMatch(
path=self._reverse_resolve_path(match.path),
line_number=match.line_number,
line=match.line,
)
for match in matches
], truncated
def update_file(self, path: str, content: bytes) -> None:
resolved_path = self._resolve_path(path)
if self._is_read_only_path(resolved_path):
raise OSError(errno.EROFS, "Read-only file system", path)
try:
dir_path = os.path.dirname(resolved_path)
if dir_path:

View File

@ -1,6 +1,7 @@
import logging
from pathlib import Path
from deerflow.sandbox.local.local_sandbox import LocalSandbox
from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping
from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import SandboxProvider
@ -14,16 +15,17 @@ class LocalSandboxProvider(SandboxProvider):
"""Initialize the local sandbox provider with path mappings."""
self._path_mappings = self._setup_path_mappings()
def _setup_path_mappings(self) -> dict[str, str]:
def _setup_path_mappings(self) -> list[PathMapping]:
"""
Setup path mappings for local sandbox.
Maps container paths to actual local paths, including skills directory.
Maps container paths to actual local paths, including skills directory
and any custom mounts configured in config.yaml.
Returns:
Dictionary of path mappings
List of path mappings
"""
mappings = {}
mappings: list[PathMapping] = []
# Map skills container path to local skills directory
try:
@ -35,10 +37,63 @@ class LocalSandboxProvider(SandboxProvider):
# Only add mapping if skills directory exists
if skills_path.exists():
mappings[container_path] = str(skills_path)
mappings.append(
PathMapping(
container_path=container_path,
local_path=str(skills_path),
read_only=True, # Skills directory is always read-only
)
)
# Map custom mounts from sandbox config
_RESERVED_CONTAINER_PREFIXES = [container_path, "/mnt/acp-workspace", "/mnt/user-data"]
sandbox_config = config.sandbox
if sandbox_config and sandbox_config.mounts:
for mount in sandbox_config.mounts:
host_path = Path(mount.host_path)
container_path = mount.container_path.rstrip("/") or "/"
if not host_path.is_absolute():
logger.warning(
"Mount host_path must be absolute, skipping: %s -> %s",
mount.host_path,
mount.container_path,
)
continue
if not container_path.startswith("/"):
logger.warning(
"Mount container_path must be absolute, skipping: %s -> %s",
mount.host_path,
mount.container_path,
)
continue
# Reject mounts that conflict with reserved container paths
if any(container_path == p or container_path.startswith(p + "/") for p in _RESERVED_CONTAINER_PREFIXES):
logger.warning(
"Mount container_path conflicts with reserved prefix, skipping: %s",
mount.container_path,
)
continue
# Ensure the host path exists before adding mapping
if host_path.exists():
mappings.append(
PathMapping(
container_path=container_path,
local_path=str(host_path.resolve()),
read_only=mount.read_only,
)
)
else:
logger.warning(
"Mount host_path does not exist, skipping: %s -> %s",
mount.host_path,
mount.container_path,
)
except Exception as e:
# Log but don't fail if config loading fails
logger.warning("Could not setup skills path mapping: %s", e, exc_info=True)
logger.warning("Could not setup path mappings: %s", e, exc_info=True)
return mappings

View File

@ -1,5 +1,7 @@
from abc import ABC, abstractmethod
from deerflow.sandbox.search import GrepMatch
class Sandbox(ABC):
"""Abstract base class for sandbox environments"""
@ -61,6 +63,25 @@ class Sandbox(ABC):
"""
pass
@abstractmethod
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
"""Find paths that match a glob pattern under a root directory."""
pass
@abstractmethod
def grep(
self,
path: str,
pattern: str,
*,
glob: str | None = None,
literal: bool = False,
case_sensitive: bool = False,
max_results: int = 100,
) -> tuple[list[GrepMatch], bool]:
"""Search for matches inside text files under a directory."""
pass
@abstractmethod
def update_file(self, path: str, content: bytes) -> None:
"""Update a file with binary content.

View File

@ -0,0 +1,210 @@
import fnmatch
import os
import re
from dataclasses import dataclass
from pathlib import Path, PurePosixPath
IGNORE_PATTERNS = [
".git",
".svn",
".hg",
".bzr",
"node_modules",
"__pycache__",
".venv",
"venv",
".env",
"env",
".tox",
".nox",
".eggs",
"*.egg-info",
"site-packages",
"dist",
"build",
".next",
".nuxt",
".output",
".turbo",
"target",
"out",
".idea",
".vscode",
"*.swp",
"*.swo",
"*~",
".project",
".classpath",
".settings",
".DS_Store",
"Thumbs.db",
"desktop.ini",
"*.lnk",
"*.log",
"*.tmp",
"*.temp",
"*.bak",
"*.cache",
".cache",
"logs",
".coverage",
"coverage",
".nyc_output",
"htmlcov",
".pytest_cache",
".mypy_cache",
".ruff_cache",
]
DEFAULT_MAX_FILE_SIZE_BYTES = 1_000_000
DEFAULT_LINE_SUMMARY_LENGTH = 200
@dataclass(frozen=True)
class GrepMatch:
path: str
line_number: int
line: str
def should_ignore_name(name: str) -> bool:
for pattern in IGNORE_PATTERNS:
if fnmatch.fnmatch(name, pattern):
return True
return False
def should_ignore_path(path: str) -> bool:
return any(should_ignore_name(segment) for segment in path.replace("\\", "/").split("/") if segment)
def path_matches(pattern: str, rel_path: str) -> bool:
path = PurePosixPath(rel_path)
if path.match(pattern):
return True
if pattern.startswith("**/"):
return path.match(pattern[3:])
return False
def truncate_line(line: str, max_chars: int = DEFAULT_LINE_SUMMARY_LENGTH) -> str:
line = line.rstrip("\n\r")
if len(line) <= max_chars:
return line
return line[: max_chars - 3] + "..."
def is_binary_file(path: Path, sample_size: int = 8192) -> bool:
try:
with path.open("rb") as handle:
return b"\0" in handle.read(sample_size)
except OSError:
return True
def find_glob_matches(root: Path, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
matches: list[str] = []
truncated = False
root = root.resolve()
if not root.exists():
raise FileNotFoundError(root)
if not root.is_dir():
raise NotADirectoryError(root)
for current_root, dirs, files in os.walk(root):
dirs[:] = [name for name in dirs if not should_ignore_name(name)]
# root is already resolved; os.walk builds current_root by joining under root,
# so relative_to() works without an extra stat()/resolve() per directory.
rel_dir = Path(current_root).relative_to(root)
if include_dirs:
for name in dirs:
rel_path = (rel_dir / name).as_posix()
if path_matches(pattern, rel_path):
matches.append(str(Path(current_root) / name))
if len(matches) >= max_results:
truncated = True
return matches, truncated
for name in files:
if should_ignore_name(name):
continue
rel_path = (rel_dir / name).as_posix()
if path_matches(pattern, rel_path):
matches.append(str(Path(current_root) / name))
if len(matches) >= max_results:
truncated = True
return matches, truncated
return matches, truncated
def find_grep_matches(
root: Path,
pattern: str,
*,
glob_pattern: str | None = None,
literal: bool = False,
case_sensitive: bool = False,
max_results: int = 100,
max_file_size: int = DEFAULT_MAX_FILE_SIZE_BYTES,
line_summary_length: int = DEFAULT_LINE_SUMMARY_LENGTH,
) -> tuple[list[GrepMatch], bool]:
matches: list[GrepMatch] = []
truncated = False
root = root.resolve()
if not root.exists():
raise FileNotFoundError(root)
if not root.is_dir():
raise NotADirectoryError(root)
regex_source = re.escape(pattern) if literal else pattern
flags = 0 if case_sensitive else re.IGNORECASE
regex = re.compile(regex_source, flags)
# Skip lines longer than this to prevent ReDoS on minified / no-newline files.
_max_line_chars = line_summary_length * 10
for current_root, dirs, files in os.walk(root):
dirs[:] = [name for name in dirs if not should_ignore_name(name)]
rel_dir = Path(current_root).relative_to(root)
for name in files:
if should_ignore_name(name):
continue
candidate_path = Path(current_root) / name
rel_path = (rel_dir / name).as_posix()
if glob_pattern is not None and not path_matches(glob_pattern, rel_path):
continue
try:
if candidate_path.is_symlink():
continue
file_path = candidate_path.resolve()
if not file_path.is_relative_to(root):
continue
if file_path.stat().st_size > max_file_size or is_binary_file(file_path):
continue
with file_path.open(encoding="utf-8", errors="replace") as handle:
for line_number, line in enumerate(handle, start=1):
if len(line) > _max_line_chars:
continue
if regex.search(line):
matches.append(
GrepMatch(
path=str(file_path),
line_number=line_number,
line=truncate_line(line, line_summary_length),
)
)
if len(matches) >= max_results:
truncated = True
return matches, truncated
except OSError:
continue
return matches, truncated

View File

@ -7,6 +7,7 @@ from langchain.tools import ToolRuntime, tool
from langgraph.typing import ContextT
from deerflow.agents.thread_state import ThreadDataState, ThreadState
from deerflow.config import get_app_config
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
from deerflow.sandbox.exceptions import (
SandboxError,
@ -16,6 +17,7 @@ from deerflow.sandbox.exceptions import (
from deerflow.sandbox.file_operation_lock import get_file_operation_lock
from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
from deerflow.sandbox.search import GrepMatch
from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_bash_allowed
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])(?<!:/)/(?:[^\s\"'`;&|<>()]+)")
@ -31,6 +33,10 @@ _LOCAL_BASH_SYSTEM_PATH_PREFIXES = (
_DEFAULT_SKILLS_CONTAINER_PATH = "/mnt/skills"
_ACP_WORKSPACE_VIRTUAL_PATH = "/mnt/acp-workspace"
_DEFAULT_GLOB_MAX_RESULTS = 200
_MAX_GLOB_MAX_RESULTS = 1000
_DEFAULT_GREP_MAX_RESULTS = 100
_MAX_GREP_MAX_RESULTS = 500
def _get_skills_container_path() -> str:
@ -113,6 +119,54 @@ def _is_acp_workspace_path(path: str) -> bool:
return path == _ACP_WORKSPACE_VIRTUAL_PATH or path.startswith(f"{_ACP_WORKSPACE_VIRTUAL_PATH}/")
def _get_custom_mounts():
"""Get custom volume mounts from sandbox config.
Result is cached after the first successful config load. If config loading
fails an empty list is returned *without* caching so that a later call can
pick up the real value once the config is available.
"""
cached = getattr(_get_custom_mounts, "_cached", None)
if cached is not None:
return cached
try:
from pathlib import Path
from deerflow.config import get_app_config
config = get_app_config()
mounts = []
if config.sandbox and config.sandbox.mounts:
# Only include mounts whose host_path exists, consistent with
# LocalSandboxProvider._setup_path_mappings() which also filters
# by host_path.exists().
mounts = [m for m in config.sandbox.mounts if Path(m.host_path).exists()]
_get_custom_mounts._cached = mounts # type: ignore[attr-defined]
return mounts
except Exception:
# If config loading fails, return an empty list without caching so that
# a later call can retry once the config is available.
return []
def _is_custom_mount_path(path: str) -> bool:
"""Check if path is under a custom mount container_path."""
for mount in _get_custom_mounts():
if path == mount.container_path or path.startswith(f"{mount.container_path}/"):
return True
return False
def _get_custom_mount_for_path(path: str):
"""Get the mount config matching this path (longest prefix first)."""
best = None
for mount in _get_custom_mounts():
if path == mount.container_path or path.startswith(f"{mount.container_path}/"):
if best is None or len(mount.container_path) > len(best.container_path):
best = mount
return best
def _extract_thread_id_from_thread_data(thread_data: "ThreadDataState | None") -> str | None:
"""Extract thread_id from thread_data by inspecting workspace_path.
@ -245,16 +299,84 @@ def _get_mcp_allowed_paths() -> list[str]:
return allowed_paths
def _get_tool_config_int(name: str, key: str, default: int) -> int:
try:
tool_config = get_app_config().get_tool_config(name)
if tool_config is not None and key in tool_config.model_extra:
value = tool_config.model_extra.get(key)
if isinstance(value, int):
return value
except Exception:
pass
return default
def _clamp_max_results(value: int, *, default: int, upper_bound: int) -> int:
if value <= 0:
return default
return min(value, upper_bound)
def _resolve_max_results(name: str, requested: int, *, default: int, upper_bound: int) -> int:
requested_max_results = _clamp_max_results(requested, default=default, upper_bound=upper_bound)
configured_max_results = _clamp_max_results(
_get_tool_config_int(name, "max_results", default),
default=default,
upper_bound=upper_bound,
)
return min(requested_max_results, configured_max_results)
def _resolve_local_read_path(path: str, thread_data: ThreadDataState) -> str:
validate_local_tool_path(path, thread_data, read_only=True)
if _is_skills_path(path):
return _resolve_skills_path(path)
if _is_acp_workspace_path(path):
return _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
return _resolve_and_validate_user_data_path(path, thread_data)
def _format_glob_results(root_path: str, matches: list[str], truncated: bool) -> str:
if not matches:
return f"No files matched under {root_path}"
lines = [f"Found {len(matches)} paths under {root_path}"]
if truncated:
lines[0] += f" (showing first {len(matches)})"
lines.extend(f"{index}. {path}" for index, path in enumerate(matches, start=1))
if truncated:
lines.append("Results truncated. Narrow the path or pattern to see fewer matches.")
return "\n".join(lines)
def _format_grep_results(root_path: str, matches: list[GrepMatch], truncated: bool) -> str:
if not matches:
return f"No matches found under {root_path}"
lines = [f"Found {len(matches)} matches under {root_path}"]
if truncated:
lines[0] += f" (showing first {len(matches)})"
lines.extend(f"{match.path}:{match.line_number}: {match.line}" for match in matches)
if truncated:
lines.append("Results truncated. Narrow the path or add a glob filter.")
return "\n".join(lines)
def _path_variants(path: str) -> set[str]:
return {path, path.replace("\\", "/"), path.replace("/", "\\")}
def _path_separator_for_style(path: str) -> str:
return "\\" if "\\" in path and "/" not in path else "/"
def _join_path_preserving_style(base: str, relative: str) -> str:
if not relative:
return base
if "/" in base and "\\" not in base:
return f"{base.rstrip('/')}/{relative}"
return str(Path(base) / relative)
separator = _path_separator_for_style(base)
normalized_relative = relative.replace("\\" if separator == "/" else "/", separator).lstrip("/\\")
stripped_base = base.rstrip("/\\")
return f"{stripped_base}{separator}{normalized_relative}"
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str:
@ -299,7 +421,10 @@ def replace_virtual_path(path: str, thread_data: ThreadDataState | None) -> str:
return actual_base
if path.startswith(f"{virtual_base}/"):
rest = path[len(virtual_base) :].lstrip("/")
return _join_path_preserving_style(actual_base, rest)
result = _join_path_preserving_style(actual_base, rest)
if path.endswith("/") and not result.endswith(("/", "\\")):
result += _path_separator_for_style(actual_base)
return result
return path
@ -379,6 +504,8 @@ def mask_local_paths_in_output(output: str, thread_data: ThreadDataState | None)
result = pattern.sub(replace_acp, result)
# Custom mount host paths are masked by LocalSandbox._reverse_resolve_paths_in_output()
# Mask user-data host paths
if thread_data is None:
return result
@ -427,6 +554,7 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
- ``/mnt/user-data/*`` always allowed (read + write)
- ``/mnt/skills/*`` allowed only when *read_only* is True
- ``/mnt/acp-workspace/*`` allowed only when *read_only* is True
- Custom mount paths (from config.yaml) respects per-mount ``read_only`` flag
Args:
path: The virtual path to validate.
@ -458,7 +586,14 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
if path.startswith(f"{VIRTUAL_PATH_PREFIX}/"):
return
raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/, {_get_skills_container_path()}/, or {_ACP_WORKSPACE_VIRTUAL_PATH}/ are allowed")
# Custom mount paths — respect read_only config
if _is_custom_mount_path(path):
mount = _get_custom_mount_for_path(path)
if mount and mount.read_only and not read_only:
raise PermissionError(f"Write access to read-only mount is not allowed: {path}")
return
raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/, {_get_skills_container_path()}/, {_ACP_WORKSPACE_VIRTUAL_PATH}/, or configured mount paths are allowed")
def _validate_resolved_user_data_path(resolved: Path, thread_data: ThreadDataState) -> None:
@ -508,9 +643,10 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
boundary and must not be treated as isolation from the host filesystem.
In local mode, commands must use virtual paths under /mnt/user-data for
user data access. Skills paths under /mnt/skills and ACP workspace paths
under /mnt/acp-workspace are allowed (path-traversal checks only; write
prevention for bash commands is not enforced here).
user data access. Skills paths under /mnt/skills, ACP workspace paths
under /mnt/acp-workspace, and custom mount container paths (configured in
config.yaml) are allowed (path-traversal checks only; write prevention
for bash commands is not enforced here).
A small allowlist of common system path prefixes is kept for executable
and device references (e.g. /bin/sh, /dev/null).
"""
@ -545,6 +681,11 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
_reject_path_traversal(absolute_path)
continue
# Allow custom mount container paths
if _is_custom_mount_path(absolute_path):
_reject_path_traversal(absolute_path)
continue
if any(absolute_path == prefix.rstrip("/") or absolute_path.startswith(prefix) for prefix in _LOCAL_BASH_SYSTEM_PATH_PREFIXES):
continue
@ -589,6 +730,8 @@ def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState
result = acp_pattern.sub(replace_acp_match, result)
# Custom mount paths are resolved by LocalSandbox._resolve_paths_in_command()
# Replace user-data paths
if VIRTUAL_PATH_PREFIX in result and thread_data is not None:
pattern = re.compile(rf"{re.escape(VIRTUAL_PATH_PREFIX)}(/[^\s\"';&|<>()]*)?")
@ -666,7 +809,8 @@ def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = No
if sandbox is None:
raise SandboxNotFoundError(f"Sandbox with ID '{sandbox_id}' not found", sandbox_id=sandbox_id)
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for downstream use
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for downstream use
return sandbox
@ -701,7 +845,8 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
if sandbox_id is not None:
sandbox = get_sandbox_provider().get(sandbox_id)
if sandbox is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
return sandbox
# Sandbox was released, fall through to acquire new one
@ -723,7 +868,8 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
if sandbox is None:
raise SandboxNotFoundError("Sandbox not found after acquisition", sandbox_id=sandbox_id)
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
if runtime.context is not None:
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
return sandbox
@ -885,8 +1031,9 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
path = _resolve_skills_path(path)
elif _is_acp_workspace_path(path):
path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
else:
elif not _is_custom_mount_path(path):
path = _resolve_and_validate_user_data_path(path, thread_data)
# Custom mount paths are resolved by LocalSandbox._resolve_path()
children = sandbox.list_dir(path)
if not children:
return "(empty)"
@ -901,6 +1048,126 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
return f"Error: Unexpected error listing directory: {_sanitize_error(e, runtime)}"
@tool("glob", parse_docstring=True)
def glob_tool(
runtime: ToolRuntime[ContextT, ThreadState],
description: str,
pattern: str,
path: str,
include_dirs: bool = False,
max_results: int = _DEFAULT_GLOB_MAX_RESULTS,
) -> str:
"""Find files or directories that match a glob pattern under a root directory.
Args:
description: Explain why you are searching for these paths in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
pattern: The glob pattern to match relative to the root path, for example `**/*.py`.
path: The **absolute** root directory to search under.
include_dirs: Whether matching directories should also be returned. Default is False.
max_results: Maximum number of paths to return. Default is 200.
"""
try:
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
requested_path = path
effective_max_results = _resolve_max_results(
"glob",
max_results,
default=_DEFAULT_GLOB_MAX_RESULTS,
upper_bound=_MAX_GLOB_MAX_RESULTS,
)
thread_data = None
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
if thread_data is None:
raise SandboxRuntimeError("Thread data not available for local sandbox")
path = _resolve_local_read_path(path, thread_data)
matches, truncated = sandbox.glob(path, pattern, include_dirs=include_dirs, max_results=effective_max_results)
if thread_data is not None:
matches = [mask_local_paths_in_output(match, thread_data) for match in matches]
return _format_glob_results(requested_path, matches, truncated)
except SandboxError as e:
return f"Error: {e}"
except FileNotFoundError:
return f"Error: Directory not found: {requested_path}"
except NotADirectoryError:
return f"Error: Path is not a directory: {requested_path}"
except PermissionError:
return f"Error: Permission denied: {requested_path}"
except Exception as e:
return f"Error: Unexpected error searching paths: {_sanitize_error(e, runtime)}"
@tool("grep", parse_docstring=True)
def grep_tool(
runtime: ToolRuntime[ContextT, ThreadState],
description: str,
pattern: str,
path: str,
glob: str | None = None,
literal: bool = False,
case_sensitive: bool = False,
max_results: int = _DEFAULT_GREP_MAX_RESULTS,
) -> str:
"""Search for matching lines inside text files under a root directory.
Args:
description: Explain why you are searching file contents in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
pattern: The string or regex pattern to search for.
path: The **absolute** root directory to search under.
glob: Optional glob filter for candidate files, for example `**/*.py`.
literal: Whether to treat `pattern` as a plain string. Default is False.
case_sensitive: Whether matching is case-sensitive. Default is False.
max_results: Maximum number of matching lines to return. Default is 100.
"""
try:
sandbox = ensure_sandbox_initialized(runtime)
ensure_thread_directories_exist(runtime)
requested_path = path
effective_max_results = _resolve_max_results(
"grep",
max_results,
default=_DEFAULT_GREP_MAX_RESULTS,
upper_bound=_MAX_GREP_MAX_RESULTS,
)
thread_data = None
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
if thread_data is None:
raise SandboxRuntimeError("Thread data not available for local sandbox")
path = _resolve_local_read_path(path, thread_data)
matches, truncated = sandbox.grep(
path,
pattern,
glob=glob,
literal=literal,
case_sensitive=case_sensitive,
max_results=effective_max_results,
)
if thread_data is not None:
matches = [
GrepMatch(
path=mask_local_paths_in_output(match.path, thread_data),
line_number=match.line_number,
line=match.line,
)
for match in matches
]
return _format_grep_results(requested_path, matches, truncated)
except SandboxError as e:
return f"Error: {e}"
except FileNotFoundError:
return f"Error: Directory not found: {requested_path}"
except NotADirectoryError:
return f"Error: Path is not a directory: {requested_path}"
except re.error as e:
return f"Error: Invalid regex pattern: {e}"
except PermissionError:
return f"Error: Permission denied: {requested_path}"
except Exception as e:
return f"Error: Unexpected error searching file contents: {_sanitize_error(e, runtime)}"
@tool("read_file", parse_docstring=True)
def read_file_tool(
runtime: ToolRuntime[ContextT, ThreadState],
@ -928,8 +1195,9 @@ def read_file_tool(
path = _resolve_skills_path(path)
elif _is_acp_workspace_path(path):
path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
else:
elif not _is_custom_mount_path(path):
path = _resolve_and_validate_user_data_path(path, thread_data)
# Custom mount paths are resolved by LocalSandbox._resolve_path()
content = sandbox.read_file(path)
if not content:
return "(empty)"
@ -977,7 +1245,9 @@ def write_file_tool(
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data)
path = _resolve_and_validate_user_data_path(path, thread_data)
if not _is_custom_mount_path(path):
path = _resolve_and_validate_user_data_path(path, thread_data)
# Custom mount paths are resolved by LocalSandbox._resolve_path()
with get_file_operation_lock(sandbox, path):
sandbox.write_file(path, content, append)
return "OK"
@ -1019,7 +1289,9 @@ def str_replace_tool(
if is_local_sandbox(runtime):
thread_data = get_thread_data(runtime)
validate_local_tool_path(path, thread_data)
path = _resolve_and_validate_user_data_path(path, thread_data)
if not _is_custom_mount_path(path):
path = _resolve_and_validate_user_data_path(path, thread_data)
# Custom mount paths are resolved by LocalSandbox._resolve_path()
with get_file_operation_lock(sandbox, path):
content = sandbox.read_file(path)
if not content:

View File

@ -43,5 +43,5 @@ You have access to the sandbox environment:
tools=["bash", "ls", "read_file", "write_file", "str_replace"], # Sandbox tools only
disallowed_tools=["task", "ask_clarification", "present_files"],
model="inherit",
max_turns=30,
max_turns=60,
)

View File

@ -44,5 +44,5 @@ You have access to the same sandbox environment as the parent agent:
tools=None, # Inherit all tools from parent
disallowed_tools=["task", "ask_clarification", "present_files"], # Prevent nesting and clarification
model="inherit",
max_turns=50,
max_turns=100,
)

View File

@ -28,9 +28,27 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
app_config = get_subagents_app_config()
effective_timeout = app_config.get_timeout_for(name)
effective_max_turns = app_config.get_max_turns_for(name, config.max_turns)
overrides = {}
if effective_timeout != config.timeout_seconds:
logger.debug(f"Subagent '{name}': timeout overridden by config.yaml ({config.timeout_seconds}s -> {effective_timeout}s)")
config = replace(config, timeout_seconds=effective_timeout)
logger.debug(
"Subagent '%s': timeout overridden by config.yaml (%ss -> %ss)",
name,
config.timeout_seconds,
effective_timeout,
)
overrides["timeout_seconds"] = effective_timeout
if effective_max_turns != config.max_turns:
logger.debug(
"Subagent '%s': max_turns overridden by config.yaml (%s -> %s)",
name,
config.max_turns,
effective_max_turns,
)
overrides["max_turns"] = effective_max_turns
if overrides:
config = replace(config, **overrides)
return config

View File

@ -57,6 +57,42 @@ def _build_mcp_servers() -> dict[str, dict[str, Any]]:
return build_servers_config(ExtensionsConfig.from_file())
def _build_acp_mcp_servers() -> list[dict[str, Any]]:
"""Build ACP ``mcpServers`` payload for ``new_session``.
The ACP client expects a list of server objects, while DeerFlow's MCP helper
returns a name -> config mapping for the LangChain MCP adapter. This helper
converts the enabled servers into the ACP wire format.
"""
from deerflow.config.extensions_config import ExtensionsConfig
extensions_config = ExtensionsConfig.from_file()
enabled_servers = extensions_config.get_enabled_mcp_servers()
mcp_servers: list[dict[str, Any]] = []
for name, server_config in enabled_servers.items():
transport_type = server_config.type or "stdio"
payload: dict[str, Any] = {"name": name, "type": transport_type}
if transport_type == "stdio":
if not server_config.command:
raise ValueError(f"MCP server '{name}' with stdio transport requires 'command' field")
payload["command"] = server_config.command
payload["args"] = server_config.args
payload["env"] = [{"name": key, "value": value} for key, value in server_config.env.items()]
elif transport_type in ("http", "sse"):
if not server_config.url:
raise ValueError(f"MCP server '{name}' with {transport_type} transport requires 'url' field")
payload["url"] = server_config.url
payload["headers"] = [{"name": key, "value": value} for key, value in server_config.headers.items()]
else:
raise ValueError(f"MCP server '{name}' has unsupported transport type: {transport_type}")
mcp_servers.append(payload)
return mcp_servers
def _build_permission_response(options: list[Any], *, auto_approve: bool) -> Any:
"""Build an ACP permission response.
@ -173,7 +209,15 @@ def build_invoke_acp_agent_tool(agents: dict) -> BaseTool:
cmd = agent_config.command
args = agent_config.args or []
physical_cwd = _get_work_dir(thread_id)
mcp_servers = _build_mcp_servers()
try:
mcp_servers = _build_acp_mcp_servers()
except ValueError as exc:
logger.warning(
"Invalid MCP server configuration for ACP agent '%s'; continuing without MCP servers: %s",
agent,
exc,
)
mcp_servers = []
agent_env: dict[str, str] | None = None
if agent_config.env:
agent_env = {k: (os.environ.get(v[1:], "") if v.startswith("$") else v) for k, v in agent_config.env.items()}

View File

@ -1,10 +1,22 @@
"""File conversion utilities.
Converts document files (PDF, PPT, Excel, Word) to Markdown using markitdown.
Converts document files (PDF, PPT, Excel, Word) to Markdown.
PDF conversion strategy (auto mode):
1. Try pymupdf4llm if installed better heading detection, faster on most files.
2. If output is suspiciously short (< _MIN_CHARS_PER_PAGE chars/page, or < 200 chars
total when page count is unavailable), treat as image-based and fall back to MarkItDown.
3. If pymupdf4llm is not installed, use MarkItDown directly (existing behaviour).
Large files (> ASYNC_THRESHOLD_BYTES) are converted in a thread pool via
asyncio.to_thread() to avoid blocking the event loop (fixes #1569).
No FastAPI or HTTP dependencies pure utility functions.
"""
import asyncio
import logging
import re
from pathlib import Path
logger = logging.getLogger(__name__)
@ -20,28 +32,278 @@ CONVERTIBLE_EXTENSIONS = {
".docx",
}
# Files larger than this threshold are converted in a background thread.
# Small files complete in < 1s synchronously; spawning a thread adds unnecessary
# scheduling overhead for them.
_ASYNC_THRESHOLD_BYTES = 1 * 1024 * 1024 # 1 MB
# If pymupdf4llm produces fewer characters *per page* than this threshold,
# the PDF is likely image-based or encrypted — fall back to MarkItDown.
# Rationale: normal text PDFs yield 200-2000 chars/page; image-based PDFs
# yield close to 0. 50 chars/page gives a wide safety margin.
# Falls back to absolute 200-char check when page count is unavailable.
_MIN_CHARS_PER_PAGE = 50
def _pymupdf_output_too_sparse(text: str, file_path: Path) -> bool:
"""Return True if pymupdf4llm output is suspiciously short (image-based PDF).
Uses chars-per-page rather than an absolute threshold so that both short
documents (few pages, few chars) and long documents (many pages, many chars)
are handled correctly.
"""
chars = len(text.strip())
doc = None
pages: int | None = None
try:
import pymupdf
doc = pymupdf.open(str(file_path))
pages = len(doc)
except Exception:
pass
finally:
if doc is not None:
try:
doc.close()
except Exception:
pass
if pages is not None and pages > 0:
return (chars / pages) < _MIN_CHARS_PER_PAGE
# Fallback: absolute threshold when page count is unavailable
return chars < 200
def _convert_pdf_with_pymupdf4llm(file_path: Path) -> str | None:
"""Attempt PDF conversion with pymupdf4llm.
Returns the markdown text, or None if pymupdf4llm is not installed or
if conversion fails (e.g. encrypted/corrupt PDF).
"""
try:
import pymupdf4llm
except ImportError:
return None
try:
return pymupdf4llm.to_markdown(str(file_path))
except Exception:
logger.exception("pymupdf4llm failed to convert %s; falling back to MarkItDown", file_path.name)
return None
def _convert_with_markitdown(file_path: Path) -> str:
"""Convert any supported file to markdown text using MarkItDown."""
from markitdown import MarkItDown
md = MarkItDown()
return md.convert(str(file_path)).text_content
def _do_convert(file_path: Path, pdf_converter: str) -> str:
"""Synchronous conversion — called directly or via asyncio.to_thread.
Args:
file_path: Path to the file.
pdf_converter: "auto" | "pymupdf4llm" | "markitdown"
"""
is_pdf = file_path.suffix.lower() == ".pdf"
if is_pdf and pdf_converter != "markitdown":
# Try pymupdf4llm first (auto or explicit)
pymupdf_text = _convert_pdf_with_pymupdf4llm(file_path)
if pymupdf_text is not None:
# pymupdf4llm is installed
if pdf_converter == "pymupdf4llm":
# Explicit — use as-is regardless of output length
return pymupdf_text
# auto mode: fall back if output looks like a failed parse.
# Use chars-per-page to distinguish image-based PDFs (near 0) from
# legitimately short documents.
if not _pymupdf_output_too_sparse(pymupdf_text, file_path):
return pymupdf_text
logger.warning(
"pymupdf4llm produced only %d chars for %s (likely image-based PDF); falling back to MarkItDown",
len(pymupdf_text.strip()),
file_path.name,
)
# pymupdf4llm not installed or fallback triggered → use MarkItDown
return _convert_with_markitdown(file_path)
async def convert_file_to_markdown(file_path: Path) -> Path | None:
"""Convert a file to markdown using markitdown.
"""Convert a supported document file to Markdown.
PDF files are handled with a two-converter strategy (see module docstring).
Large files (> 1 MB) are offloaded to a thread pool to avoid blocking the
event loop.
Args:
file_path: Path to the file to convert.
Returns:
Path to the markdown file if conversion was successful, None otherwise.
Path to the generated .md file, or None if conversion failed.
"""
try:
from markitdown import MarkItDown
pdf_converter = _get_pdf_converter()
file_size = file_path.stat().st_size
md = MarkItDown()
result = md.convert(str(file_path))
if file_size > _ASYNC_THRESHOLD_BYTES:
text = await asyncio.to_thread(_do_convert, file_path, pdf_converter)
else:
text = _do_convert(file_path, pdf_converter)
# Save as .md file with same name
md_path = file_path.with_suffix(".md")
md_path.write_text(result.text_content, encoding="utf-8")
md_path.write_text(text, encoding="utf-8")
logger.info(f"Converted {file_path.name} to markdown: {md_path.name}")
logger.info("Converted %s to markdown: %s (%d chars)", file_path.name, md_path.name, len(text))
return md_path
except Exception as e:
logger.error(f"Failed to convert {file_path.name} to markdown: {e}")
logger.error("Failed to convert %s to markdown: %s", file_path.name, e)
return None
# Regex for bold-only lines that look like section headings.
# Targets SEC filing structural headings that pymupdf4llm renders as **bold**
# rather than # Markdown headings (because they use same font size as body text,
# distinguished only by bold+caps formatting).
#
# Pattern requires ALL of:
# 1. Entire line is a single **...** block (no surrounding prose)
# 2. Starts with a recognised structural keyword:
# - ITEM / PART / SECTION (with optional number/letter after)
# - SCHEDULE, EXHIBIT, APPENDIX, ANNEX, CHAPTER
# All-caps addresses, boilerplate ("CURRENT REPORT", "SIGNATURES",
# "WASHINGTON, DC 20549") do NOT start with these keywords and are excluded.
#
# Chinese headings (第三节...) are already captured as standard # headings
# by pymupdf4llm, so they don't need this pattern.
_BOLD_HEADING_RE = re.compile(r"^\*\*((ITEM|PART|SECTION|SCHEDULE|EXHIBIT|APPENDIX|ANNEX|CHAPTER)\b[A-Z0-9 .,\-]*)\*\*\s*$")
# Regex for split-bold headings produced by pymupdf4llm when a heading spans
# multiple text spans in the PDF (e.g. section number and title are separate spans).
# Matches lines like: **1** **Introduction** or **3.2** **Multi-Head Attention**
# Requirements:
# 1. Entire line consists only of **...** blocks separated by whitespace (no prose)
# 2. First block is a section number (digits and dots, e.g. "1", "3.2", "A.1")
# 3. Second block must not be purely numeric/punctuation — excludes financial table
# headers like **2023** **2022** **2021** while allowing non-ASCII titles such as
# **1** **概述** or accented words (negative lookahead instead of [A-Za-z])
# 4. At most two additional blocks (four total) with [^*]+ (no * inside) to keep
# the regex linear and avoid ReDoS on attacker-controlled content
_SPLIT_BOLD_HEADING_RE = re.compile(r"^\*\*[\dA-Z][\d\.]*\*\*\s+\*\*(?!\d[\d\s.,\-–—/:()%]*\*\*)[^*]+\*\*(?:\s+\*\*[^*]+\*\*){0,2}\s*$")
# Maximum number of outline entries injected into the agent context.
# Keeps prompt size bounded even for very long documents.
MAX_OUTLINE_ENTRIES = 50
_ALLOWED_PDF_CONVERTERS = {"auto", "pymupdf4llm", "markitdown"}
def _clean_bold_title(raw: str) -> str:
"""Normalise a title string that may contain pymupdf4llm bold artefacts.
pymupdf4llm sometimes emits adjacent bold spans as ``**A** **B**`` instead
of a single ``**A B**`` block. This helper merges those fragments and then
strips the outermost ``**...**`` wrapper so the caller gets plain text.
Examples::
"**Overview**" "Overview"
"**UNITED STATES** **SECURITIES**" "UNITED STATES SECURITIES"
"plain text" "plain text" (unchanged)
"""
# Merge adjacent bold spans: "** **" → " "
merged = re.sub(r"\*\*\s*\*\*", " ", raw).strip()
# Strip outermost **...** if the whole string is wrapped
if m := re.fullmatch(r"\*\*(.+?)\*\*", merged, re.DOTALL):
return m.group(1).strip()
return merged
def extract_outline(md_path: Path) -> list[dict]:
"""Extract document outline (headings) from a Markdown file.
Recognises three heading styles produced by pymupdf4llm:
1. Standard Markdown headings: lines starting with one or more '#'.
Inline ``**...**`` wrappers and adjacent bold spans (``** **``) are
cleaned so the title is plain text.
2. Bold-only structural headings: ``**ITEM 1. BUSINESS**``, ``**PART II**``,
etc. SEC filings use bold+caps for section headings with the same font
size as body text, so pymupdf4llm cannot promote them to # headings.
3. Split-bold headings: ``**1** **Introduction**``, ``**3.2** **Attention**``.
pymupdf4llm emits these when the section number and title text are
separate spans in the underlying PDF (common in academic papers).
Args:
md_path: Path to the .md file.
Returns:
List of dicts with keys: title (str), line (int, 1-based).
When the outline is truncated at MAX_OUTLINE_ENTRIES, a sentinel entry
``{"truncated": True}`` is appended as the last element so callers can
render a "showing first N headings" hint without re-scanning the file.
Returns an empty list if the file cannot be read or has no headings.
"""
outline: list[dict] = []
try:
with md_path.open(encoding="utf-8") as f:
for lineno, line in enumerate(f, 1):
stripped = line.strip()
if not stripped:
continue
# Style 1: standard Markdown heading
if stripped.startswith("#"):
title = _clean_bold_title(stripped.lstrip("#").strip())
if title:
outline.append({"title": title, "line": lineno})
# Style 2: single bold block with SEC structural keyword
elif m := _BOLD_HEADING_RE.match(stripped):
title = m.group(1).strip()
if title:
outline.append({"title": title, "line": lineno})
# Style 3: split-bold heading — **<num>** **<title>**
# Regex already enforces max 4 blocks and non-numeric second block.
elif _SPLIT_BOLD_HEADING_RE.match(stripped):
title = " ".join(re.findall(r"\*\*([^*]+)\*\*", stripped))
if title:
outline.append({"title": title, "line": lineno})
if len(outline) >= MAX_OUTLINE_ENTRIES:
outline.append({"truncated": True})
break
except Exception:
return []
return outline
def _get_pdf_converter() -> str:
"""Read pdf_converter setting from app config, defaulting to 'auto'.
Normalizes the value to lowercase and validates it against the allowed set
so that values like 'AUTO' or 'MarkItDown' from config.yaml don't silently
fall through to unexpected behaviour.
"""
try:
from deerflow.config.app_config import get_app_config
cfg = get_app_config()
uploads_cfg = getattr(cfg, "uploads", None)
if uploads_cfg is not None:
raw = str(getattr(uploads_cfg, "pdf_converter", "auto")).strip().lower()
if raw not in _ALLOWED_PDF_CONVERTERS:
logger.warning("Invalid pdf_converter value %r; falling back to 'auto'", raw)
return "auto"
return raw
except Exception:
pass
return "auto"

View File

@ -9,16 +9,17 @@ dependencies = [
"dotenv>=0.9.9",
"httpx>=0.28.0",
"kubernetes>=30.0.0",
"langchain>=1.2.3",
"langchain>=1.2.3,<1.2.10",
"langchain-anthropic>=1.3.4",
"langchain-deepseek>=1.0.1",
"langchain-mcp-adapters>=0.1.0",
"langchain-openai>=1.1.7",
"langfuse>=3.4.1",
"langgraph>=1.0.6,<1.0.10",
"langgraph-prebuilt>=1.0.6,<1.0.9",
"langgraph-api>=0.7.0,<0.8.0",
"langgraph-cli>=0.4.14",
"langgraph-runtime-inmem>=0.22.1",
"langgraph-runtime-inmem>=0.22.1,<0.27.0",
"markdownify>=1.2.2",
"markitdown[all,xlsx]>=0.0.1a2",
"pydantic>=2.12.5",
@ -34,6 +35,9 @@ dependencies = [
"langgraph-sdk>=0.1.51",
]
[project.optional-dependencies]
pymupdf = ["pymupdf4llm>=0.0.17"]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

View File

@ -16,6 +16,10 @@ dependencies = [
"python-telegram-bot>=21.0",
"langgraph-sdk>=0.1.51",
"markdown-to-mrkdwn>=0.3.1",
"wecom-aibot-python-sdk>=0.1.6",
"bcrypt>=4.0.0",
"pyjwt>=2.9.0",
"email-validator>=2.0.0",
]
[dependency-groups]

506
backend/tests/test_auth.py Normal file
View File

@ -0,0 +1,506 @@
"""Tests for authentication module: JWT, password hashing, AuthContext, and authz decorators."""
from datetime import timedelta
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from app.gateway.auth import create_access_token, decode_token, hash_password, verify_password
from app.gateway.auth.models import User
from app.gateway.authz import (
AuthContext,
Permissions,
get_auth_context,
require_auth,
require_permission,
)
# ── Password Hashing ────────────────────────────────────────────────────────
def test_hash_password_and_verify():
"""Hashing and verification round-trip."""
password = "s3cr3tP@ssw0rd!"
hashed = hash_password(password)
assert hashed != password
assert verify_password(password, hashed) is True
assert verify_password("wrongpassword", hashed) is False
def test_hash_password_different_each_time():
"""bcrypt generates unique salts, so same password has different hashes."""
password = "testpassword"
h1 = hash_password(password)
h2 = hash_password(password)
assert h1 != h2 # Different salts
# But both verify correctly
assert verify_password(password, h1) is True
assert verify_password(password, h2) is True
def test_verify_password_rejects_empty():
"""Empty password should not verify."""
hashed = hash_password("nonempty")
assert verify_password("", hashed) is False
# ── JWT ─────────────────────────────────────────────────────────────────────
def test_create_and_decode_token():
"""JWT creation and decoding round-trip."""
user_id = str(uuid4())
# Set a valid JWT secret for this test
import os
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
token = create_access_token(user_id)
assert isinstance(token, str)
payload = decode_token(token)
assert payload is not None
assert payload.sub == user_id
def test_decode_token_expired():
"""Expired token returns TokenError.EXPIRED."""
from app.gateway.auth.errors import TokenError
user_id = str(uuid4())
# Create token that expires immediately
token = create_access_token(user_id, expires_delta=timedelta(seconds=-1))
payload = decode_token(token)
assert payload == TokenError.EXPIRED
def test_decode_token_invalid():
"""Invalid token returns TokenError."""
from app.gateway.auth.errors import TokenError
assert isinstance(decode_token("not.a.valid.token"), TokenError)
assert isinstance(decode_token(""), TokenError)
assert isinstance(decode_token("completely-wrong"), TokenError)
def test_create_token_custom_expiry():
"""Custom expiry is respected."""
user_id = str(uuid4())
token = create_access_token(user_id, expires_delta=timedelta(hours=1))
payload = decode_token(token)
assert payload is not None
assert payload.sub == user_id
# ── AuthContext ────────────────────────────────────────────────────────────
def test_auth_context_unauthenticated():
"""AuthContext with no user."""
ctx = AuthContext(user=None, permissions=[])
assert ctx.is_authenticated is False
assert ctx.has_permission("threads", "read") is False
def test_auth_context_authenticated_no_perms():
"""AuthContext with user but no permissions."""
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
ctx = AuthContext(user=user, permissions=[])
assert ctx.is_authenticated is True
assert ctx.has_permission("threads", "read") is False
def test_auth_context_has_permission():
"""AuthContext permission checking."""
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
perms = [Permissions.THREADS_READ, Permissions.THREADS_WRITE]
ctx = AuthContext(user=user, permissions=perms)
assert ctx.has_permission("threads", "read") is True
assert ctx.has_permission("threads", "write") is True
assert ctx.has_permission("threads", "delete") is False
assert ctx.has_permission("runs", "read") is False
def test_auth_context_require_user_raises():
"""require_user raises 401 when not authenticated."""
ctx = AuthContext(user=None, permissions=[])
with pytest.raises(HTTPException) as exc_info:
ctx.require_user()
assert exc_info.value.status_code == 401
def test_auth_context_require_user_returns_user():
"""require_user returns user when authenticated."""
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
ctx = AuthContext(user=user, permissions=[])
returned = ctx.require_user()
assert returned == user
# ── get_auth_context helper ─────────────────────────────────────────────────
def test_get_auth_context_not_set():
"""get_auth_context returns None when auth not set on request."""
mock_request = MagicMock()
# Make getattr return None (simulating attribute not set)
mock_request.state = MagicMock()
del mock_request.state.auth
assert get_auth_context(mock_request) is None
def test_get_auth_context_set():
"""get_auth_context returns the AuthContext from request."""
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
ctx = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
mock_request = MagicMock()
mock_request.state.auth = ctx
assert get_auth_context(mock_request) == ctx
# ── require_auth decorator ──────────────────────────────────────────────────
def test_require_auth_sets_auth_context():
"""require_auth sets auth context on request from cookie."""
from fastapi import Request
app = FastAPI()
@app.get("/test")
@require_auth
async def endpoint(request: Request):
ctx = get_auth_context(request)
return {"authenticated": ctx.is_authenticated}
with TestClient(app) as client:
# No cookie → anonymous
response = client.get("/test")
assert response.status_code == 200
assert response.json()["authenticated"] is False
def test_require_auth_requires_request_param():
"""require_auth raises ValueError if request parameter is missing."""
import asyncio
@require_auth
async def bad_endpoint(): # Missing `request` parameter
pass
with pytest.raises(ValueError, match="require_auth decorator requires 'request' parameter"):
asyncio.run(bad_endpoint())
# ── require_permission decorator ─────────────────────────────────────────────
def test_require_permission_requires_auth():
"""require_permission raises 401 when not authenticated."""
from fastapi import Request
app = FastAPI()
@app.get("/test")
@require_permission("threads", "read")
async def endpoint(request: Request):
return {"ok": True}
with TestClient(app) as client:
response = client.get("/test")
assert response.status_code == 401
assert "Authentication required" in response.json()["detail"]
def test_require_permission_denies_wrong_permission():
"""User without required permission gets 403."""
from fastapi import Request
app = FastAPI()
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
@app.get("/test")
@require_permission("threads", "delete")
async def endpoint(request: Request):
return {"ok": True}
mock_auth = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
with TestClient(app) as client:
response = client.get("/test")
assert response.status_code == 403
assert "Permission denied" in response.json()["detail"]
# ── Weak JWT secret warning ──────────────────────────────────────────────────
# ── User Model Fields ──────────────────────────────────────────────────────
def test_user_model_has_needs_setup_default_false():
"""New users default to needs_setup=False."""
user = User(email="test@example.com", password_hash="hash")
assert user.needs_setup is False
def test_user_model_has_token_version_default_zero():
"""New users default to token_version=0."""
user = User(email="test@example.com", password_hash="hash")
assert user.token_version == 0
def test_user_model_needs_setup_true():
"""Auto-created admin has needs_setup=True."""
user = User(email="admin@example.com", password_hash="hash", needs_setup=True)
assert user.needs_setup is True
def test_sqlite_round_trip_new_fields():
"""needs_setup and token_version survive create → read round-trip."""
import asyncio
import os
import tempfile
from pathlib import Path
from app.gateway.auth.repositories import sqlite as sqlite_mod
with tempfile.TemporaryDirectory() as tmpdir:
db_path = os.path.join(tmpdir, "test_users.db")
old_path = sqlite_mod._resolved_db_path
old_init = sqlite_mod._table_initialized
sqlite_mod._resolved_db_path = Path(db_path)
sqlite_mod._table_initialized = False
try:
repo = sqlite_mod.SQLiteUserRepository()
user = User(
email="setup@test.com",
password_hash="fakehash",
system_role="admin",
needs_setup=True,
token_version=3,
)
created = asyncio.run(repo.create_user(user))
assert created.needs_setup is True
assert created.token_version == 3
fetched = asyncio.run(repo.get_user_by_email("setup@test.com"))
assert fetched is not None
assert fetched.needs_setup is True
assert fetched.token_version == 3
fetched.needs_setup = False
fetched.token_version = 4
asyncio.run(repo.update_user(fetched))
refetched = asyncio.run(repo.get_user_by_id(str(fetched.id)))
assert refetched.needs_setup is False
assert refetched.token_version == 4
finally:
sqlite_mod._resolved_db_path = old_path
sqlite_mod._table_initialized = old_init
# ── Token Versioning ───────────────────────────────────────────────────────
def test_jwt_encodes_ver():
"""JWT payload includes ver field."""
import os
from app.gateway.auth.errors import TokenError
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
token = create_access_token(str(uuid4()), token_version=3)
payload = decode_token(token)
assert not isinstance(payload, TokenError)
assert payload.ver == 3
def test_jwt_default_ver_zero():
"""JWT ver defaults to 0."""
import os
from app.gateway.auth.errors import TokenError
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
token = create_access_token(str(uuid4()))
payload = decode_token(token)
assert not isinstance(payload, TokenError)
assert payload.ver == 0
def test_token_version_mismatch_rejects():
"""Token with stale ver is rejected by get_current_user_from_request."""
import asyncio
import os
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
user_id = str(uuid4())
token = create_access_token(user_id, token_version=0)
mock_user = User(id=user_id, email="test@example.com", password_hash="hash", token_version=1)
mock_request = MagicMock()
mock_request.cookies = {"access_token": token}
with patch("app.gateway.deps.get_local_provider") as mock_provider_fn:
mock_provider = MagicMock()
mock_provider.get_user = AsyncMock(return_value=mock_user)
mock_provider_fn.return_value = mock_provider
from app.gateway.deps import get_current_user_from_request
with pytest.raises(HTTPException) as exc_info:
asyncio.run(get_current_user_from_request(mock_request))
assert exc_info.value.status_code == 401
assert "revoked" in str(exc_info.value.detail).lower()
# ── change-password extension ──────────────────────────────────────────────
def test_change_password_request_accepts_new_email():
"""ChangePasswordRequest model accepts optional new_email."""
from app.gateway.routers.auth import ChangePasswordRequest
req = ChangePasswordRequest(
current_password="old",
new_password="newpassword",
new_email="new@example.com",
)
assert req.new_email == "new@example.com"
def test_change_password_request_new_email_optional():
"""ChangePasswordRequest model works without new_email."""
from app.gateway.routers.auth import ChangePasswordRequest
req = ChangePasswordRequest(current_password="old", new_password="newpassword")
assert req.new_email is None
def test_login_response_includes_needs_setup():
"""LoginResponse includes needs_setup field."""
from app.gateway.routers.auth import LoginResponse
resp = LoginResponse(expires_in=3600, needs_setup=True)
assert resp.needs_setup is True
resp2 = LoginResponse(expires_in=3600)
assert resp2.needs_setup is False
# ── Rate Limiting ──────────────────────────────────────────────────────────
def test_rate_limiter_allows_under_limit():
"""Requests under the limit are allowed."""
from app.gateway.routers.auth import _check_rate_limit, _login_attempts
_login_attempts.clear()
_check_rate_limit("192.168.1.1") # Should not raise
def test_rate_limiter_blocks_after_max_failures():
"""IP is blocked after 5 consecutive failures."""
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure
_login_attempts.clear()
ip = "10.0.0.1"
for _ in range(5):
_record_login_failure(ip)
with pytest.raises(HTTPException) as exc_info:
_check_rate_limit(ip)
assert exc_info.value.status_code == 429
def test_rate_limiter_resets_on_success():
"""Successful login clears the failure counter."""
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure, _record_login_success
_login_attempts.clear()
ip = "10.0.0.2"
for _ in range(4):
_record_login_failure(ip)
_record_login_success(ip)
_check_rate_limit(ip) # Should not raise
# ── Client IP extraction ─────────────────────────────────────────────────
def test_get_client_ip_direct_connection():
"""Without nginx (no X-Real-IP), falls back to request.client.host."""
from app.gateway.routers.auth import _get_client_ip
req = MagicMock()
req.client.host = "203.0.113.42"
req.headers = {}
assert _get_client_ip(req) == "203.0.113.42"
def test_get_client_ip_uses_x_real_ip():
"""X-Real-IP (set by nginx) is used when present."""
from app.gateway.routers.auth import _get_client_ip
req = MagicMock()
req.client.host = "10.0.0.1" # uvicorn may have replaced this with XFF[0]
req.headers = {"x-real-ip": "203.0.113.42"}
assert _get_client_ip(req) == "203.0.113.42"
def test_get_client_ip_xff_ignored():
"""X-Forwarded-For is never used; only X-Real-IP matters."""
from app.gateway.routers.auth import _get_client_ip
req = MagicMock()
req.client.host = "10.0.0.1"
req.headers = {"x-forwarded-for": "10.0.0.1, 198.51.100.5", "x-real-ip": "198.51.100.5"}
assert _get_client_ip(req) == "198.51.100.5"
def test_get_client_ip_no_real_ip_fallback():
"""No X-Real-IP → falls back to client.host (direct connection)."""
from app.gateway.routers.auth import _get_client_ip
req = MagicMock()
req.client.host = "127.0.0.1"
req.headers = {}
assert _get_client_ip(req) == "127.0.0.1"
def test_get_client_ip_x_real_ip_always_preferred():
"""X-Real-IP is always preferred over client.host regardless of IP."""
from app.gateway.routers.auth import _get_client_ip
req = MagicMock()
req.client.host = "203.0.113.99"
req.headers = {"x-real-ip": "198.51.100.7"}
assert _get_client_ip(req) == "198.51.100.7"
# ── Weak JWT secret warning ──────────────────────────────────────────────────
def test_missing_jwt_secret_generates_ephemeral(monkeypatch, caplog):
"""get_auth_config() auto-generates an ephemeral secret when AUTH_JWT_SECRET is unset."""
import logging
import app.gateway.auth.config as config_module
config_module._auth_config = None
monkeypatch.delenv("AUTH_JWT_SECRET", raising=False)
with caplog.at_level(logging.WARNING):
config = config_module.get_auth_config()
assert config.jwt_secret # non-empty ephemeral secret
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
# Cleanup
config_module._auth_config = None

View File

@ -0,0 +1,54 @@
"""Tests for AuthConfig typed configuration."""
import os
from unittest.mock import patch
import pytest
from app.gateway.auth.config import AuthConfig
def test_auth_config_defaults():
config = AuthConfig(jwt_secret="test-secret-key-123")
assert config.token_expiry_days == 7
def test_auth_config_token_expiry_range():
AuthConfig(jwt_secret="s", token_expiry_days=1)
AuthConfig(jwt_secret="s", token_expiry_days=30)
with pytest.raises(Exception):
AuthConfig(jwt_secret="s", token_expiry_days=0)
with pytest.raises(Exception):
AuthConfig(jwt_secret="s", token_expiry_days=31)
def test_auth_config_from_env():
env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"}
with patch.dict(os.environ, env, clear=False):
import app.gateway.auth.config as cfg
old = cfg._auth_config
cfg._auth_config = None
try:
config = cfg.get_auth_config()
assert config.jwt_secret == "test-jwt-secret-from-env"
finally:
cfg._auth_config = old
def test_auth_config_missing_secret_generates_ephemeral(caplog):
import logging
import app.gateway.auth.config as cfg
old = cfg._auth_config
cfg._auth_config = None
try:
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("AUTH_JWT_SECRET", None)
with caplog.at_level(logging.WARNING):
config = cfg.get_auth_config()
assert config.jwt_secret
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
finally:
cfg._auth_config = old

View File

@ -0,0 +1,75 @@
"""Tests for auth error types and typed decode_token."""
from datetime import UTC, datetime, timedelta
import jwt as pyjwt
from app.gateway.auth.config import AuthConfig, set_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
from app.gateway.auth.jwt import create_access_token, decode_token
def test_auth_error_code_values():
assert AuthErrorCode.INVALID_CREDENTIALS == "invalid_credentials"
assert AuthErrorCode.TOKEN_EXPIRED == "token_expired"
assert AuthErrorCode.NOT_AUTHENTICATED == "not_authenticated"
def test_token_error_values():
assert TokenError.EXPIRED == "expired"
assert TokenError.INVALID_SIGNATURE == "invalid_signature"
assert TokenError.MALFORMED == "malformed"
def test_auth_error_response_serialization():
err = AuthErrorResponse(
code=AuthErrorCode.TOKEN_EXPIRED,
message="Token has expired",
)
d = err.model_dump()
assert d == {"code": "token_expired", "message": "Token has expired"}
def test_auth_error_response_from_dict():
d = {"code": "invalid_credentials", "message": "Wrong password"}
err = AuthErrorResponse(**d)
assert err.code == AuthErrorCode.INVALID_CREDENTIALS
# ── decode_token typed failure tests ──────────────────────────────
_TEST_SECRET = "test-secret-for-jwt-decode-token-tests"
def _setup_config():
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
def test_decode_token_returns_token_error_on_expired():
_setup_config()
expired_payload = {"sub": "user-1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(expired_payload, _TEST_SECRET, algorithm="HS256")
result = decode_token(token)
assert result == TokenError.EXPIRED
def test_decode_token_returns_token_error_on_bad_signature():
_setup_config()
payload = {"sub": "user-1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
result = decode_token(token)
assert result == TokenError.INVALID_SIGNATURE
def test_decode_token_returns_token_error_on_malformed():
_setup_config()
result = decode_token("not-a-jwt")
assert result == TokenError.MALFORMED
def test_decode_token_returns_payload_on_valid():
_setup_config()
token = create_access_token("user-123")
result = decode_token(token)
assert not isinstance(result, TokenError)
assert result.sub == "user-123"

View File

@ -0,0 +1,216 @@
"""Tests for the global AuthMiddleware (fail-closed safety net)."""
import pytest
from starlette.testclient import TestClient
from app.gateway.auth_middleware import AuthMiddleware, _is_public
# ── _is_public unit tests ─────────────────────────────────────────────────
@pytest.mark.parametrize(
"path",
[
"/health",
"/health/",
"/docs",
"/docs/",
"/redoc",
"/openapi.json",
"/api/v1/auth/login/local",
"/api/v1/auth/register",
"/api/v1/auth/logout",
"/api/v1/auth/setup-status",
],
)
def test_public_paths(path: str):
assert _is_public(path) is True
@pytest.mark.parametrize(
"path",
[
"/api/models",
"/api/mcp/config",
"/api/memory",
"/api/skills",
"/api/threads/123",
"/api/threads/123/uploads",
"/api/agents",
"/api/channels",
"/api/runs/stream",
"/api/threads/123/runs",
"/api/v1/auth/me",
"/api/v1/auth/change-password",
],
)
def test_protected_paths(path: str):
assert _is_public(path) is False
# ── Trailing slash / normalization edge cases ─────────────────────────────
@pytest.mark.parametrize(
"path",
[
"/api/v1/auth/login/local/",
"/api/v1/auth/register/",
"/api/v1/auth/logout/",
"/api/v1/auth/setup-status/",
],
)
def test_public_auth_paths_with_trailing_slash(path: str):
assert _is_public(path) is True
@pytest.mark.parametrize(
"path",
[
"/api/models/",
"/api/v1/auth/me/",
"/api/v1/auth/change-password/",
],
)
def test_protected_paths_with_trailing_slash(path: str):
assert _is_public(path) is False
def test_unknown_api_path_is_protected():
"""Fail-closed: any new /api/* path is protected by default."""
assert _is_public("/api/new-feature") is False
assert _is_public("/api/v2/something") is False
assert _is_public("/api/v1/auth/new-endpoint") is False
# ── Middleware integration tests ──────────────────────────────────────────
def _make_app():
"""Create a minimal FastAPI app with AuthMiddleware for testing."""
from fastapi import FastAPI
app = FastAPI()
app.add_middleware(AuthMiddleware)
@app.get("/health")
async def health():
return {"status": "ok"}
@app.get("/api/v1/auth/me")
async def auth_me():
return {"id": "1", "email": "test@test.com"}
@app.get("/api/v1/auth/setup-status")
async def setup_status():
return {"needs_setup": False}
@app.get("/api/models")
async def models_get():
return {"models": []}
@app.put("/api/mcp/config")
async def mcp_put():
return {"ok": True}
@app.delete("/api/threads/abc")
async def thread_delete():
return {"ok": True}
@app.patch("/api/threads/abc")
async def thread_patch():
return {"ok": True}
@app.post("/api/threads/abc/runs/stream")
async def stream():
return {"ok": True}
@app.get("/api/future-endpoint")
async def future():
return {"ok": True}
return app
@pytest.fixture
def client():
return TestClient(_make_app())
def test_public_path_no_cookie(client):
res = client.get("/health")
assert res.status_code == 200
def test_public_auth_path_no_cookie(client):
"""Public auth endpoints (login/register) pass without cookie."""
res = client.get("/api/v1/auth/setup-status")
assert res.status_code == 200
def test_protected_auth_path_no_cookie(client):
"""/auth/me requires cookie even though it's under /api/v1/auth/."""
res = client.get("/api/v1/auth/me")
assert res.status_code == 401
def test_protected_path_no_cookie_returns_401(client):
res = client.get("/api/models")
assert res.status_code == 401
body = res.json()
assert body["detail"]["code"] == "not_authenticated"
def test_protected_path_with_cookie_passes(client):
res = client.get("/api/models", cookies={"access_token": "some-token"})
assert res.status_code == 200
def test_protected_post_no_cookie_returns_401(client):
res = client.post("/api/threads/abc/runs/stream")
assert res.status_code == 401
# ── Method matrix: PUT/DELETE/PATCH also protected ────────────────────────
def test_protected_put_no_cookie(client):
res = client.put("/api/mcp/config")
assert res.status_code == 401
def test_protected_delete_no_cookie(client):
res = client.delete("/api/threads/abc")
assert res.status_code == 401
def test_protected_patch_no_cookie(client):
res = client.patch("/api/threads/abc")
assert res.status_code == 401
def test_put_with_cookie_passes(client):
client.cookies.set("access_token", "tok")
res = client.put("/api/mcp/config")
assert res.status_code == 200
def test_delete_with_cookie_passes(client):
client.cookies.set("access_token", "tok")
res = client.delete("/api/threads/abc")
assert res.status_code == 200
# ── Fail-closed: unknown future endpoints ─────────────────────────────────
def test_unknown_endpoint_no_cookie_returns_401(client):
"""Any new /api/* endpoint is blocked by default without cookie."""
res = client.get("/api/future-endpoint")
assert res.status_code == 401
def test_unknown_endpoint_with_cookie_passes(client):
client.cookies.set("access_token", "tok")
res = client.get("/api/future-endpoint")
assert res.status_code == 200

View File

@ -0,0 +1,675 @@
"""Tests for auth type system hardening.
Covers structured error responses, typed decode_token callers,
CSRF middleware path matching, config-driven cookie security,
and unhappy paths / edge cases for all auth boundaries.
"""
import os
import secrets
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
import jwt as pyjwt
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from pydantic import ValidationError
from app.gateway.auth.config import AuthConfig, set_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
from app.gateway.auth.jwt import decode_token
from app.gateway.csrf_middleware import (
CSRF_COOKIE_NAME,
CSRF_HEADER_NAME,
CSRFMiddleware,
is_auth_endpoint,
should_check_csrf,
)
# ── Setup ────────────────────────────────────────────────────────────
_TEST_SECRET = "test-secret-for-auth-type-system-tests-min32"
def _setup_config():
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
# ── CSRF Middleware Path Matching ────────────────────────────────────
class _FakeRequest:
"""Minimal request mock for CSRF path matching tests."""
def __init__(self, path: str, method: str = "POST"):
self.method = method
class _URL:
def __init__(self, p):
self.path = p
self.url = _URL(path)
self.cookies = {}
self.headers = {}
def test_csrf_exempts_login_local():
"""login/local (actual route) should be exempt from CSRF."""
req = _FakeRequest("/api/v1/auth/login/local")
assert is_auth_endpoint(req) is True
def test_csrf_exempts_login_local_trailing_slash():
"""Trailing slash should also be exempt."""
req = _FakeRequest("/api/v1/auth/login/local/")
assert is_auth_endpoint(req) is True
def test_csrf_exempts_logout():
req = _FakeRequest("/api/v1/auth/logout")
assert is_auth_endpoint(req) is True
def test_csrf_exempts_register():
req = _FakeRequest("/api/v1/auth/register")
assert is_auth_endpoint(req) is True
def test_csrf_does_not_exempt_old_login_path():
"""Old /api/v1/auth/login (without /local) should NOT be exempt."""
req = _FakeRequest("/api/v1/auth/login")
assert is_auth_endpoint(req) is False
def test_csrf_does_not_exempt_me():
req = _FakeRequest("/api/v1/auth/me")
assert is_auth_endpoint(req) is False
def test_csrf_skips_get_requests():
req = _FakeRequest("/api/v1/auth/me", method="GET")
assert should_check_csrf(req) is False
def test_csrf_checks_post_to_protected():
req = _FakeRequest("/api/v1/some/endpoint", method="POST")
assert should_check_csrf(req) is True
# ── Structured Error Response Format ────────────────────────────────
def test_auth_error_response_has_code_and_message():
"""All auth errors should have structured {code, message} format."""
err = AuthErrorResponse(
code=AuthErrorCode.INVALID_CREDENTIALS,
message="Wrong password",
)
d = err.model_dump()
assert "code" in d
assert "message" in d
assert d["code"] == "invalid_credentials"
def test_auth_error_response_all_codes_serializable():
"""Every AuthErrorCode should be serializable in AuthErrorResponse."""
for code in AuthErrorCode:
err = AuthErrorResponse(code=code, message=f"Test {code.value}")
d = err.model_dump()
assert d["code"] == code.value
# ── decode_token Caller Pattern ──────────────────────────────────────
def test_decode_token_expired_maps_to_token_expired_code():
"""TokenError.EXPIRED should map to AuthErrorCode.TOKEN_EXPIRED."""
_setup_config()
from datetime import UTC, datetime, timedelta
import jwt as pyjwt
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
result = decode_token(token)
assert result == TokenError.EXPIRED
# Verify the mapping pattern used in route handlers
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
assert code == AuthErrorCode.TOKEN_EXPIRED
def test_decode_token_invalid_sig_maps_to_token_invalid_code():
"""TokenError.INVALID_SIGNATURE should map to AuthErrorCode.TOKEN_INVALID."""
_setup_config()
from datetime import UTC, datetime, timedelta
import jwt as pyjwt
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
result = decode_token(token)
assert result == TokenError.INVALID_SIGNATURE
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
assert code == AuthErrorCode.TOKEN_INVALID
def test_decode_token_malformed_maps_to_token_invalid_code():
"""TokenError.MALFORMED should map to AuthErrorCode.TOKEN_INVALID."""
_setup_config()
result = decode_token("garbage")
assert result == TokenError.MALFORMED
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
assert code == AuthErrorCode.TOKEN_INVALID
# ── Login Response Format ────────────────────────────────────────────
def test_login_response_model_has_no_access_token():
"""LoginResponse should NOT contain access_token field (RFC-001)."""
from app.gateway.routers.auth import LoginResponse
resp = LoginResponse(expires_in=604800)
d = resp.model_dump()
assert "access_token" not in d
assert "expires_in" in d
assert d["expires_in"] == 604800
def test_login_response_model_fields():
"""LoginResponse has expires_in and needs_setup."""
from app.gateway.routers.auth import LoginResponse
fields = set(LoginResponse.model_fields.keys())
assert fields == {"expires_in", "needs_setup"}
# ── AuthConfig in Route ──────────────────────────────────────────────
def test_auth_config_token_expiry_used_in_login_response():
"""LoginResponse.expires_in should come from config.token_expiry_days."""
from app.gateway.routers.auth import LoginResponse
expected_seconds = 14 * 24 * 3600
resp = LoginResponse(expires_in=expected_seconds)
assert resp.expires_in == expected_seconds
# ── UserResponse Type Preservation ───────────────────────────────────
def test_user_response_system_role_literal():
"""UserResponse.system_role should only accept 'admin' or 'user'."""
from app.gateway.auth.models import UserResponse
# Valid roles
resp = UserResponse(id="1", email="a@b.com", system_role="admin")
assert resp.system_role == "admin"
resp = UserResponse(id="1", email="a@b.com", system_role="user")
assert resp.system_role == "user"
def test_user_response_rejects_invalid_role():
"""UserResponse should reject invalid system_role values."""
from app.gateway.auth.models import UserResponse
with pytest.raises(ValidationError):
UserResponse(id="1", email="a@b.com", system_role="superadmin")
# ══════════════════════════════════════════════════════════════════════
# UNHAPPY PATHS / EDGE CASES
# ══════════════════════════════════════════════════════════════════════
# ── get_current_user structured 401 responses ────────────────────────
def test_get_current_user_no_cookie_returns_not_authenticated():
"""No cookie → 401 with code=not_authenticated."""
import asyncio
from fastapi import HTTPException
from app.gateway.deps import get_current_user_from_request
mock_request = type("MockRequest", (), {"cookies": {}})()
with pytest.raises(HTTPException) as exc_info:
asyncio.run(get_current_user_from_request(mock_request))
assert exc_info.value.status_code == 401
detail = exc_info.value.detail
assert detail["code"] == "not_authenticated"
def test_get_current_user_expired_token_returns_token_expired():
"""Expired token → 401 with code=token_expired."""
import asyncio
from fastapi import HTTPException
from app.gateway.deps import get_current_user_from_request
_setup_config()
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
with pytest.raises(HTTPException) as exc_info:
asyncio.run(get_current_user_from_request(mock_request))
assert exc_info.value.status_code == 401
detail = exc_info.value.detail
assert detail["code"] == "token_expired"
def test_get_current_user_invalid_token_returns_token_invalid():
"""Bad signature → 401 with code=token_invalid."""
import asyncio
from fastapi import HTTPException
from app.gateway.deps import get_current_user_from_request
_setup_config()
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
with pytest.raises(HTTPException) as exc_info:
asyncio.run(get_current_user_from_request(mock_request))
assert exc_info.value.status_code == 401
detail = exc_info.value.detail
assert detail["code"] == "token_invalid"
def test_get_current_user_malformed_token_returns_token_invalid():
"""Garbage token → 401 with code=token_invalid."""
import asyncio
from fastapi import HTTPException
from app.gateway.deps import get_current_user_from_request
_setup_config()
mock_request = type("MockRequest", (), {"cookies": {"access_token": "not-a-jwt"}})()
with pytest.raises(HTTPException) as exc_info:
asyncio.run(get_current_user_from_request(mock_request))
assert exc_info.value.status_code == 401
detail = exc_info.value.detail
assert detail["code"] == "token_invalid"
# ── decode_token edge cases ──────────────────────────────────────────
def test_decode_token_empty_string_returns_malformed():
_setup_config()
result = decode_token("")
assert result == TokenError.MALFORMED
def test_decode_token_whitespace_returns_malformed():
_setup_config()
result = decode_token(" ")
assert result == TokenError.MALFORMED
# ── AuthConfig validation edge cases ─────────────────────────────────
def test_auth_config_missing_jwt_secret_raises():
"""AuthConfig requires jwt_secret — no default allowed."""
with pytest.raises(ValidationError):
AuthConfig()
def test_auth_config_token_expiry_zero_raises():
"""token_expiry_days must be >= 1."""
with pytest.raises(ValidationError):
AuthConfig(jwt_secret="secret", token_expiry_days=0)
def test_auth_config_token_expiry_31_raises():
"""token_expiry_days must be <= 30."""
with pytest.raises(ValidationError):
AuthConfig(jwt_secret="secret", token_expiry_days=31)
def test_auth_config_token_expiry_boundary_1_ok():
config = AuthConfig(jwt_secret="secret", token_expiry_days=1)
assert config.token_expiry_days == 1
def test_auth_config_token_expiry_boundary_30_ok():
config = AuthConfig(jwt_secret="secret", token_expiry_days=30)
assert config.token_expiry_days == 30
def test_get_auth_config_missing_env_var_generates_ephemeral(caplog):
"""get_auth_config() auto-generates ephemeral secret when AUTH_JWT_SECRET is unset."""
import logging
import app.gateway.auth.config as cfg
old = cfg._auth_config
cfg._auth_config = None
try:
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("AUTH_JWT_SECRET", None)
with caplog.at_level(logging.WARNING):
config = cfg.get_auth_config()
assert config.jwt_secret
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
finally:
cfg._auth_config = old
# ── CSRF middleware integration (unhappy paths) ──────────────────────
def _make_csrf_app():
"""Create a minimal FastAPI app with CSRFMiddleware for testing."""
from fastapi import HTTPException as _HTTPException
from fastapi.responses import JSONResponse as _JSONResponse
app = FastAPI()
@app.exception_handler(_HTTPException)
async def _http_exc_handler(request, exc):
return _JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
app.add_middleware(CSRFMiddleware)
@app.post("/api/v1/test/protected")
async def protected():
return {"ok": True}
@app.post("/api/v1/auth/login/local")
async def login():
return {"ok": True}
@app.get("/api/v1/test/read")
async def read_endpoint():
return {"ok": True}
return app
def test_csrf_middleware_blocks_post_without_token():
"""POST to protected endpoint without CSRF token → 403 with structured detail."""
client = TestClient(_make_csrf_app())
resp = client.post("/api/v1/test/protected")
assert resp.status_code == 403
assert "CSRF" in resp.json()["detail"]
assert "missing" in resp.json()["detail"].lower()
def test_csrf_middleware_blocks_post_with_mismatched_token():
"""POST with mismatched CSRF cookie/header → 403 with mismatch detail."""
client = TestClient(_make_csrf_app())
client.cookies.set(CSRF_COOKIE_NAME, "token-a")
resp = client.post(
"/api/v1/test/protected",
headers={CSRF_HEADER_NAME: "token-b"},
)
assert resp.status_code == 403
assert "mismatch" in resp.json()["detail"].lower()
def test_csrf_middleware_allows_post_with_matching_token():
"""POST with matching CSRF cookie/header → 200."""
client = TestClient(_make_csrf_app())
token = secrets.token_urlsafe(64)
client.cookies.set(CSRF_COOKIE_NAME, token)
resp = client.post(
"/api/v1/test/protected",
headers={CSRF_HEADER_NAME: token},
)
assert resp.status_code == 200
def test_csrf_middleware_allows_get_without_token():
"""GET requests bypass CSRF check."""
client = TestClient(_make_csrf_app())
resp = client.get("/api/v1/test/read")
assert resp.status_code == 200
def test_csrf_middleware_exempts_login_local():
"""POST to login/local is exempt from CSRF (no token yet)."""
client = TestClient(_make_csrf_app())
resp = client.post("/api/v1/auth/login/local")
assert resp.status_code == 200
def test_csrf_middleware_sets_cookie_on_auth_endpoint():
"""Auth endpoints should receive a CSRF cookie in response."""
client = TestClient(_make_csrf_app())
resp = client.post("/api/v1/auth/login/local")
assert CSRF_COOKIE_NAME in resp.cookies
# ── UserResponse edge cases ──────────────────────────────────────────
def test_user_response_missing_required_fields():
"""UserResponse with missing fields → ValidationError."""
from app.gateway.auth.models import UserResponse
with pytest.raises(ValidationError):
UserResponse(id="1") # missing email, system_role
with pytest.raises(ValidationError):
UserResponse(id="1", email="a@b.com") # missing system_role
def test_user_response_empty_string_role_rejected():
"""Empty string is not a valid role."""
from app.gateway.auth.models import UserResponse
with pytest.raises(ValidationError):
UserResponse(id="1", email="a@b.com", system_role="")
# ══════════════════════════════════════════════════════════════════════
# HTTP-LEVEL API CONTRACT TESTS
# ══════════════════════════════════════════════════════════════════════
def _make_auth_app():
"""Create FastAPI app with auth routes for contract testing."""
from app.gateway.app import create_app
return create_app()
def _get_auth_client():
"""Get TestClient for auth API contract tests."""
return TestClient(_make_auth_app())
def test_api_auth_me_no_cookie_returns_structured_401():
"""/api/v1/auth/me without cookie → 401 with {code: 'not_authenticated'}."""
_setup_config()
client = _get_auth_client()
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "not_authenticated"
assert "message" in body["detail"]
def test_api_auth_me_expired_token_returns_structured_401():
"""/api/v1/auth/me with expired token → 401 with {code: 'token_expired'}."""
_setup_config()
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
client = _get_auth_client()
client.cookies.set("access_token", token)
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "token_expired"
def test_api_auth_me_invalid_sig_returns_structured_401():
"""/api/v1/auth/me with bad signature → 401 with {code: 'token_invalid'}."""
_setup_config()
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
client = _get_auth_client()
client.cookies.set("access_token", token)
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "token_invalid"
def test_api_login_bad_credentials_returns_structured_401():
"""Login with wrong password → 401 with {code: 'invalid_credentials'}."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/login/local",
data={"username": "nonexistent@test.com", "password": "wrongpassword"},
)
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "invalid_credentials"
def test_api_login_success_no_token_in_body():
"""Successful login → response body has expires_in but NOT access_token."""
_setup_config()
client = _get_auth_client()
# Register first
client.post(
"/api/v1/auth/register",
json={"email": "contract-test@test.com", "password": "securepassword123"},
)
# Login
resp = client.post(
"/api/v1/auth/login/local",
data={"username": "contract-test@test.com", "password": "securepassword123"},
)
assert resp.status_code == 200
body = resp.json()
assert "expires_in" in body
assert "access_token" not in body
# Token should be in cookie, not body
assert "access_token" in resp.cookies
def test_api_register_duplicate_returns_structured_400():
"""Register with duplicate email → 400 with {code: 'email_already_exists'}."""
_setup_config()
client = _get_auth_client()
email = "dup-contract-test@test.com"
# First register
client.post("/api/v1/auth/register", json={"email": email, "password": "password123"})
# Duplicate
resp = client.post("/api/v1/auth/register", json={"email": email, "password": "password456"})
assert resp.status_code == 400
body = resp.json()
assert body["detail"]["code"] == "email_already_exists"
# ── Cookie security: HTTP vs HTTPS ────────────────────────────────────
def _unique_email(prefix: str) -> str:
return f"{prefix}-{secrets.token_hex(4)}@test.com"
def _get_set_cookie_headers(resp) -> list[str]:
"""Extract all set-cookie header values from a TestClient response."""
return [v for k, v in resp.headers.multi_items() if k.lower() == "set-cookie"]
def test_register_http_cookie_httponly_true_secure_false():
"""HTTP register → access_token cookie is httponly=True, secure=False, no max_age."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("http-cookie"), "password": "password123"},
)
assert resp.status_code == 201
cookie_header = resp.headers.get("set-cookie", "")
assert "access_token=" in cookie_header
assert "httponly" in cookie_header.lower()
assert "secure" not in cookie_header.lower().replace("samesite", "")
def test_register_https_cookie_httponly_true_secure_true():
"""HTTPS register (x-forwarded-proto) → access_token cookie is httponly=True, secure=True, has max_age."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("https-cookie"), "password": "password123"},
headers={"x-forwarded-proto": "https"},
)
assert resp.status_code == 201
cookie_header = resp.headers.get("set-cookie", "")
assert "access_token=" in cookie_header
assert "httponly" in cookie_header.lower()
assert "secure" in cookie_header.lower()
assert "max-age" in cookie_header.lower()
def test_login_https_sets_secure_cookie():
"""HTTPS login → access_token cookie has secure flag."""
_setup_config()
client = _get_auth_client()
email = _unique_email("https-login")
client.post("/api/v1/auth/register", json={"email": email, "password": "password123"})
resp = client.post(
"/api/v1/auth/login/local",
data={"username": email, "password": "password123"},
headers={"x-forwarded-proto": "https"},
)
assert resp.status_code == 200
cookie_header = resp.headers.get("set-cookie", "")
assert "access_token=" in cookie_header
assert "httponly" in cookie_header.lower()
assert "secure" in cookie_header.lower()
def test_csrf_cookie_secure_on_https():
"""HTTPS register → csrf_token cookie has secure flag but NOT httponly."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("csrf-https"), "password": "password123"},
headers={"x-forwarded-proto": "https"},
)
assert resp.status_code == 201
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
assert csrf_cookies, "csrf_token cookie not set on HTTPS register"
csrf_header = csrf_cookies[0]
assert "secure" in csrf_header.lower()
assert "httponly" not in csrf_header.lower()
def test_csrf_cookie_not_secure_on_http():
"""HTTP register → csrf_token cookie does NOT have secure flag."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("csrf-http"), "password": "password123"},
)
assert resp.status_code == 201
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
assert csrf_cookies, "csrf_token cookie not set on HTTP register"
csrf_header = csrf_cookies[0]
assert "secure" not in csrf_header.lower().replace("samesite", "")

View File

@ -7,12 +7,12 @@ import json
import tempfile
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.channels.base import Channel
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.store import ChannelStore
@ -1718,6 +1718,159 @@ class TestFeishuChannel:
_run(go())
class TestWeComChannel:
def test_publish_ws_inbound_starts_stream_and_publishes_message(self, monkeypatch):
from app.channels.wecom import WeComChannel
async def go():
bus = MessageBus()
bus.publish_inbound = AsyncMock()
channel = WeComChannel(bus, config={})
channel._ws_client = SimpleNamespace(reply_stream=AsyncMock())
monkeypatch.setitem(
__import__("sys").modules,
"aibot",
SimpleNamespace(generate_req_id=lambda prefix: "stream-1"),
)
frame = {
"body": {
"msgid": "msg-1",
"from": {"userid": "user-1"},
"aibotid": "bot-1",
"chattype": "single",
}
}
files = [{"type": "image", "url": "https://example.com/image.png"}]
await channel._publish_ws_inbound(frame, "hello", files=files)
channel._ws_client.reply_stream.assert_awaited_once_with(frame, "stream-1", "Working on it...", False)
bus.publish_inbound.assert_awaited_once()
inbound = bus.publish_inbound.await_args.args[0]
assert inbound.channel_name == "wecom"
assert inbound.chat_id == "user-1"
assert inbound.user_id == "user-1"
assert inbound.text == "hello"
assert inbound.thread_ts == "msg-1"
assert inbound.topic_id == "user-1"
assert inbound.files == files
assert inbound.metadata == {"aibotid": "bot-1", "chattype": "single"}
assert channel._ws_frames["msg-1"] is frame
assert channel._ws_stream_ids["msg-1"] == "stream-1"
_run(go())
def test_publish_ws_inbound_uses_configured_working_message(self, monkeypatch):
from app.channels.wecom import WeComChannel
async def go():
bus = MessageBus()
bus.publish_inbound = AsyncMock()
channel = WeComChannel(bus, config={"working_message": "Please wait..."})
channel._ws_client = SimpleNamespace(reply_stream=AsyncMock())
channel._working_message = "Please wait..."
monkeypatch.setitem(
__import__("sys").modules,
"aibot",
SimpleNamespace(generate_req_id=lambda prefix: "stream-1"),
)
frame = {
"body": {
"msgid": "msg-1",
"from": {"userid": "user-1"},
}
}
await channel._publish_ws_inbound(frame, "hello")
channel._ws_client.reply_stream.assert_awaited_once_with(frame, "stream-1", "Please wait...", False)
_run(go())
def test_on_outbound_sends_attachment_before_clearing_context(self, tmp_path):
from app.channels.wecom import WeComChannel
async def go():
bus = MessageBus()
channel = WeComChannel(bus, config={})
frame = {"body": {"msgid": "msg-1"}}
ws_client = SimpleNamespace(
reply_stream=AsyncMock(),
reply=AsyncMock(),
)
channel._ws_client = ws_client
channel._ws_frames["msg-1"] = frame
channel._ws_stream_ids["msg-1"] = "stream-1"
channel._upload_media_ws = AsyncMock(return_value="media-1")
attachment_path = tmp_path / "image.png"
attachment_path.write_bytes(b"png")
attachment = ResolvedAttachment(
virtual_path="/mnt/user-data/outputs/image.png",
actual_path=attachment_path,
filename="image.png",
mime_type="image/png",
size=attachment_path.stat().st_size,
is_image=True,
)
msg = OutboundMessage(
channel_name="wecom",
chat_id="user-1",
thread_id="thread-1",
text="done",
attachments=[attachment],
is_final=True,
thread_ts="msg-1",
)
await channel._on_outbound(msg)
ws_client.reply_stream.assert_awaited_once_with(frame, "stream-1", "done", True)
channel._upload_media_ws.assert_awaited_once_with(
media_type="image",
filename="image.png",
path=str(attachment_path),
size=attachment.size,
)
ws_client.reply.assert_awaited_once_with(frame, {"image": {"media_id": "media-1"}, "msgtype": "image"})
assert "msg-1" not in channel._ws_frames
assert "msg-1" not in channel._ws_stream_ids
_run(go())
def test_send_falls_back_to_send_message_without_thread_context(self):
from app.channels.wecom import WeComChannel
async def go():
bus = MessageBus()
channel = WeComChannel(bus, config={})
channel._ws_client = SimpleNamespace(send_message=AsyncMock())
msg = OutboundMessage(
channel_name="wecom",
chat_id="user-1",
thread_id="thread-1",
text="hello",
thread_ts=None,
)
await channel.send(msg)
channel._ws_client.send_message.assert_awaited_once_with(
"user-1",
{"msgtype": "markdown", "markdown": {"content": "hello"}},
)
_run(go())
class TestChannelService:
def test_get_status_no_channels(self):
from app.channels.service import ChannelService
@ -1835,6 +1988,47 @@ class TestSlackSendRetry:
_run(go())
class TestSlackAllowedUsers:
def test_numeric_allowed_users_match_string_event_user_id(self):
from app.channels.slack import SlackChannel
bus = MessageBus()
bus.publish_inbound = AsyncMock()
channel = SlackChannel(
bus=bus,
config={"allowed_users": [123456]},
)
channel._loop = MagicMock()
channel._loop.is_running.return_value = True
channel._add_reaction = MagicMock()
channel._send_running_reply = MagicMock()
event = {
"user": "123456",
"text": "hello from slack",
"channel": "C123",
"ts": "1710000000.000100",
}
def submit_coro(coro, loop):
coro.close()
return MagicMock()
with patch(
"app.channels.slack.asyncio.run_coroutine_threadsafe",
side_effect=submit_coro,
) as submit:
channel._handle_message_event(event)
channel._add_reaction.assert_called_once_with("C123", "1710000000.000100", "eyes")
channel._send_running_reply.assert_called_once_with("C123", "1710000000.000100")
submit.assert_called_once()
inbound = bus.publish_inbound.call_args.args[0]
assert inbound.user_id == "123456"
assert inbound.chat_id == "C123"
assert inbound.text == "hello from slack"
def test_raises_after_all_retries_exhausted(self):
from app.channels.slack import SlackChannel
@ -1854,6 +2048,20 @@ class TestSlackSendRetry:
_run(go())
def test_raises_runtime_error_when_no_attempts_configured(self):
from app.channels.slack import SlackChannel
async def go():
bus = MessageBus()
ch = SlackChannel(bus=bus, config={"bot_token": "xoxb-test", "app_token": "xapp-test"})
ch._web_client = MagicMock()
msg = OutboundMessage(channel_name="slack", chat_id="C123", thread_id="t1", text="hello")
with pytest.raises(RuntimeError, match="without an exception"):
await ch.send(msg, _max_retries=0)
_run(go())
# ---------------------------------------------------------------------------
# Telegram send retry tests
@ -1912,6 +2120,36 @@ class TestTelegramSendRetry:
_run(go())
def test_raises_runtime_error_when_no_attempts_configured(self):
from app.channels.telegram import TelegramChannel
async def go():
bus = MessageBus()
ch = TelegramChannel(bus=bus, config={"bot_token": "test-token"})
ch._application = MagicMock()
msg = OutboundMessage(channel_name="telegram", chat_id="12345", thread_id="t1", text="hello")
with pytest.raises(RuntimeError, match="without an exception"):
await ch.send(msg, _max_retries=0)
_run(go())
class TestFeishuSendRetry:
def test_raises_runtime_error_when_no_attempts_configured(self):
from app.channels.feishu import FeishuChannel
async def go():
bus = MessageBus()
ch = FeishuChannel(bus=bus, config={"app_id": "id", "app_secret": "secret"})
ch._api_client = MagicMock()
msg = OutboundMessage(channel_name="feishu", chat_id="chat", thread_id="t1", text="hello")
with pytest.raises(RuntimeError, match="without an exception"):
await ch.send(msg, _max_retries=0)
_run(go())
# ---------------------------------------------------------------------------
# Telegram private-chat thread context tests

View File

@ -59,18 +59,20 @@ class TestClientInit:
assert client._subagent_enabled is False
assert client._plan_mode is False
assert client._agent_name is None
assert client._available_skills is None
assert client._checkpointer is None
assert client._agent is None
def test_custom_params(self, mock_app_config):
mock_middleware = MagicMock()
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", middlewares=[mock_middleware])
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", available_skills={"skill1", "skill2"}, middlewares=[mock_middleware])
assert c._model_name == "gpt-4"
assert c._thinking_enabled is False
assert c._subagent_enabled is True
assert c._plan_mode is True
assert c._agent_name == "test-agent"
assert c._available_skills == {"skill1", "skill2"}
assert c._middlewares == [mock_middleware]
def test_invalid_agent_name(self, mock_app_config):
@ -394,8 +396,10 @@ class TestEnsureAgent:
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
):
client._agent_name = "custom-agent"
client._available_skills = {"test_skill"}
client._ensure_agent(config)
assert client._agent is mock_agent
@ -404,6 +408,7 @@ class TestEnsureAgent:
assert mock_build_middlewares.call_args.kwargs.get("agent_name") == "custom-agent"
mock_apply_prompt.assert_called_once()
assert mock_apply_prompt.call_args.kwargs.get("agent_name") == "custom-agent"
assert mock_apply_prompt.call_args.kwargs.get("available_skills") == {"test_skill"}
def test_uses_default_checkpointer_when_available(self, client):
mock_agent = MagicMock()
@ -441,6 +446,7 @@ class TestEnsureAgent:
patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
):
client._ensure_agent(config)
@ -469,7 +475,7 @@ class TestEnsureAgent:
"""_ensure_agent does not recreate if config key unchanged."""
mock_agent = MagicMock()
client._agent = mock_agent
client._agent_config_key = (None, True, False, False)
client._agent_config_key = (None, True, False, False, None, None)
config = client._get_runnable_config("t1")
client._ensure_agent(config)
@ -1276,6 +1282,7 @@ class TestScenarioAgentRecreation:
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
):
client._ensure_agent(config_a)
first_agent = client._agent
@ -1303,6 +1310,7 @@ class TestScenarioAgentRecreation:
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
):
client._ensure_agent(config)
client._ensure_agent(config)
@ -1327,6 +1335,7 @@ class TestScenarioAgentRecreation:
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
):
client._ensure_agent(config)
client.reset_agent()

View File

@ -439,6 +439,15 @@ class TestAgentsAPI:
assert "agent-one" in names
assert "agent-two" in names
def test_list_agents_includes_soul(self, agent_client):
agent_client.post("/api/agents", json={"name": "soul-agent", "soul": "My soul content"})
response = agent_client.get("/api/agents")
assert response.status_code == 200
agents = response.json()["agents"]
soul_agent = next(a for a in agents if a["name"] == "soul-agent")
assert soul_agent["soul"] == "My soul content"
def test_get_agent(self, agent_client):
agent_client.post("/api/agents", json={"name": "test-agent", "soul": "Hello world"})

View File

@ -0,0 +1,214 @@
"""Tests for _ensure_admin_user() in app.py.
Covers: first-boot admin creation, auto-reset on needs_setup=True,
no-op on needs_setup=False, migration, and edge cases.
"""
import asyncio
import os
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-ensure-admin-testing-min-32")
from app.gateway.auth.config import AuthConfig, set_auth_config
from app.gateway.auth.models import User
_JWT_SECRET = "test-secret-key-ensure-admin-testing-min-32"
@pytest.fixture(autouse=True)
def _setup_auth_config():
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
yield
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
def _make_app_stub(store=None):
"""Minimal app-like object with state.store."""
app = SimpleNamespace()
app.state = SimpleNamespace()
app.state.store = store
return app
def _make_provider(user_count=0, admin_user=None):
p = AsyncMock()
p.count_users = AsyncMock(return_value=user_count)
p.create_user = AsyncMock(
side_effect=lambda **kw: User(
email=kw["email"],
password_hash="hashed",
system_role=kw.get("system_role", "user"),
needs_setup=kw.get("needs_setup", False),
)
)
p.get_user_by_email = AsyncMock(return_value=admin_user)
p.update_user = AsyncMock(side_effect=lambda u: u)
return p
# ── First boot: no users ─────────────────────────────────────────────────
def test_first_boot_creates_admin():
"""count_users==0 → create admin with needs_setup=True."""
provider = _make_provider(user_count=0)
app = _make_app_stub()
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
provider.create_user.assert_called_once()
call_kwargs = provider.create_user.call_args[1]
assert call_kwargs["email"] == "admin@deerflow.dev"
assert call_kwargs["system_role"] == "admin"
assert call_kwargs["needs_setup"] is True
assert len(call_kwargs["password"]) > 10 # random password generated
def test_first_boot_triggers_migration_if_store_present():
"""First boot with store → _migrate_orphaned_threads called."""
provider = _make_provider(user_count=0)
store = AsyncMock()
store.asearch = AsyncMock(return_value=[])
app = _make_app_stub(store=store)
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
store.asearch.assert_called_once()
def test_first_boot_no_store_skips_migration():
"""First boot without store → no crash, migration skipped."""
provider = _make_provider(user_count=0)
app = _make_app_stub(store=None)
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
provider.create_user.assert_called_once()
# ── Subsequent boot: needs_setup=True → auto-reset ───────────────────────
def test_needs_setup_true_resets_password():
"""Existing admin with needs_setup=True → password reset + token_version bumped."""
admin = User(
email="admin@deerflow.dev",
password_hash="old-hash",
system_role="admin",
needs_setup=True,
token_version=0,
created_at=datetime.now(UTC) - timedelta(seconds=30),
)
provider = _make_provider(user_count=1, admin_user=admin)
app = _make_app_stub()
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
# Password was reset
provider.update_user.assert_called_once()
updated = provider.update_user.call_args[0][0]
assert updated.password_hash == "new-hash"
assert updated.token_version == 1
def test_needs_setup_true_consecutive_resets_increment_version():
"""Two boots with needs_setup=True → token_version increments each time."""
admin = User(
email="admin@deerflow.dev",
password_hash="hash",
system_role="admin",
needs_setup=True,
token_version=3,
created_at=datetime.now(UTC) - timedelta(seconds=30),
)
provider = _make_provider(user_count=1, admin_user=admin)
app = _make_app_stub()
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
updated = provider.update_user.call_args[0][0]
assert updated.token_version == 4
# ── Subsequent boot: needs_setup=False → no-op ──────────────────────────
def test_needs_setup_false_no_reset():
"""Admin with needs_setup=False → no password reset, no update."""
admin = User(
email="admin@deerflow.dev",
password_hash="stable-hash",
system_role="admin",
needs_setup=False,
token_version=2,
)
provider = _make_provider(user_count=1, admin_user=admin)
app = _make_app_stub()
with patch("app.gateway.deps.get_local_provider", return_value=provider):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
provider.update_user.assert_not_called()
assert admin.password_hash == "stable-hash"
assert admin.token_version == 2
# ── Edge cases ───────────────────────────────────────────────────────────
def test_no_admin_email_found_no_crash():
"""Users exist but no admin@deerflow.dev → no crash, no reset."""
provider = _make_provider(user_count=3, admin_user=None)
app = _make_app_stub()
with patch("app.gateway.deps.get_local_provider", return_value=provider):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
provider.update_user.assert_not_called()
provider.create_user.assert_not_called()
def test_migration_failure_is_non_fatal():
"""_migrate_orphaned_threads exception is caught and logged."""
provider = _make_provider(user_count=0)
store = AsyncMock()
store.asearch = AsyncMock(side_effect=RuntimeError("store crashed"))
app = _make_app_stub(store=store)
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
from app.gateway.app import _ensure_admin_user
# Should not raise
asyncio.run(_ensure_admin_user(app))
provider.create_user.assert_called_once()

View File

@ -0,0 +1,459 @@
"""Tests for file_conversion utilities (PR1: pymupdf4llm + asyncio.to_thread; PR2: extract_outline)."""
from __future__ import annotations
import asyncio
import sys
from types import ModuleType
from unittest.mock import MagicMock, patch
from deerflow.utils.file_conversion import (
_ASYNC_THRESHOLD_BYTES,
_MIN_CHARS_PER_PAGE,
MAX_OUTLINE_ENTRIES,
_do_convert,
_pymupdf_output_too_sparse,
convert_file_to_markdown,
extract_outline,
)
def _make_pymupdf_mock(page_count: int) -> ModuleType:
"""Return a fake *pymupdf* module whose ``open()`` reports *page_count* pages."""
mock_doc = MagicMock()
mock_doc.__len__ = MagicMock(return_value=page_count)
fake_pymupdf = ModuleType("pymupdf")
fake_pymupdf.open = MagicMock(return_value=mock_doc) # type: ignore[attr-defined]
return fake_pymupdf
def _run(coro):
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
# ---------------------------------------------------------------------------
# _pymupdf_output_too_sparse
# ---------------------------------------------------------------------------
class TestPymupdfOutputTooSparse:
"""Check the chars-per-page sparsity heuristic."""
def test_dense_text_pdf_not_sparse(self, tmp_path):
"""Normal text PDF: many chars per page → not sparse."""
pdf = tmp_path / "dense.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
# 10 pages × 10 000 chars → 1000/page ≫ threshold
with patch.dict(sys.modules, {"pymupdf": _make_pymupdf_mock(page_count=10)}):
result = _pymupdf_output_too_sparse("x" * 10_000, pdf)
assert result is False
def test_image_based_pdf_is_sparse(self, tmp_path):
"""Image-based PDF: near-zero chars per page → sparse."""
pdf = tmp_path / "image.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
# 612 chars / 31 pages ≈ 19.7/page < _MIN_CHARS_PER_PAGE (50)
with patch.dict(sys.modules, {"pymupdf": _make_pymupdf_mock(page_count=31)}):
result = _pymupdf_output_too_sparse("x" * 612, pdf)
assert result is True
def test_fallback_when_pymupdf_unavailable(self, tmp_path):
"""When pymupdf is not installed, fall back to absolute 200-char threshold."""
pdf = tmp_path / "broken.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
# Remove pymupdf from sys.modules so the `import pymupdf` inside the
# function raises ImportError, triggering the absolute-threshold fallback.
with patch.dict(sys.modules, {"pymupdf": None}):
sparse = _pymupdf_output_too_sparse("x" * 100, pdf)
not_sparse = _pymupdf_output_too_sparse("x" * 300, pdf)
assert sparse is True
assert not_sparse is False
def test_exactly_at_threshold_is_not_sparse(self, tmp_path):
"""Chars-per-page == threshold is treated as NOT sparse (boundary inclusive)."""
pdf = tmp_path / "boundary.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
# 2 pages × _MIN_CHARS_PER_PAGE chars = exactly at threshold
with patch.dict(sys.modules, {"pymupdf": _make_pymupdf_mock(page_count=2)}):
result = _pymupdf_output_too_sparse("x" * (_MIN_CHARS_PER_PAGE * 2), pdf)
assert result is False
# ---------------------------------------------------------------------------
# _do_convert — routing logic
# ---------------------------------------------------------------------------
class TestDoConvert:
"""Verify that _do_convert routes to the right sub-converter."""
def test_non_pdf_always_uses_markitdown(self, tmp_path):
"""DOCX / XLSX / PPTX always go through MarkItDown regardless of setting."""
docx = tmp_path / "report.docx"
docx.write_bytes(b"PK fake docx")
with patch(
"deerflow.utils.file_conversion._convert_with_markitdown",
return_value="# Markdown from MarkItDown",
) as mock_md:
result = _do_convert(docx, "auto")
mock_md.assert_called_once_with(docx)
assert result == "# Markdown from MarkItDown"
def test_pdf_auto_uses_pymupdf4llm_when_dense(self, tmp_path):
"""auto mode: use pymupdf4llm output when it's dense enough."""
pdf = tmp_path / "report.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
dense_text = "# Heading\n" + "word " * 2000 # clearly dense
with (
patch(
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
return_value=dense_text,
),
patch(
"deerflow.utils.file_conversion._pymupdf_output_too_sparse",
return_value=False,
),
patch("deerflow.utils.file_conversion._convert_with_markitdown") as mock_md,
):
result = _do_convert(pdf, "auto")
mock_md.assert_not_called()
assert result == dense_text
def test_pdf_auto_falls_back_when_sparse(self, tmp_path):
"""auto mode: fall back to MarkItDown when pymupdf4llm output is sparse."""
pdf = tmp_path / "scanned.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
with (
patch(
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
return_value="x" * 612, # 19.7 chars/page for 31-page doc
),
patch(
"deerflow.utils.file_conversion._pymupdf_output_too_sparse",
return_value=True,
),
patch(
"deerflow.utils.file_conversion._convert_with_markitdown",
return_value="OCR result via MarkItDown",
) as mock_md,
):
result = _do_convert(pdf, "auto")
mock_md.assert_called_once_with(pdf)
assert result == "OCR result via MarkItDown"
def test_pdf_explicit_pymupdf4llm_skips_sparsity_check(self, tmp_path):
"""'pymupdf4llm' mode: use output as-is even if sparse."""
pdf = tmp_path / "explicit.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
sparse_text = "x" * 10 # very short
with (
patch(
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
return_value=sparse_text,
),
patch("deerflow.utils.file_conversion._convert_with_markitdown") as mock_md,
):
result = _do_convert(pdf, "pymupdf4llm")
mock_md.assert_not_called()
assert result == sparse_text
def test_pdf_explicit_markitdown_skips_pymupdf4llm(self, tmp_path):
"""'markitdown' mode: never attempt pymupdf4llm."""
pdf = tmp_path / "force_md.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
with (
patch("deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm") as mock_pymu,
patch(
"deerflow.utils.file_conversion._convert_with_markitdown",
return_value="MarkItDown result",
),
):
result = _do_convert(pdf, "markitdown")
mock_pymu.assert_not_called()
assert result == "MarkItDown result"
def test_pdf_auto_falls_back_when_pymupdf4llm_not_installed(self, tmp_path):
"""auto mode: if pymupdf4llm is not installed, use MarkItDown directly."""
pdf = tmp_path / "no_pymupdf.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
with (
patch(
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
return_value=None, # None signals not installed
),
patch(
"deerflow.utils.file_conversion._convert_with_markitdown",
return_value="MarkItDown fallback",
) as mock_md,
):
result = _do_convert(pdf, "auto")
mock_md.assert_called_once_with(pdf)
assert result == "MarkItDown fallback"
# ---------------------------------------------------------------------------
# convert_file_to_markdown — async + file writing
# ---------------------------------------------------------------------------
class TestConvertFileToMarkdown:
def test_small_file_runs_synchronously(self, tmp_path):
"""Small files (< 1 MB) are converted in the event loop thread."""
pdf = tmp_path / "small.pdf"
pdf.write_bytes(b"%PDF-1.4 " + b"x" * 100) # well under 1 MB
with (
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
patch(
"deerflow.utils.file_conversion._do_convert",
return_value="# Small PDF",
) as mock_convert,
patch("asyncio.to_thread") as mock_thread,
):
md_path = _run(convert_file_to_markdown(pdf))
# asyncio.to_thread must NOT have been called
mock_thread.assert_not_called()
mock_convert.assert_called_once()
assert md_path == pdf.with_suffix(".md")
assert md_path.read_text() == "# Small PDF"
def test_large_file_offloaded_to_thread(self, tmp_path):
"""Large files (> 1 MB) are offloaded via asyncio.to_thread."""
pdf = tmp_path / "large.pdf"
# Write slightly more than the threshold
pdf.write_bytes(b"%PDF-1.4 " + b"x" * (_ASYNC_THRESHOLD_BYTES + 1))
async def fake_to_thread(fn, *args, **kwargs):
return fn(*args, **kwargs)
with (
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
patch(
"deerflow.utils.file_conversion._do_convert",
return_value="# Large PDF",
),
patch("asyncio.to_thread", side_effect=fake_to_thread) as mock_thread,
):
md_path = _run(convert_file_to_markdown(pdf))
mock_thread.assert_called_once()
assert md_path == pdf.with_suffix(".md")
assert md_path.read_text() == "# Large PDF"
def test_returns_none_on_conversion_error(self, tmp_path):
"""If conversion raises, return None without propagating the exception."""
pdf = tmp_path / "broken.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
with (
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
patch(
"deerflow.utils.file_conversion._do_convert",
side_effect=RuntimeError("conversion failed"),
),
):
result = _run(convert_file_to_markdown(pdf))
assert result is None
def test_writes_utf8_markdown_file(self, tmp_path):
"""Generated .md file is written with UTF-8 encoding."""
pdf = tmp_path / "report.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
chinese_content = "# 中文报告\n\n这是测试内容。"
with (
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
patch(
"deerflow.utils.file_conversion._do_convert",
return_value=chinese_content,
),
):
md_path = _run(convert_file_to_markdown(pdf))
assert md_path is not None
assert md_path.read_text(encoding="utf-8") == chinese_content
# ---------------------------------------------------------------------------
# extract_outline
# ---------------------------------------------------------------------------
class TestExtractOutline:
"""Tests for extract_outline()."""
def test_empty_file_returns_empty(self, tmp_path):
"""Empty markdown file yields no outline entries."""
md = tmp_path / "empty.md"
md.write_text("", encoding="utf-8")
assert extract_outline(md) == []
def test_missing_file_returns_empty(self, tmp_path):
"""Non-existent path returns [] without raising."""
assert extract_outline(tmp_path / "nonexistent.md") == []
def test_standard_markdown_headings(self, tmp_path):
"""# / ## / ### headings are all recognised."""
md = tmp_path / "doc.md"
md.write_text(
"# Chapter One\n\nSome text.\n\n## Section 1.1\n\nMore text.\n\n### Sub 1.1.1\n",
encoding="utf-8",
)
outline = extract_outline(md)
assert len(outline) == 3
assert outline[0] == {"title": "Chapter One", "line": 1}
assert outline[1] == {"title": "Section 1.1", "line": 5}
assert outline[2] == {"title": "Sub 1.1.1", "line": 9}
def test_bold_sec_item_heading(self, tmp_path):
"""**ITEM N. TITLE** lines in SEC filings are recognised."""
md = tmp_path / "10k.md"
md.write_text(
"Cover page text.\n\n**ITEM 1. BUSINESS**\n\nBody.\n\n**ITEM 1A. RISK FACTORS**\n",
encoding="utf-8",
)
outline = extract_outline(md)
assert len(outline) == 2
assert outline[0] == {"title": "ITEM 1. BUSINESS", "line": 3}
assert outline[1] == {"title": "ITEM 1A. RISK FACTORS", "line": 7}
def test_bold_part_heading(self, tmp_path):
"""**PART I** / **PART II** headings are recognised."""
md = tmp_path / "10k.md"
md.write_text("**PART I**\n\n**PART II**\n\n**PART III**\n", encoding="utf-8")
outline = extract_outline(md)
assert len(outline) == 3
titles = [e["title"] for e in outline]
assert "PART I" in titles
assert "PART II" in titles
assert "PART III" in titles
def test_sec_cover_page_boilerplate_excluded(self, tmp_path):
"""Address lines and short cover boilerplate must NOT appear in outline."""
md = tmp_path / "8k.md"
md.write_text(
"## **UNITED STATES SECURITIES AND EXCHANGE COMMISSION**\n\n**WASHINGTON, DC 20549**\n\n**CURRENT REPORT**\n\n**SIGNATURES**\n\n**TESLA, INC.**\n\n**ITEM 2.02. RESULTS OF OPERATIONS**\n",
encoding="utf-8",
)
outline = extract_outline(md)
titles = [e["title"] for e in outline]
# Cover-page boilerplate should be excluded
assert "WASHINGTON, DC 20549" not in titles
assert "CURRENT REPORT" not in titles
assert "SIGNATURES" not in titles
assert "TESLA, INC." not in titles
# Real SEC heading must be included
assert "ITEM 2.02. RESULTS OF OPERATIONS" in titles
def test_chinese_headings_via_standard_markdown(self, tmp_path):
"""Chinese annual report headings emitted as # by pymupdf4llm are captured."""
md = tmp_path / "annual.md"
md.write_text(
"# 第一节 公司简介\n\n内容。\n\n## 第三节 管理层讨论与分析\n\n分析内容。\n",
encoding="utf-8",
)
outline = extract_outline(md)
assert len(outline) == 2
assert outline[0]["title"] == "第一节 公司简介"
assert outline[1]["title"] == "第三节 管理层讨论与分析"
def test_outline_capped_at_max_entries(self, tmp_path):
"""When truncated, result has MAX_OUTLINE_ENTRIES real entries + 1 sentinel."""
lines = [f"# Heading {i}" for i in range(MAX_OUTLINE_ENTRIES + 10)]
md = tmp_path / "long.md"
md.write_text("\n".join(lines), encoding="utf-8")
outline = extract_outline(md)
# Last entry is the truncation sentinel
assert outline[-1] == {"truncated": True}
# Visible entries are exactly MAX_OUTLINE_ENTRIES
visible = [e for e in outline if not e.get("truncated")]
assert len(visible) == MAX_OUTLINE_ENTRIES
def test_no_truncation_sentinel_when_under_limit(self, tmp_path):
"""Short documents produce no sentinel entry."""
lines = [f"# Heading {i}" for i in range(5)]
md = tmp_path / "short.md"
md.write_text("\n".join(lines), encoding="utf-8")
outline = extract_outline(md)
assert len(outline) == 5
assert not any(e.get("truncated") for e in outline)
def test_blank_lines_and_whitespace_ignored(self, tmp_path):
"""Blank lines between headings do not produce empty entries."""
md = tmp_path / "spaced.md"
md.write_text("\n\n# Title One\n\n\n\n# Title Two\n\n", encoding="utf-8")
outline = extract_outline(md)
assert len(outline) == 2
assert all(e["title"] for e in outline)
def test_inline_bold_not_confused_with_heading(self, tmp_path):
"""Mid-sentence bold text must not be mistaken for a heading."""
md = tmp_path / "prose.md"
md.write_text(
"This sentence has **bold words** inside it.\n\nAnother with **MULTIPLE CAPS** inline.\n",
encoding="utf-8",
)
outline = extract_outline(md)
assert outline == []
def test_split_bold_heading_academic_paper(self, tmp_path):
"""**<num>** **<title>** lines from academic papers are recognised (Style 3)."""
md = tmp_path / "paper.md"
md.write_text(
"## **Attention Is All You Need**\n\n**1** **Introduction**\n\nBody text.\n\n**2** **Background**\n\nMore text.\n\n**3.1** **Encoder and Decoder Stacks**\n",
encoding="utf-8",
)
outline = extract_outline(md)
titles = [e["title"] for e in outline]
assert "1 Introduction" in titles
assert "2 Background" in titles
assert "3.1 Encoder and Decoder Stacks" in titles
def test_split_bold_year_columns_excluded(self, tmp_path):
"""Financial table headers like **2023** **2022** **2021** are NOT headings."""
md = tmp_path / "annual.md"
md.write_text(
"# Financial Summary\n\n**2023** **2022** **2021**\n\nRevenue 100 90 80\n",
encoding="utf-8",
)
outline = extract_outline(md)
titles = [e["title"] for e in outline]
# Only the # heading should appear, not the year-column row
assert titles == ["Financial Summary"]
def test_adjacent_bold_spans_merged_in_markdown_heading(self, tmp_path):
"""** ** artefacts inside a # heading are merged into clean plain text."""
md = tmp_path / "sec.md"
md.write_text(
"## **UNITED STATES** **SECURITIES AND EXCHANGE COMMISSION**\n\nBody text.\n",
encoding="utf-8",
)
outline = extract_outline(md)
assert len(outline) == 1
# Title must be clean — no ** ** artefacts
assert outline[0]["title"] == "UNITED STATES SECURITIES AND EXCHANGE COMMISSION"

View File

@ -8,6 +8,7 @@ import pytest
from deerflow.config.acp_config import ACPAgentConfig
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig, set_extensions_config
from deerflow.tools.builtins.invoke_acp_agent_tool import (
_build_acp_mcp_servers,
_build_mcp_servers,
_build_permission_response,
_get_work_dir,
@ -42,6 +43,43 @@ def test_build_mcp_servers_filters_disabled_and_maps_transports():
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
def test_build_acp_mcp_servers_formats_list_payload():
set_extensions_config(ExtensionsConfig(mcp_servers={"stale": McpServerConfig(enabled=True, type="stdio", command="echo")}, skills={}))
fresh_config = ExtensionsConfig(
mcp_servers={
"stdio": McpServerConfig(enabled=True, type="stdio", command="npx", args=["srv"], env={"FOO": "bar"}),
"http": McpServerConfig(enabled=True, type="http", url="https://example.com/mcp", headers={"Authorization": "Bearer token"}),
"disabled": McpServerConfig(enabled=False, type="stdio", command="echo"),
},
skills={},
)
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: fresh_config),
)
try:
assert _build_acp_mcp_servers() == [
{
"name": "stdio",
"type": "stdio",
"command": "npx",
"args": ["srv"],
"env": [{"name": "FOO", "value": "bar"}],
},
{
"name": "http",
"type": "http",
"url": "https://example.com/mcp",
"headers": [{"name": "Authorization", "value": "Bearer token"}],
},
]
finally:
monkeypatch.undo()
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
def test_build_permission_response_prefers_allow_once():
response = _build_permission_response(
[
@ -251,9 +289,15 @@ async def test_invoke_acp_agent_uses_fixed_acp_workspace(monkeypatch, tmp_path):
assert captured["spawn"] == {"cmd": "codex-acp", "args": ["--json"], "cwd": expected_cwd}
assert captured["new_session"] == {
"cwd": expected_cwd,
"mcp_servers": {
"github": {"transport": "stdio", "command": "npx", "args": ["github-mcp"]},
},
"mcp_servers": [
{
"name": "github",
"type": "stdio",
"command": "npx",
"args": ["github-mcp"],
"env": [],
}
],
"model": "gpt-5-codex",
}
assert captured["prompt"] == {
@ -448,6 +492,94 @@ async def test_invoke_acp_agent_passes_env_to_spawn(monkeypatch, tmp_path):
assert captured["env"] == {"OPENAI_API_KEY": "sk-from-env", "FOO": "bar"}
@pytest.mark.anyio
async def test_invoke_acp_agent_skips_invalid_mcp_servers(monkeypatch, tmp_path, caplog):
"""Invalid MCP config should be logged and skipped instead of failing ACP invocation."""
from deerflow.config import paths as paths_module
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
monkeypatch.setattr(
"deerflow.tools.builtins.invoke_acp_agent_tool._build_acp_mcp_servers",
lambda: (_ for _ in ()).throw(ValueError("missing command")),
)
captured: dict[str, object] = {}
class DummyClient:
def __init__(self) -> None:
self._chunks: list[str] = []
@property
def collected_text(self) -> str:
return ""
async def session_update(self, session_id, update, **kwargs):
pass
async def request_permission(self, options, session_id, tool_call, **kwargs):
raise AssertionError("should not be called")
class DummyConn:
async def initialize(self, **kwargs):
pass
async def new_session(self, **kwargs):
captured["new_session"] = kwargs
return SimpleNamespace(session_id="s1")
async def prompt(self, **kwargs):
pass
class DummyProcessContext:
def __init__(self, client, cmd, *args, env=None, cwd=None):
captured["spawn"] = {"cmd": cmd, "args": list(args), "env": env, "cwd": cwd}
async def __aenter__(self):
return DummyConn(), object()
async def __aexit__(self, exc_type, exc, tb):
return False
class DummyRequestError(Exception):
@staticmethod
def method_not_found(method):
return DummyRequestError(method)
monkeypatch.setitem(
sys.modules,
"acp",
SimpleNamespace(
PROTOCOL_VERSION="2026-03-24",
Client=DummyClient,
RequestError=DummyRequestError,
spawn_agent_process=lambda client, cmd, *args, env=None, cwd: DummyProcessContext(client, cmd, *args, env=env, cwd=cwd),
text_block=lambda text: {"type": "text", "text": text},
),
)
monkeypatch.setitem(
sys.modules,
"acp.schema",
SimpleNamespace(
ClientCapabilities=lambda: {},
Implementation=lambda **kwargs: kwargs,
TextContentBlock=type("TextContentBlock", (), {"__init__": lambda self, text: setattr(self, "text", text)}),
),
)
tool = build_invoke_acp_agent_tool({"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")})
caplog.set_level("WARNING")
try:
await tool.coroutine(agent="codex", prompt="Do something")
finally:
sys.modules.pop("acp", None)
sys.modules.pop("acp.schema", None)
assert captured["new_session"]["mcp_servers"] == []
assert "continuing without MCP servers" in caplog.text
assert "missing command" in caplog.text
@pytest.mark.anyio
async def test_invoke_acp_agent_passes_none_env_when_not_configured(monkeypatch, tmp_path):
"""When env is empty, None is passed to spawn_agent_process (subprocess inherits parent env)."""

View File

@ -0,0 +1,312 @@
"""Tests for LangGraph Server auth handler (langgraph_auth.py).
Validates that the LangGraph auth layer enforces the same rules as Gateway:
cookie JWT decode DB lookup token_version check owner filter
"""
import asyncio
import os
from datetime import timedelta
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-for-langgraph-auth-testing-min-32")
from langgraph_sdk import Auth
from app.gateway.auth.config import AuthConfig, set_auth_config
from app.gateway.auth.jwt import create_access_token, decode_token
from app.gateway.auth.models import User
from app.gateway.langgraph_auth import add_owner_filter, authenticate
# ── Helpers ───────────────────────────────────────────────────────────────
_JWT_SECRET = "test-secret-key-for-langgraph-auth-testing-min-32"
@pytest.fixture(autouse=True)
def _setup_auth_config():
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
yield
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
def _req(cookies=None, method="GET", headers=None):
return SimpleNamespace(cookies=cookies or {}, method=method, headers=headers or {})
def _user(user_id=None, token_version=0):
return User(email="test@example.com", password_hash="fakehash", system_role="user", id=user_id or uuid4(), token_version=token_version)
def _mock_provider(user=None):
p = AsyncMock()
p.get_user = AsyncMock(return_value=user)
return p
# ── @auth.authenticate ───────────────────────────────────────────────────
def test_no_cookie_raises_401():
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req()))
assert exc.value.status_code == 401
assert "Not authenticated" in str(exc.value.detail)
def test_invalid_jwt_raises_401():
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": "garbage"})))
assert exc.value.status_code == 401
assert "Token error" in str(exc.value.detail)
def test_expired_jwt_raises_401():
token = create_access_token("user-1", expires_delta=timedelta(seconds=-1))
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": token})))
assert exc.value.status_code == 401
def test_user_not_found_raises_401():
token = create_access_token("ghost")
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(None)):
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": token})))
assert exc.value.status_code == 401
assert "User not found" in str(exc.value.detail)
def test_token_version_mismatch_raises_401():
user = _user(token_version=2)
token = create_access_token(str(user.id), token_version=1)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": token})))
assert exc.value.status_code == 401
assert "revoked" in str(exc.value.detail).lower()
def test_valid_token_returns_user_id():
user = _user(token_version=0)
token = create_access_token(str(user.id), token_version=0)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
result = asyncio.run(authenticate(_req({"access_token": token})))
assert result == str(user.id)
def test_valid_token_matching_version():
user = _user(token_version=5)
token = create_access_token(str(user.id), token_version=5)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
result = asyncio.run(authenticate(_req({"access_token": token})))
assert result == str(user.id)
# ── @auth.authenticate edge cases ────────────────────────────────────────
def test_provider_exception_propagates():
"""Provider raises → should not be swallowed silently."""
token = create_access_token("user-1")
p = AsyncMock()
p.get_user = AsyncMock(side_effect=RuntimeError("DB down"))
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=p):
with pytest.raises(RuntimeError, match="DB down"):
asyncio.run(authenticate(_req({"access_token": token})))
def test_jwt_missing_ver_defaults_to_zero():
"""JWT without 'ver' claim → decoded as ver=0, matches user with token_version=0."""
import jwt as pyjwt
uid = str(uuid4())
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
user = _user(user_id=uid, token_version=0)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
result = asyncio.run(authenticate(_req({"access_token": raw})))
assert result == uid
def test_jwt_missing_ver_rejected_when_user_version_nonzero():
"""JWT without 'ver' (defaults 0) vs user with token_version=1 → 401."""
import jwt as pyjwt
uid = str(uuid4())
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
user = _user(user_id=uid, token_version=1)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": raw})))
assert exc.value.status_code == 401
def test_wrong_secret_raises_401():
"""Token signed with different secret → 401."""
import jwt as pyjwt
raw = pyjwt.encode({"sub": "user-1", "exp": 9999999999, "ver": 0}, "wrong-secret-that-is-long-enough-32chars!", algorithm="HS256")
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": raw})))
assert exc.value.status_code == 401
# ── @auth.on (owner filter) ──────────────────────────────────────────────
class _FakeUser:
"""Minimal BaseUser-compatible object without langgraph_api.config dependency."""
def __init__(self, identity: str):
self.identity = identity
self.is_authenticated = True
self.display_name = identity
def _make_ctx(user_id):
return Auth.types.AuthContext(resource="threads", action="create", user=_FakeUser(user_id), permissions=[])
def test_filter_injects_user_id():
value = {}
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
assert value["metadata"]["user_id"] == "user-a"
def test_filter_preserves_existing_metadata():
value = {"metadata": {"title": "hello"}}
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
assert value["metadata"]["user_id"] == "user-a"
assert value["metadata"]["title"] == "hello"
def test_filter_returns_user_id_dict():
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
assert result == {"user_id": "user-x"}
def test_filter_read_write_consistency():
value = {}
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
assert value["metadata"]["user_id"] == filter_dict["user_id"]
def test_different_users_different_filters():
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
assert f_a["user_id"] != f_b["user_id"]
def test_filter_overrides_conflicting_user_id():
"""If value already has a different user_id in metadata, it gets overwritten."""
value = {"metadata": {"user_id": "attacker"}}
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
assert value["metadata"]["user_id"] == "real-owner"
def test_filter_with_empty_metadata():
"""Explicit empty metadata dict is fine."""
value = {"metadata": {}}
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
assert value["metadata"]["user_id"] == "user-z"
assert result == {"user_id": "user-z"}
# ── Gateway parity ───────────────────────────────────────────────────────
def test_shared_jwt_secret():
token = create_access_token("user-1", token_version=3)
payload = decode_token(token)
from app.gateway.auth.errors import TokenError
assert not isinstance(payload, TokenError)
assert payload.sub == "user-1"
assert payload.ver == 3
def test_langgraph_json_has_auth_path():
import json
config = json.loads((Path(__file__).parent.parent / "langgraph.json").read_text())
assert "auth" in config
assert "langgraph_auth" in config["auth"]["path"]
def test_auth_handler_has_both_layers():
from app.gateway.langgraph_auth import auth
assert auth._authenticate_handler is not None
assert len(auth._global_handlers) == 1
# ── CSRF in LangGraph auth ──────────────────────────────────────────────
def test_csrf_get_no_check():
"""GET requests skip CSRF — should proceed to JWT validation."""
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req(method="GET")))
# Rejected by missing cookie, NOT by CSRF
assert exc.value.status_code == 401
assert "Not authenticated" in str(exc.value.detail)
def test_csrf_post_missing_token():
"""POST without CSRF token → 403."""
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req(method="POST", cookies={"access_token": "some-jwt"})))
assert exc.value.status_code == 403
assert "CSRF token missing" in str(exc.value.detail)
def test_csrf_post_mismatched_token():
"""POST with mismatched CSRF tokens → 403."""
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(
authenticate(
_req(
method="POST",
cookies={"access_token": "some-jwt", "csrf_token": "real-token"},
headers={"x-csrf-token": "wrong-token"},
)
)
)
assert exc.value.status_code == 403
assert "mismatch" in str(exc.value.detail)
def test_csrf_post_matching_token_proceeds_to_jwt():
"""POST with matching CSRF tokens passes CSRF check, then fails on JWT."""
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(
authenticate(
_req(
method="POST",
cookies={"access_token": "garbage", "csrf_token": "same-token"},
headers={"x-csrf-token": "same-token"},
)
)
)
# Past CSRF, rejected by JWT decode
assert exc.value.status_code == 401
assert "Token error" in str(exc.value.detail)
def test_csrf_put_requires_token():
"""PUT also requires CSRF."""
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req(method="PUT", cookies={"access_token": "jwt"})))
assert exc.value.status_code == 403
def test_csrf_delete_requires_token():
"""DELETE also requires CSRF."""
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req(method="DELETE", cookies={"access_token": "jwt"})))
assert exc.value.status_code == 403

View File

@ -0,0 +1,388 @@
import errno
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping
from deerflow.sandbox.local.local_sandbox_provider import LocalSandboxProvider
class TestPathMapping:
def test_path_mapping_dataclass(self):
mapping = PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True)
assert mapping.container_path == "/mnt/skills"
assert mapping.local_path == "/home/user/skills"
assert mapping.read_only is True
def test_path_mapping_defaults_to_false(self):
mapping = PathMapping(container_path="/mnt/data", local_path="/home/user/data")
assert mapping.read_only is False
class TestLocalSandboxPathResolution:
def test_resolve_path_exact_match(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
],
)
resolved = sandbox._resolve_path("/mnt/skills")
assert resolved == "/home/user/skills"
def test_resolve_path_nested_path(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
],
)
resolved = sandbox._resolve_path("/mnt/skills/agent/prompt.py")
assert resolved == "/home/user/skills/agent/prompt.py"
def test_resolve_path_no_mapping(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
],
)
resolved = sandbox._resolve_path("/mnt/other/file.txt")
assert resolved == "/mnt/other/file.txt"
def test_resolve_path_longest_prefix_first(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
PathMapping(container_path="/mnt", local_path="/var/mnt"),
],
)
resolved = sandbox._resolve_path("/mnt/skills/file.py")
# Should match /mnt/skills first (longer prefix)
assert resolved == "/home/user/skills/file.py"
def test_reverse_resolve_path_exact_match(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir)),
],
)
resolved = sandbox._reverse_resolve_path(str(skills_dir))
assert resolved == "/mnt/skills"
def test_reverse_resolve_path_nested(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
file_path = skills_dir / "agent" / "prompt.py"
file_path.parent.mkdir()
file_path.write_text("test")
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir)),
],
)
resolved = sandbox._reverse_resolve_path(str(file_path))
assert resolved == "/mnt/skills/agent/prompt.py"
class TestReadOnlyPath:
def test_is_read_only_true(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True),
],
)
assert sandbox._is_read_only_path("/home/user/skills/file.py") is True
def test_is_read_only_false_for_writable(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/data", local_path="/home/user/data", read_only=False),
],
)
assert sandbox._is_read_only_path("/home/user/data/file.txt") is False
def test_is_read_only_false_for_unmapped_path(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True),
],
)
# Path not under any mapping
assert sandbox._is_read_only_path("/tmp/other/file.txt") is False
def test_is_read_only_true_for_exact_match(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True),
],
)
assert sandbox._is_read_only_path("/home/user/skills") is True
def test_write_file_blocked_on_read_only(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir), read_only=True),
],
)
# Skills dir is read-only, write should be blocked
with pytest.raises(OSError) as exc_info:
sandbox.write_file("/mnt/skills/new_file.py", "content")
assert exc_info.value.errno == errno.EROFS
def test_write_file_allowed_on_writable_mount(self, tmp_path):
data_dir = tmp_path / "data"
data_dir.mkdir()
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/data", local_path=str(data_dir), read_only=False),
],
)
sandbox.write_file("/mnt/data/file.txt", "content")
assert (data_dir / "file.txt").read_text() == "content"
def test_update_file_blocked_on_read_only(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
existing_file = skills_dir / "existing.py"
existing_file.write_bytes(b"original")
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir), read_only=True),
],
)
with pytest.raises(OSError) as exc_info:
sandbox.update_file("/mnt/skills/existing.py", b"updated")
assert exc_info.value.errno == errno.EROFS
class TestMultipleMounts:
def test_multiple_read_write_mounts(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
data_dir = tmp_path / "data"
data_dir.mkdir()
external_dir = tmp_path / "external"
external_dir.mkdir()
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir), read_only=True),
PathMapping(container_path="/mnt/data", local_path=str(data_dir), read_only=False),
PathMapping(container_path="/mnt/external", local_path=str(external_dir), read_only=True),
],
)
# Skills is read-only
with pytest.raises(OSError):
sandbox.write_file("/mnt/skills/file.py", "content")
# Data is writable
sandbox.write_file("/mnt/data/file.txt", "data content")
assert (data_dir / "file.txt").read_text() == "data content"
# External is read-only
with pytest.raises(OSError):
sandbox.write_file("/mnt/external/file.txt", "content")
def test_nested_mounts_writable_under_readonly(self, tmp_path):
"""A writable mount nested under a read-only mount should allow writes."""
ro_dir = tmp_path / "ro"
ro_dir.mkdir()
rw_dir = ro_dir / "writable"
rw_dir.mkdir()
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/repo", local_path=str(ro_dir), read_only=True),
PathMapping(container_path="/mnt/repo/writable", local_path=str(rw_dir), read_only=False),
],
)
# Parent mount is read-only
with pytest.raises(OSError):
sandbox.write_file("/mnt/repo/file.txt", "content")
# Nested writable mount should allow writes
sandbox.write_file("/mnt/repo/writable/file.txt", "content")
assert (rw_dir / "file.txt").read_text() == "content"
def test_execute_command_path_replacement(self, tmp_path, monkeypatch):
data_dir = tmp_path / "data"
data_dir.mkdir()
test_file = data_dir / "test.txt"
test_file.write_text("hello")
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/data", local_path=str(data_dir)),
],
)
# Mock subprocess to capture the resolved command
captured = {}
original_run = __import__("subprocess").run
def mock_run(*args, **kwargs):
if len(args) > 0:
captured["command"] = args[0]
return original_run(*args, **kwargs)
monkeypatch.setattr("deerflow.sandbox.local.local_sandbox.subprocess.run", mock_run)
monkeypatch.setattr("deerflow.sandbox.local.local_sandbox.LocalSandbox._get_shell", lambda self: "/bin/sh")
sandbox.execute_command("cat /mnt/data/test.txt")
# Verify the command received the resolved local path
assert str(data_dir) in captured.get("command", "")
def test_reverse_resolve_path_does_not_match_partial_prefix(self, tmp_path):
foo_dir = tmp_path / "foo"
foo_dir.mkdir()
foobar_dir = tmp_path / "foobar"
foobar_dir.mkdir()
target = foobar_dir / "file.txt"
target.write_text("test")
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/foo", local_path=str(foo_dir)),
],
)
resolved = sandbox._reverse_resolve_path(str(target))
assert resolved == str(target.resolve())
def test_reverse_resolve_paths_in_output_supports_backslash_separator(self, tmp_path):
mount_dir = tmp_path / "mount"
mount_dir.mkdir()
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/data", local_path=str(mount_dir)),
],
)
output = f"Copied: {mount_dir}\\file.txt"
masked = sandbox._reverse_resolve_paths_in_output(output)
assert "/mnt/data/file.txt" in masked
assert str(mount_dir) not in masked
class TestLocalSandboxProviderMounts:
def test_setup_path_mappings_uses_configured_skills_container_path_as_reserved_prefix(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
custom_dir = tmp_path / "custom"
custom_dir.mkdir()
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
sandbox_config = SandboxConfig(
use="deerflow.sandbox.local:LocalSandboxProvider",
mounts=[
VolumeMountConfig(host_path=str(custom_dir), container_path="/custom-skills/nested", read_only=False),
],
)
config = SimpleNamespace(
skills=SimpleNamespace(container_path="/custom-skills", get_skills_path=lambda: skills_dir),
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/custom-skills"]
def test_setup_path_mappings_skips_relative_host_path(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
sandbox_config = SandboxConfig(
use="deerflow.sandbox.local:LocalSandboxProvider",
mounts=[
VolumeMountConfig(host_path="relative/path", container_path="/mnt/data", read_only=False),
],
)
config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir),
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
def test_setup_path_mappings_skips_non_absolute_container_path(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
custom_dir = tmp_path / "custom"
custom_dir.mkdir()
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
sandbox_config = SandboxConfig(
use="deerflow.sandbox.local:LocalSandboxProvider",
mounts=[
VolumeMountConfig(host_path=str(custom_dir), container_path="mnt/data", read_only=False),
],
)
config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir),
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
def test_setup_path_mappings_normalizes_container_path_trailing_slash(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
custom_dir = tmp_path / "custom"
custom_dir.mkdir()
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
sandbox_config = SandboxConfig(
use="deerflow.sandbox.local:LocalSandboxProvider",
mounts=[
VolumeMountConfig(host_path=str(custom_dir), container_path="/mnt/data/", read_only=False),
],
)
config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir),
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills", "/mnt/data"]

View File

@ -1,5 +1,6 @@
"""Tests for LoopDetectionMiddleware."""
import copy
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
@ -19,8 +20,13 @@ def _make_runtime(thread_id="test-thread"):
def _make_state(tool_calls=None, content=""):
"""Build a minimal AgentState dict with an AIMessage."""
msg = AIMessage(content=content, tool_calls=tool_calls or [])
"""Build a minimal AgentState dict with an AIMessage.
Deep-copies *content* when it is mutable (e.g. list) so that
successive calls never share the same object reference.
"""
safe_content = copy.deepcopy(content) if isinstance(content, list) else content
msg = AIMessage(content=safe_content, tool_calls=tool_calls or [])
return {"messages": [msg]}
@ -229,3 +235,114 @@ class TestLoopDetection:
mw._apply(_make_state(tool_calls=call), runtime)
assert "default" in mw._history
class TestAppendText:
"""Unit tests for LoopDetectionMiddleware._append_text."""
def test_none_content_returns_text(self):
result = LoopDetectionMiddleware._append_text(None, "hello")
assert result == "hello"
def test_str_content_concatenates(self):
result = LoopDetectionMiddleware._append_text("existing", "appended")
assert result == "existing\n\nappended"
def test_empty_str_content_concatenates(self):
result = LoopDetectionMiddleware._append_text("", "appended")
assert result == "\n\nappended"
def test_list_content_appends_text_block(self):
"""List content (e.g. Anthropic thinking mode) should get a new text block."""
content = [
{"type": "thinking", "text": "Let me think..."},
{"type": "text", "text": "Here is my answer"},
]
result = LoopDetectionMiddleware._append_text(content, "stop msg")
assert isinstance(result, list)
assert len(result) == 3
assert result[0] == content[0]
assert result[1] == content[1]
assert result[2] == {"type": "text", "text": "\n\nstop msg"}
def test_empty_list_content_appends_text_block(self):
result = LoopDetectionMiddleware._append_text([], "stop msg")
assert isinstance(result, list)
assert len(result) == 1
assert result[0] == {"type": "text", "text": "\n\nstop msg"}
def test_unexpected_type_coerced_to_str(self):
"""Unexpected content types should be coerced to str as a fallback."""
result = LoopDetectionMiddleware._append_text(42, "stop msg")
assert isinstance(result, str)
assert result == "42\n\nstop msg"
def test_list_content_not_mutated_in_place(self):
"""_append_text must not modify the original list."""
original = [{"type": "text", "text": "hello"}]
result = LoopDetectionMiddleware._append_text(original, "appended")
assert len(original) == 1 # original unchanged
assert len(result) == 2 # new list has the appended block
class TestHardStopWithListContent:
"""Regression tests: hard stop must not crash when AIMessage.content is a list."""
def test_hard_stop_with_list_content(self):
"""Hard stop on list content should not raise TypeError (regression)."""
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
runtime = _make_runtime()
call = [_bash_call("ls")]
# Build state with list content (e.g. Anthropic thinking mode)
list_content = [
{"type": "thinking", "text": "Let me think..."},
{"type": "text", "text": "I'll run ls"},
]
for _ in range(3):
mw._apply(_make_state(tool_calls=call, content=list_content), runtime)
# Fourth call triggers hard stop — must not raise TypeError
result = mw._apply(_make_state(tool_calls=call, content=list_content), runtime)
assert result is not None
msg = result["messages"][0]
assert isinstance(msg, AIMessage)
assert msg.tool_calls == []
# Content should remain a list with the stop message appended
assert isinstance(msg.content, list)
assert len(msg.content) == 3
assert msg.content[2]["type"] == "text"
assert _HARD_STOP_MSG in msg.content[2]["text"]
def test_hard_stop_with_none_content(self):
"""Hard stop on None content should produce a plain string."""
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime)
# Fourth call with default empty-string content
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
msg = result["messages"][0]
assert isinstance(msg.content, str)
assert _HARD_STOP_MSG in msg.content
def test_hard_stop_with_str_content(self):
"""Hard stop on str content should concatenate the stop message."""
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call, content="thinking..."), runtime)
result = mw._apply(_make_state(tool_calls=call, content="thinking..."), runtime)
assert result is not None
msg = result["messages"][0]
assert isinstance(msg.content, str)
assert msg.content.startswith("thinking...")
assert _HARD_STOP_MSG in msg.content

View File

@ -154,3 +154,22 @@ def test_format_memory_renders_correction_without_source_error_normally() -> Non
assert "Use make dev for local development." in result
assert "avoid:" not in result
def test_format_memory_includes_long_term_background() -> None:
"""longTermBackground in history must be injected into the prompt."""
memory_data = {
"user": {},
"history": {
"recentMonths": {"summary": "Recent activity summary"},
"earlierContext": {"summary": "Earlier context summary"},
"longTermBackground": {"summary": "Core expertise in distributed systems"},
},
"facts": [],
}
result = format_memory_for_injection(memory_data, max_tokens=2000)
assert "Background: Core expertise in distributed systems" in result
assert "Recent: Recent activity summary" in result
assert "Earlier: Earlier context summary" in result

View File

@ -47,4 +47,45 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None:
thread_id="thread-1",
agent_name="lead_agent",
correction_detected=True,
reinforcement_detected=False,
)
def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["first"], reinforcement_detected=True)
queue.add(thread_id="thread-1", messages=["second"], reinforcement_detected=False)
assert len(queue._queue) == 1
assert queue._queue[0].messages == ["second"]
assert queue._queue[0].reinforcement_detected is True
def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
queue = MemoryUpdateQueue()
queue._queue = [
ConversationContext(
thread_id="thread-1",
messages=["conversation"],
agent_name="lead_agent",
reinforcement_detected=True,
)
]
mock_updater = MagicMock()
mock_updater.update_memory.return_value = True
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
queue._process_queue()
mock_updater.update_memory.assert_called_once_with(
messages=["conversation"],
thread_id="thread-1",
agent_name="lead_agent",
correction_detected=False,
reinforcement_detected=True,
)

View File

@ -619,3 +619,156 @@ class TestUpdateMemoryStructuredResponse:
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Explicit correction signals were detected" not in prompt
class TestFactDeduplicationCaseInsensitive:
"""Tests that fact deduplication is case-insensitive."""
def test_duplicate_fact_different_case_not_stored(self):
updater = MemoryUpdater()
current_memory = _make_memory(
facts=[
{
"id": "fact_1",
"content": "User prefers Python",
"category": "preference",
"confidence": 0.9,
"createdAt": "2026-01-01T00:00:00Z",
"source": "thread-a",
},
]
)
# Same fact with different casing should be treated as duplicate
update_data = {
"factsToRemove": [],
"newFacts": [
{"content": "user prefers python", "category": "preference", "confidence": 0.95},
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
# Should still have only 1 fact (duplicate rejected)
assert len(result["facts"]) == 1
assert result["facts"][0]["content"] == "User prefers Python"
def test_unique_fact_different_case_and_content_stored(self):
updater = MemoryUpdater()
current_memory = _make_memory(
facts=[
{
"id": "fact_1",
"content": "User prefers Python",
"category": "preference",
"confidence": 0.9,
"createdAt": "2026-01-01T00:00:00Z",
"source": "thread-a",
},
]
)
update_data = {
"factsToRemove": [],
"newFacts": [
{"content": "User prefers Go", "category": "preference", "confidence": 0.85},
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
assert len(result["facts"]) == 2
class TestReinforcementHint:
"""Tests that reinforcement_detected injects the correct hint into the prompt."""
@staticmethod
def _make_mock_model(json_response: str):
model = MagicMock()
response = MagicMock()
response.content = f"```json\n{json_response}\n```"
model.invoke.return_value = response
return model
def test_reinforcement_hint_injected_when_detected(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Yes, exactly! That's what I needed."
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Great to hear!"
ai_msg.tool_calls = []
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Positive reinforcement signals were detected" in prompt
def test_reinforcement_hint_absent_when_not_detected(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Tell me more."
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Sure."
ai_msg.tool_calls = []
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Positive reinforcement signals were detected" not in prompt
def test_both_hints_present_when_both_detected(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "No wait, that's wrong. Actually yes, exactly right."
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Got it."
ai_msg.tool_calls = []
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Explicit correction signals were detected" in prompt
assert "Positive reinforcement signals were detected" in prompt

View File

@ -10,7 +10,7 @@ persisting in long-term memory:
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction, detect_reinforcement
# ---------------------------------------------------------------------------
# Helpers
@ -270,3 +270,73 @@ class TestStripUploadMentionsFromMemory:
mem = {"user": {}, "history": {}, "facts": []}
result = _strip_upload_mentions_from_memory(mem)
assert result == {"user": {}, "history": {}, "facts": []}
# ===========================================================================
# detect_reinforcement
# ===========================================================================
class TestDetectReinforcement:
def test_detects_english_reinforcement_signal(self):
msgs = [
_human("Can you summarise it in bullet points?"),
_ai("Here are the key points: ..."),
_human("Yes, exactly! That's what I needed."),
_ai("Glad it helped."),
]
assert detect_reinforcement(msgs) is True
def test_detects_perfect_signal(self):
msgs = [
_human("Write it more concisely."),
_ai("Here is the concise version."),
_human("Perfect."),
_ai("Great!"),
]
assert detect_reinforcement(msgs) is True
def test_detects_chinese_reinforcement_signal(self):
msgs = [
_human("帮我用要点来总结"),
_ai("好的,要点如下:..."),
_human("完全正确,就是这个意思"),
_ai("很高兴能帮到你"),
]
assert detect_reinforcement(msgs) is True
def test_returns_false_without_signal(self):
msgs = [
_human("What does this function do?"),
_ai("It processes the input data."),
_human("Can you show me an example?"),
]
assert detect_reinforcement(msgs) is False
def test_only_checks_recent_messages(self):
# Reinforcement signal buried beyond the -6 window should not trigger
msgs = [
_human("Yes, exactly right."),
_ai("Noted."),
_human("Let's discuss tests."),
_ai("Sure."),
_human("What about linting?"),
_ai("Use ruff."),
_human("And formatting?"),
_ai("Use make format."),
]
assert detect_reinforcement(msgs) is False
def test_does_not_conflict_with_correction(self):
# A message can trigger correction but not reinforcement
msgs = [
_human("That's wrong, try again."),
_ai("Corrected."),
]
assert detect_reinforcement(msgs) is False

View File

@ -0,0 +1,393 @@
from types import SimpleNamespace
from unittest.mock import patch
from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox
from deerflow.sandbox.local.local_sandbox import LocalSandbox
from deerflow.sandbox.search import GrepMatch, find_glob_matches, find_grep_matches
from deerflow.sandbox.tools import glob_tool, grep_tool
def _make_runtime(tmp_path):
workspace = tmp_path / "workspace"
uploads = tmp_path / "uploads"
outputs = tmp_path / "outputs"
workspace.mkdir()
uploads.mkdir()
outputs.mkdir()
return SimpleNamespace(
state={
"sandbox": {"sandbox_id": "local"},
"thread_data": {
"workspace_path": str(workspace),
"uploads_path": str(uploads),
"outputs_path": str(outputs),
},
},
context={"thread_id": "thread-1"},
)
def test_glob_tool_returns_virtual_paths_and_ignores_common_dirs(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "app.py").write_text("print('hi')\n", encoding="utf-8")
(workspace / "pkg").mkdir()
(workspace / "pkg" / "util.py").write_text("print('util')\n", encoding="utf-8")
(workspace / "node_modules").mkdir()
(workspace / "node_modules" / "skip.py").write_text("ignored\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
result = glob_tool.func(
runtime=runtime,
description="find python files",
pattern="**/*.py",
path="/mnt/user-data/workspace",
)
assert "/mnt/user-data/workspace/app.py" in result
assert "/mnt/user-data/workspace/pkg/util.py" in result
assert "node_modules" not in result
assert str(workspace) not in result
def test_glob_tool_supports_skills_virtual_paths(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
skills_dir = tmp_path / "skills"
(skills_dir / "public" / "demo").mkdir(parents=True)
(skills_dir / "public" / "demo" / "SKILL.md").write_text("# Demo\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
with (
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
patch("deerflow.sandbox.tools._get_skills_host_path", return_value=str(skills_dir)),
):
result = glob_tool.func(
runtime=runtime,
description="find skills",
pattern="**/SKILL.md",
path="/mnt/skills",
)
assert "/mnt/skills/public/demo/SKILL.md" in result
assert str(skills_dir) not in result
def test_grep_tool_filters_by_glob_and_skips_binary_files(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "main.py").write_text("TODO = 'ship it'\nprint(TODO)\n", encoding="utf-8")
(workspace / "notes.txt").write_text("TODO in txt should be filtered\n", encoding="utf-8")
(workspace / "image.bin").write_bytes(b"\0binary TODO")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
result = grep_tool.func(
runtime=runtime,
description="find todo references",
pattern="TODO",
path="/mnt/user-data/workspace",
glob="**/*.py",
)
assert "/mnt/user-data/workspace/main.py:1: TODO = 'ship it'" in result
assert "notes.txt" not in result
assert "image.bin" not in result
assert str(workspace) not in result
def test_grep_tool_truncates_results(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "main.py").write_text("TODO one\nTODO two\nTODO three\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
# Prevent config.yaml tool config from overriding the caller-supplied max_results=2.
monkeypatch.setattr("deerflow.sandbox.tools.get_app_config", lambda: SimpleNamespace(get_tool_config=lambda name: None))
result = grep_tool.func(
runtime=runtime,
description="limit matches",
pattern="TODO",
path="/mnt/user-data/workspace",
max_results=2,
)
assert "Found 2 matches under /mnt/user-data/workspace (showing first 2)" in result
assert "TODO one" in result
assert "TODO two" in result
assert "TODO three" not in result
assert "Results truncated." in result
def test_glob_tool_include_dirs_filters_nested_ignored_paths(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "src").mkdir()
(workspace / "src" / "main.py").write_text("x\n", encoding="utf-8")
(workspace / "node_modules").mkdir()
(workspace / "node_modules" / "lib").mkdir()
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
result = glob_tool.func(
runtime=runtime,
description="find dirs",
pattern="**",
path="/mnt/user-data/workspace",
include_dirs=True,
)
assert "src" in result
assert "node_modules" not in result
def test_grep_tool_literal_mode(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "file.py").write_text("price = (a+b)\nresult = a+b\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
# literal=True should treat (a+b) as a plain string, not a regex group
result = grep_tool.func(
runtime=runtime,
description="literal search",
pattern="(a+b)",
path="/mnt/user-data/workspace",
literal=True,
)
assert "price = (a+b)" in result
assert "result = a+b" not in result
def test_grep_tool_case_sensitive(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "file.py").write_text("TODO: fix\ntodo: also fix\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
result = grep_tool.func(
runtime=runtime,
description="case sensitive search",
pattern="TODO",
path="/mnt/user-data/workspace",
case_sensitive=True,
)
assert "TODO: fix" in result
assert "todo: also fix" not in result
def test_grep_tool_invalid_regex_returns_error(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
result = grep_tool.func(
runtime=runtime,
description="bad pattern",
pattern="[invalid",
path="/mnt/user-data/workspace",
)
assert "Invalid regex pattern" in result
def test_aio_sandbox_glob_include_dirs_filters_nested_ignored(monkeypatch) -> None:
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
monkeypatch.setattr(
sandbox._client.file,
"list_path",
lambda **kwargs: SimpleNamespace(
data=SimpleNamespace(
files=[
SimpleNamespace(name="src", path="/mnt/workspace/src"),
SimpleNamespace(name="node_modules", path="/mnt/workspace/node_modules"),
# child of node_modules — should be filtered via should_ignore_path
SimpleNamespace(name="lib", path="/mnt/workspace/node_modules/lib"),
]
)
),
)
matches, truncated = sandbox.glob("/mnt/workspace", "**", include_dirs=True)
assert "/mnt/workspace/src" in matches
assert "/mnt/workspace/node_modules" not in matches
assert "/mnt/workspace/node_modules/lib" not in matches
assert truncated is False
def test_aio_sandbox_grep_invalid_regex_raises() -> None:
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
import re
try:
sandbox.grep("/mnt/workspace", "[invalid")
assert False, "Expected re.error"
except re.error:
pass
def test_aio_sandbox_glob_parses_json(monkeypatch) -> None:
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
monkeypatch.setattr(
sandbox._client.file,
"find_files",
lambda **kwargs: SimpleNamespace(data=SimpleNamespace(files=["/mnt/user-data/workspace/app.py", "/mnt/user-data/workspace/node_modules/skip.py"])),
)
matches, truncated = sandbox.glob("/mnt/user-data/workspace", "**/*.py")
assert matches == ["/mnt/user-data/workspace/app.py"]
assert truncated is False
def test_aio_sandbox_grep_parses_json(monkeypatch) -> None:
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
monkeypatch.setattr(
sandbox._client.file,
"list_path",
lambda **kwargs: SimpleNamespace(
data=SimpleNamespace(
files=[
SimpleNamespace(
name="app.py",
path="/mnt/user-data/workspace/app.py",
is_directory=False,
)
]
)
),
)
monkeypatch.setattr(
sandbox._client.file,
"search_in_file",
lambda **kwargs: SimpleNamespace(data=SimpleNamespace(line_numbers=[7], matches=["TODO = True"])),
)
matches, truncated = sandbox.grep("/mnt/user-data/workspace", "TODO")
assert matches == [GrepMatch(path="/mnt/user-data/workspace/app.py", line_number=7, line="TODO = True")]
assert truncated is False
def test_find_glob_matches_raises_not_a_directory(tmp_path) -> None:
file_path = tmp_path / "file.txt"
file_path.write_text("x\n", encoding="utf-8")
try:
find_glob_matches(file_path, "**/*.py")
assert False, "Expected NotADirectoryError"
except NotADirectoryError:
pass
def test_find_grep_matches_raises_not_a_directory(tmp_path) -> None:
file_path = tmp_path / "file.txt"
file_path.write_text("TODO\n", encoding="utf-8")
try:
find_grep_matches(file_path, "TODO")
assert False, "Expected NotADirectoryError"
except NotADirectoryError:
pass
def test_find_grep_matches_skips_symlink_outside_root(tmp_path) -> None:
workspace = tmp_path / "workspace"
workspace.mkdir()
outside = tmp_path / "outside.txt"
outside.write_text("TODO outside\n", encoding="utf-8")
(workspace / "outside-link.txt").symlink_to(outside)
matches, truncated = find_grep_matches(workspace, "TODO")
assert matches == []
assert truncated is False
def test_glob_tool_honors_smaller_requested_max_results(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "a.py").write_text("print('a')\n", encoding="utf-8")
(workspace / "b.py").write_text("print('b')\n", encoding="utf-8")
(workspace / "c.py").write_text("print('c')\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
monkeypatch.setattr(
"deerflow.sandbox.tools.get_app_config",
lambda: SimpleNamespace(get_tool_config=lambda name: SimpleNamespace(model_extra={"max_results": 50})),
)
result = glob_tool.func(
runtime=runtime,
description="limit glob matches",
pattern="**/*.py",
path="/mnt/user-data/workspace",
max_results=2,
)
assert "Found 2 paths under /mnt/user-data/workspace (showing first 2)" in result
assert "Results truncated." in result
def test_aio_sandbox_glob_include_dirs_enforces_root_boundary(monkeypatch) -> None:
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
monkeypatch.setattr(
sandbox._client.file,
"list_path",
lambda **kwargs: SimpleNamespace(
data=SimpleNamespace(
files=[
SimpleNamespace(name="src", path="/mnt/workspace/src"),
SimpleNamespace(name="src2", path="/mnt/workspace2/src2"),
]
)
),
)
matches, truncated = sandbox.glob("/mnt/workspace", "**", include_dirs=True)
assert matches == ["/mnt/workspace/src"]
assert truncated is False
def test_aio_sandbox_grep_skips_mismatched_line_number_payloads(monkeypatch) -> None:
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
monkeypatch.setattr(
sandbox._client.file,
"list_path",
lambda **kwargs: SimpleNamespace(
data=SimpleNamespace(
files=[
SimpleNamespace(
name="app.py",
path="/mnt/user-data/workspace/app.py",
is_directory=False,
)
]
)
),
)
monkeypatch.setattr(
sandbox._client.file,
"search_in_file",
lambda **kwargs: SimpleNamespace(data=SimpleNamespace(line_numbers=[7], matches=["TODO = True", "extra"])),
)
matches, truncated = sandbox.grep("/mnt/user-data/workspace", "TODO")
assert matches == [GrepMatch(path="/mnt/user-data/workspace/app.py", line_number=7, line="TODO = True")]
assert truncated is False

View File

@ -8,7 +8,10 @@ import pytest
from deerflow.sandbox.tools import (
VIRTUAL_PATH_PREFIX,
_apply_cwd_prefix,
_get_custom_mount_for_path,
_get_custom_mounts,
_is_acp_workspace_path,
_is_custom_mount_path,
_is_skills_path,
_reject_path_traversal,
_resolve_acp_workspace_path,
@ -39,6 +42,53 @@ def test_replace_virtual_path_maps_virtual_root_and_subpaths() -> None:
assert Path(replace_virtual_path("/mnt/user-data", _THREAD_DATA)).as_posix() == "/tmp/deer-flow/threads/t1/user-data"
def test_replace_virtual_path_preserves_trailing_slash() -> None:
"""Trailing slash must survive virtual-to-actual path replacement.
Regression: '/mnt/user-data/workspace/' was previously returned without
the trailing slash, causing string concatenations like
output_dir + 'file.txt' to produce a missing-separator path.
"""
result = replace_virtual_path("/mnt/user-data/workspace/", _THREAD_DATA)
assert result.endswith("/"), f"Expected trailing slash, got: {result!r}"
assert result == "/tmp/deer-flow/threads/t1/user-data/workspace/"
def test_replace_virtual_path_preserves_trailing_slash_windows_style() -> None:
"""Trailing slash must be preserved as backslash when actual_base is Windows-style.
If actual_base uses backslash separators, appending '/' would produce a
mixed-separator path. The separator must match the style of actual_base.
"""
win_thread_data = {
"workspace_path": r"C:\deer-flow\threads\t1\user-data\workspace",
"uploads_path": r"C:\deer-flow\threads\t1\user-data\uploads",
"outputs_path": r"C:\deer-flow\threads\t1\user-data\outputs",
}
result = replace_virtual_path("/mnt/user-data/workspace/", win_thread_data)
assert result.endswith("\\"), f"Expected trailing backslash for Windows path, got: {result!r}"
assert "/" not in result, f"Mixed separators in Windows path: {result!r}"
def test_replace_virtual_path_preserves_windows_style_for_nested_subdir_trailing_slash() -> None:
"""Nested Windows-style subdirectories must keep backslashes throughout."""
win_thread_data = {
"workspace_path": r"C:\deer-flow\threads\t1\user-data\workspace",
"uploads_path": r"C:\deer-flow\threads\t1\user-data\uploads",
"outputs_path": r"C:\deer-flow\threads\t1\user-data\outputs",
}
result = replace_virtual_path("/mnt/user-data/workspace/subdir/", win_thread_data)
assert result == "C:\\deer-flow\\threads\\t1\\user-data\\workspace\\subdir\\"
assert "/" not in result, f"Mixed separators in Windows path: {result!r}"
def test_replace_virtual_paths_in_command_preserves_trailing_slash() -> None:
"""Trailing slash on a virtual path inside a command must be preserved."""
cmd = """python -c "output_dir = '/mnt/user-data/workspace/'; print(output_dir + 'some_file.txt')\""""
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
assert "/tmp/deer-flow/threads/t1/user-data/workspace/" in result, f"Trailing slash lost in: {result!r}"
# ---------- mask_local_paths_in_output ----------
@ -96,6 +146,25 @@ def test_validate_local_tool_path_rejects_non_virtual_path() -> None:
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
def test_validate_local_tool_path_rejects_non_virtual_path_mentions_configured_mounts() -> None:
with pytest.raises(PermissionError, match="configured mount paths"):
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
def test_validate_local_tool_path_prioritizes_user_data_before_custom_mounts() -> None:
from deerflow.config.sandbox_config import VolumeMountConfig
mounts = [
VolumeMountConfig(host_path="/tmp/host-user-data", container_path=VIRTUAL_PATH_PREFIX, read_only=False),
]
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, read_only=True)
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
with pytest.raises(PermissionError, match="path traversal"):
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, read_only=True)
def test_validate_local_tool_path_rejects_bare_virtual_root() -> None:
"""The bare /mnt/user-data root without trailing slash is not a valid sub-path."""
with pytest.raises(PermissionError, match="Only paths under"):
@ -235,6 +304,22 @@ def test_validate_local_bash_command_paths_blocks_host_paths() -> None:
validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA)
def test_validate_local_bash_command_paths_allows_https_urls() -> None:
"""URLs like https://github.com/... must not be flagged as unsafe absolute paths."""
validate_local_bash_command_paths(
"cd /mnt/user-data/workspace && git clone https://github.com/CherryHQ/cherry-studio.git",
_THREAD_DATA,
)
def test_validate_local_bash_command_paths_allows_http_urls() -> None:
"""HTTP URLs must not be flagged as unsafe absolute paths."""
validate_local_bash_command_paths(
"curl http://example.com/file.tar.gz -o /mnt/user-data/workspace/file.tar.gz",
_THREAD_DATA,
)
def test_validate_local_bash_command_paths_allows_virtual_and_system_paths() -> None:
validate_local_bash_command_paths(
"/bin/echo ok > /mnt/user-data/workspace/out.txt && cat /dev/null",
@ -567,6 +652,156 @@ def test_validate_local_bash_command_paths_allows_mcp_filesystem_paths() -> None
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
# ---------- Custom mount path tests ----------
def _mock_custom_mounts():
"""Create mock VolumeMountConfig objects for testing."""
from deerflow.config.sandbox_config import VolumeMountConfig
return [
VolumeMountConfig(host_path="/home/user/code-read", container_path="/mnt/code-read", read_only=True),
VolumeMountConfig(host_path="/home/user/data", container_path="/mnt/data", read_only=False),
]
def test_is_custom_mount_path_recognises_configured_mounts() -> None:
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
assert _is_custom_mount_path("/mnt/code-read") is True
assert _is_custom_mount_path("/mnt/code-read/src/main.py") is True
assert _is_custom_mount_path("/mnt/data") is True
assert _is_custom_mount_path("/mnt/data/file.txt") is True
assert _is_custom_mount_path("/mnt/code-read-extra/foo") is False
assert _is_custom_mount_path("/mnt/other") is False
def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
from deerflow.config.sandbox_config import VolumeMountConfig
mounts = [
VolumeMountConfig(host_path="/var/mnt", container_path="/mnt", read_only=False),
VolumeMountConfig(host_path="/home/user/code", container_path="/mnt/code", read_only=True),
]
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
mount = _get_custom_mount_for_path("/mnt/code/file.py")
assert mount is not None
assert mount.container_path == "/mnt/code"
def test_validate_local_tool_path_allows_custom_mount_read() -> None:
"""read_file / ls should be able to access custom mount paths."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=True)
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=True)
def test_validate_local_tool_path_blocks_read_only_mount_write() -> None:
"""write_file / str_replace must NOT write to read-only custom mounts."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="Write access to read-only mount is not allowed"):
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=False)
def test_validate_local_tool_path_allows_writable_mount_write() -> None:
"""write_file / str_replace should succeed on writable custom mounts."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=False)
def test_validate_local_tool_path_blocks_traversal_in_custom_mount() -> None:
"""Path traversal via .. in custom mount paths must be rejected."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="path traversal"):
validate_local_tool_path("/mnt/code-read/../../etc/passwd", _THREAD_DATA, read_only=True)
def test_validate_local_bash_command_paths_allows_custom_mount() -> None:
"""bash commands referencing custom mount paths should be allowed."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
validate_local_bash_command_paths("cat /mnt/code-read/src/main.py", _THREAD_DATA)
validate_local_bash_command_paths("ls /mnt/data", _THREAD_DATA)
def test_validate_local_bash_command_paths_blocks_traversal_in_custom_mount() -> None:
"""Bash commands with traversal in custom mount paths should be blocked."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="path traversal"):
validate_local_bash_command_paths("cat /mnt/code-read/../../etc/passwd", _THREAD_DATA)
def test_validate_local_bash_command_paths_still_blocks_non_mount_paths() -> None:
"""Paths not matching any custom mount should still be blocked."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA)
def test_get_custom_mounts_caching(monkeypatch, tmp_path) -> None:
"""_get_custom_mounts should cache after first successful load."""
# Clear any existing cache
if hasattr(_get_custom_mounts, "_cached"):
monkeypatch.delattr(_get_custom_mounts, "_cached")
# Use real directories so host_path.exists() filtering passes
dir_a = tmp_path / "code-read"
dir_a.mkdir()
dir_b = tmp_path / "data"
dir_b.mkdir()
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
mounts = [
VolumeMountConfig(host_path=str(dir_a), container_path="/mnt/code-read", read_only=True),
VolumeMountConfig(host_path=str(dir_b), container_path="/mnt/data", read_only=False),
]
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
mock_config = SimpleNamespace(sandbox=mock_sandbox)
with patch("deerflow.config.get_app_config", return_value=mock_config):
result = _get_custom_mounts()
assert len(result) == 2
# After caching, should return cached value even without mock
assert hasattr(_get_custom_mounts, "_cached")
assert len(_get_custom_mounts()) == 2
# Cleanup
monkeypatch.delattr(_get_custom_mounts, "_cached")
def test_get_custom_mounts_filters_nonexistent_host_path(monkeypatch, tmp_path) -> None:
"""_get_custom_mounts should only return mounts whose host_path exists."""
if hasattr(_get_custom_mounts, "_cached"):
monkeypatch.delattr(_get_custom_mounts, "_cached")
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
existing_dir = tmp_path / "existing"
existing_dir.mkdir()
mounts = [
VolumeMountConfig(host_path=str(existing_dir), container_path="/mnt/existing", read_only=True),
VolumeMountConfig(host_path="/nonexistent/path/12345", container_path="/mnt/ghost", read_only=False),
]
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
mock_config = SimpleNamespace(sandbox=mock_sandbox)
with patch("deerflow.config.get_app_config", return_value=mock_config):
result = _get_custom_mounts()
assert len(result) == 1
assert result[0].container_path == "/mnt/existing"
# Cleanup
monkeypatch.delattr(_get_custom_mounts, "_cached")
def test_get_custom_mount_for_path_boundary_no_false_prefix_match() -> None:
"""_get_custom_mount_for_path must not match /mnt/code-read-extra for /mnt/code-read."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
mount = _get_custom_mount_for_path("/mnt/code-read-extra/foo")
assert mount is None
def test_str_replace_parallel_updates_should_preserve_both_edits(monkeypatch) -> None:
class SharedSandbox:
def __init__(self) -> None:

View File

@ -140,6 +140,193 @@ async def test_event_id_format(bridge: MemoryStreamBridge):
assert re.match(r"^\d+-\d+$", event.id), f"Expected timestamp-seq format, got {event.id}"
# ---------------------------------------------------------------------------
# END sentinel guarantee tests
# ---------------------------------------------------------------------------
@pytest.mark.anyio
async def test_end_sentinel_delivered_when_queue_full():
"""END sentinel must always be delivered, even when the queue is completely full.
This is the critical regression test for the bug where publish_end()
would silently drop the END sentinel when the queue was full, causing
subscribe() to hang forever and leaking resources.
"""
bridge = MemoryStreamBridge(queue_maxsize=2)
run_id = "run-end-full"
# Fill the queue to capacity
await bridge.publish(run_id, "event-1", {"n": 1})
await bridge.publish(run_id, "event-2", {"n": 2})
assert bridge._queues[run_id].full()
# publish_end should succeed by evicting old events
await bridge.publish_end(run_id)
# Subscriber must receive END_SENTINEL
events = []
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
events.append(entry)
if entry is END_SENTINEL:
break
assert any(e is END_SENTINEL for e in events), "END sentinel was not delivered"
@pytest.mark.anyio
async def test_end_sentinel_evicts_oldest_events():
"""When queue is full, publish_end evicts the oldest events to make room."""
bridge = MemoryStreamBridge(queue_maxsize=1)
run_id = "run-evict"
# Fill queue with one event
await bridge.publish(run_id, "will-be-evicted", {})
assert bridge._queues[run_id].full()
# publish_end must succeed
await bridge.publish_end(run_id)
# The only event we should get is END_SENTINEL (the regular event was evicted)
events = []
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
events.append(entry)
if entry is END_SENTINEL:
break
assert len(events) == 1
assert events[0] is END_SENTINEL
@pytest.mark.anyio
async def test_end_sentinel_no_eviction_when_space_available():
"""When queue has space, publish_end should not evict anything."""
bridge = MemoryStreamBridge(queue_maxsize=10)
run_id = "run-no-evict"
await bridge.publish(run_id, "event-1", {"n": 1})
await bridge.publish(run_id, "event-2", {"n": 2})
await bridge.publish_end(run_id)
events = []
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
events.append(entry)
if entry is END_SENTINEL:
break
# All events plus END should be present
assert len(events) == 3
assert events[0].event == "event-1"
assert events[1].event == "event-2"
assert events[2] is END_SENTINEL
@pytest.mark.anyio
async def test_concurrent_tasks_end_sentinel():
"""Multiple concurrent producer/consumer pairs should all terminate properly.
Simulates the production scenario where multiple runs share a single
bridge instance each must receive its own END sentinel.
"""
bridge = MemoryStreamBridge(queue_maxsize=4)
num_runs = 4
async def producer(run_id: str):
for i in range(10): # More events than queue capacity
await bridge.publish(run_id, f"event-{i}", {"i": i})
await bridge.publish_end(run_id)
async def consumer(run_id: str) -> list:
events = []
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
events.append(entry)
if entry is END_SENTINEL:
return events
return events # pragma: no cover
# Run producers and consumers concurrently
run_ids = [f"concurrent-{i}" for i in range(num_runs)]
producers = [producer(rid) for rid in run_ids]
consumers = [consumer(rid) for rid in run_ids]
# Start consumers first, then producers
consumer_tasks = [asyncio.create_task(c) for c in consumers]
await asyncio.gather(*producers)
results = await asyncio.wait_for(
asyncio.gather(*consumer_tasks),
timeout=10.0,
)
for i, events in enumerate(results):
assert events[-1] is END_SENTINEL, f"Run {run_ids[i]} did not receive END sentinel"
# ---------------------------------------------------------------------------
# Drop counter tests
# ---------------------------------------------------------------------------
@pytest.mark.anyio
async def test_dropped_count_tracking():
"""Dropped events should be tracked per run_id."""
bridge = MemoryStreamBridge(queue_maxsize=1)
run_id = "run-drop-count"
# Fill the queue
await bridge.publish(run_id, "first", {})
# This publish will time out and be dropped (we patch timeout to be instant)
# Instead, we verify the counter after publish_end eviction
await bridge.publish_end(run_id)
# dropped_count tracks publish() drops, not publish_end evictions
assert bridge.dropped_count(run_id) == 0
# cleanup should also clear the counter
await bridge.cleanup(run_id)
assert bridge.dropped_count(run_id) == 0
@pytest.mark.anyio
async def test_dropped_total():
"""dropped_total should sum across all runs."""
bridge = MemoryStreamBridge(queue_maxsize=256)
# No drops yet
assert bridge.dropped_total == 0
# Manually set some counts to verify the property
bridge._dropped_counts["run-a"] = 3
bridge._dropped_counts["run-b"] = 7
assert bridge.dropped_total == 10
@pytest.mark.anyio
async def test_cleanup_clears_dropped_counts():
"""cleanup() should clear the dropped counter for the run."""
bridge = MemoryStreamBridge(queue_maxsize=256)
run_id = "run-cleanup-drops"
bridge._get_or_create_queue(run_id)
bridge._dropped_counts[run_id] = 5
await bridge.cleanup(run_id)
assert run_id not in bridge._dropped_counts
@pytest.mark.anyio
async def test_close_clears_dropped_counts():
"""close() should clear all dropped counters."""
bridge = MemoryStreamBridge(queue_maxsize=256)
bridge._dropped_counts["run-x"] = 10
bridge._dropped_counts["run-y"] = 20
await bridge.close()
assert bridge.dropped_total == 0
assert len(bridge._dropped_counts) == 0
# ---------------------------------------------------------------------------
# Factory tests
# ---------------------------------------------------------------------------

View File

@ -1,8 +1,8 @@
"""Tests for subagent timeout configuration.
"""Tests for subagent runtime configuration.
Covers:
- SubagentsAppConfig / SubagentOverrideConfig model validation and defaults
- get_timeout_for() resolution logic (global vs per-agent)
- get_timeout_for() / get_max_turns_for() resolution logic
- load_subagents_config_from_dict() and get_subagents_app_config() singleton
- registry.get_subagent_config() applies config overrides
- registry.list_subagents() applies overrides for all agents
@ -24,9 +24,20 @@ from deerflow.subagents.config import SubagentConfig
# ---------------------------------------------------------------------------
def _reset_subagents_config(timeout_seconds: int = 900, agents: dict | None = None) -> None:
def _reset_subagents_config(
timeout_seconds: int = 900,
*,
max_turns: int | None = None,
agents: dict | None = None,
) -> None:
"""Reset global subagents config to a known state."""
load_subagents_config_from_dict({"timeout_seconds": timeout_seconds, "agents": agents or {}})
load_subagents_config_from_dict(
{
"timeout_seconds": timeout_seconds,
"max_turns": max_turns,
"agents": agents or {},
}
)
# ---------------------------------------------------------------------------
@ -38,22 +49,29 @@ class TestSubagentOverrideConfig:
def test_default_is_none(self):
override = SubagentOverrideConfig()
assert override.timeout_seconds is None
assert override.max_turns is None
def test_explicit_value(self):
override = SubagentOverrideConfig(timeout_seconds=300)
override = SubagentOverrideConfig(timeout_seconds=300, max_turns=42)
assert override.timeout_seconds == 300
assert override.max_turns == 42
def test_rejects_zero(self):
with pytest.raises(ValueError):
SubagentOverrideConfig(timeout_seconds=0)
with pytest.raises(ValueError):
SubagentOverrideConfig(max_turns=0)
def test_rejects_negative(self):
with pytest.raises(ValueError):
SubagentOverrideConfig(timeout_seconds=-1)
with pytest.raises(ValueError):
SubagentOverrideConfig(max_turns=-1)
def test_minimum_valid_value(self):
override = SubagentOverrideConfig(timeout_seconds=1)
override = SubagentOverrideConfig(timeout_seconds=1, max_turns=1)
assert override.timeout_seconds == 1
assert override.max_turns == 1
# ---------------------------------------------------------------------------
@ -66,66 +84,86 @@ class TestSubagentsAppConfigDefaults:
config = SubagentsAppConfig()
assert config.timeout_seconds == 900
def test_default_max_turns_override_is_none(self):
config = SubagentsAppConfig()
assert config.max_turns is None
def test_default_agents_empty(self):
config = SubagentsAppConfig()
assert config.agents == {}
def test_custom_global_timeout(self):
config = SubagentsAppConfig(timeout_seconds=1800)
def test_custom_global_runtime_overrides(self):
config = SubagentsAppConfig(timeout_seconds=1800, max_turns=120)
assert config.timeout_seconds == 1800
assert config.max_turns == 120
def test_rejects_zero_timeout(self):
with pytest.raises(ValueError):
SubagentsAppConfig(timeout_seconds=0)
with pytest.raises(ValueError):
SubagentsAppConfig(max_turns=0)
def test_rejects_negative_timeout(self):
with pytest.raises(ValueError):
SubagentsAppConfig(timeout_seconds=-60)
with pytest.raises(ValueError):
SubagentsAppConfig(max_turns=-60)
# ---------------------------------------------------------------------------
# SubagentsAppConfig.get_timeout_for()
# SubagentsAppConfig resolution helpers
# ---------------------------------------------------------------------------
class TestGetTimeoutFor:
class TestRuntimeResolution:
def test_returns_global_default_when_no_override(self):
config = SubagentsAppConfig(timeout_seconds=600)
assert config.get_timeout_for("general-purpose") == 600
assert config.get_timeout_for("bash") == 600
assert config.get_timeout_for("unknown-agent") == 600
assert config.get_max_turns_for("general-purpose", 100) == 100
assert config.get_max_turns_for("bash", 60) == 60
def test_returns_per_agent_override_when_set(self):
config = SubagentsAppConfig(
timeout_seconds=900,
agents={"bash": SubagentOverrideConfig(timeout_seconds=300)},
max_turns=120,
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, max_turns=80)},
)
assert config.get_timeout_for("bash") == 300
assert config.get_max_turns_for("bash", 60) == 80
def test_other_agents_still_use_global_default(self):
config = SubagentsAppConfig(
timeout_seconds=900,
agents={"bash": SubagentOverrideConfig(timeout_seconds=300)},
max_turns=140,
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, max_turns=80)},
)
assert config.get_timeout_for("general-purpose") == 900
assert config.get_max_turns_for("general-purpose", 100) == 140
def test_agent_with_none_override_falls_back_to_global(self):
config = SubagentsAppConfig(
timeout_seconds=900,
agents={"general-purpose": SubagentOverrideConfig(timeout_seconds=None)},
max_turns=150,
agents={"general-purpose": SubagentOverrideConfig(timeout_seconds=None, max_turns=None)},
)
assert config.get_timeout_for("general-purpose") == 900
assert config.get_max_turns_for("general-purpose", 100) == 150
def test_multiple_per_agent_overrides(self):
config = SubagentsAppConfig(
timeout_seconds=900,
max_turns=120,
agents={
"general-purpose": SubagentOverrideConfig(timeout_seconds=1800),
"bash": SubagentOverrideConfig(timeout_seconds=120),
"general-purpose": SubagentOverrideConfig(timeout_seconds=1800, max_turns=200),
"bash": SubagentOverrideConfig(timeout_seconds=120, max_turns=80),
},
)
assert config.get_timeout_for("general-purpose") == 1800
assert config.get_timeout_for("bash") == 120
assert config.get_max_turns_for("general-purpose", 100) == 200
assert config.get_max_turns_for("bash", 60) == 80
# ---------------------------------------------------------------------------
@ -139,54 +177,63 @@ class TestLoadSubagentsConfig:
_reset_subagents_config()
def test_load_global_timeout(self):
load_subagents_config_from_dict({"timeout_seconds": 300})
load_subagents_config_from_dict({"timeout_seconds": 300, "max_turns": 120})
assert get_subagents_app_config().timeout_seconds == 300
assert get_subagents_app_config().max_turns == 120
def test_load_with_per_agent_overrides(self):
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"max_turns": 120,
"agents": {
"general-purpose": {"timeout_seconds": 1800},
"bash": {"timeout_seconds": 60},
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
"bash": {"timeout_seconds": 60, "max_turns": 80},
},
}
)
cfg = get_subagents_app_config()
assert cfg.get_timeout_for("general-purpose") == 1800
assert cfg.get_timeout_for("bash") == 60
assert cfg.get_max_turns_for("general-purpose", 100) == 200
assert cfg.get_max_turns_for("bash", 60) == 80
def test_load_partial_override(self):
load_subagents_config_from_dict(
{
"timeout_seconds": 600,
"agents": {"bash": {"timeout_seconds": 120}},
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 70}},
}
)
cfg = get_subagents_app_config()
assert cfg.get_timeout_for("general-purpose") == 600
assert cfg.get_timeout_for("bash") == 120
assert cfg.get_max_turns_for("general-purpose", 100) == 100
assert cfg.get_max_turns_for("bash", 60) == 70
def test_load_empty_dict_uses_defaults(self):
load_subagents_config_from_dict({})
cfg = get_subagents_app_config()
assert cfg.timeout_seconds == 900
assert cfg.max_turns is None
assert cfg.agents == {}
def test_load_replaces_previous_config(self):
load_subagents_config_from_dict({"timeout_seconds": 100})
load_subagents_config_from_dict({"timeout_seconds": 100, "max_turns": 90})
assert get_subagents_app_config().timeout_seconds == 100
assert get_subagents_app_config().max_turns == 90
load_subagents_config_from_dict({"timeout_seconds": 200})
load_subagents_config_from_dict({"timeout_seconds": 200, "max_turns": 110})
assert get_subagents_app_config().timeout_seconds == 200
assert get_subagents_app_config().max_turns == 110
def test_singleton_returns_same_instance_between_calls(self):
load_subagents_config_from_dict({"timeout_seconds": 777})
load_subagents_config_from_dict({"timeout_seconds": 777, "max_turns": 123})
assert get_subagents_app_config() is get_subagents_app_config()
# ---------------------------------------------------------------------------
# registry.get_subagent_config timeout override applied
# registry.get_subagent_config runtime overrides applied
# ---------------------------------------------------------------------------
@ -211,25 +258,29 @@ class TestRegistryGetSubagentConfig:
_reset_subagents_config(timeout_seconds=900)
config = get_subagent_config("general-purpose")
assert config.timeout_seconds == 900
assert config.max_turns == 100
def test_global_timeout_override_applied(self):
from deerflow.subagents.registry import get_subagent_config
_reset_subagents_config(timeout_seconds=1800)
_reset_subagents_config(timeout_seconds=1800, max_turns=140)
config = get_subagent_config("general-purpose")
assert config.timeout_seconds == 1800
assert config.max_turns == 140
def test_per_agent_timeout_override_applied(self):
def test_per_agent_runtime_override_applied(self):
from deerflow.subagents.registry import get_subagent_config
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"agents": {"bash": {"timeout_seconds": 120}},
"max_turns": 120,
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 80}},
}
)
bash_config = get_subagent_config("bash")
assert bash_config.timeout_seconds == 120
assert bash_config.max_turns == 80
def test_per_agent_override_does_not_affect_other_agents(self):
from deerflow.subagents.registry import get_subagent_config
@ -237,11 +288,13 @@ class TestRegistryGetSubagentConfig:
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"agents": {"bash": {"timeout_seconds": 120}},
"max_turns": 120,
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 80}},
}
)
gp_config = get_subagent_config("general-purpose")
assert gp_config.timeout_seconds == 900
assert gp_config.max_turns == 120
def test_builtin_config_object_is_not_mutated(self):
"""Registry must return a new object, leaving the builtin default intact."""
@ -249,24 +302,27 @@ class TestRegistryGetSubagentConfig:
from deerflow.subagents.registry import get_subagent_config
original_timeout = BUILTIN_SUBAGENTS["bash"].timeout_seconds
load_subagents_config_from_dict({"timeout_seconds": 42})
original_max_turns = BUILTIN_SUBAGENTS["bash"].max_turns
load_subagents_config_from_dict({"timeout_seconds": 42, "max_turns": 88})
returned = get_subagent_config("bash")
assert returned.timeout_seconds == 42
assert returned.max_turns == 88
assert BUILTIN_SUBAGENTS["bash"].timeout_seconds == original_timeout
assert BUILTIN_SUBAGENTS["bash"].max_turns == original_max_turns
def test_config_preserves_other_fields(self):
"""Applying timeout override must not change other SubagentConfig fields."""
"""Applying runtime overrides must not change other SubagentConfig fields."""
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
from deerflow.subagents.registry import get_subagent_config
_reset_subagents_config(timeout_seconds=300)
_reset_subagents_config(timeout_seconds=300, max_turns=140)
original = BUILTIN_SUBAGENTS["general-purpose"]
overridden = get_subagent_config("general-purpose")
assert overridden.name == original.name
assert overridden.description == original.description
assert overridden.max_turns == original.max_turns
assert overridden.max_turns == 140
assert overridden.model == original.model
assert overridden.tools == original.tools
assert overridden.disallowed_tools == original.disallowed_tools
@ -291,9 +347,10 @@ class TestRegistryListSubagents:
def test_all_returned_configs_get_global_override(self):
from deerflow.subagents.registry import list_subagents
_reset_subagents_config(timeout_seconds=123)
_reset_subagents_config(timeout_seconds=123, max_turns=77)
for cfg in list_subagents():
assert cfg.timeout_seconds == 123, f"{cfg.name} has wrong timeout"
assert cfg.max_turns == 77, f"{cfg.name} has wrong max_turns"
def test_per_agent_overrides_reflected_in_list(self):
from deerflow.subagents.registry import list_subagents
@ -301,15 +358,18 @@ class TestRegistryListSubagents:
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"max_turns": 120,
"agents": {
"general-purpose": {"timeout_seconds": 1800},
"bash": {"timeout_seconds": 60},
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
"bash": {"timeout_seconds": 60, "max_turns": 80},
},
}
)
by_name = {cfg.name: cfg for cfg in list_subagents()}
assert by_name["general-purpose"].timeout_seconds == 1800
assert by_name["bash"].timeout_seconds == 60
assert by_name["general-purpose"].max_turns == 200
assert by_name["bash"].max_turns == 80
# ---------------------------------------------------------------------------

View File

@ -1,5 +1,5 @@
import asyncio
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock
from app.gateway.routers import suggestions
@ -43,7 +43,7 @@ def test_generate_suggestions_parses_and_limits(monkeypatch):
model_name=None,
)
fake_model = MagicMock()
fake_model.invoke.return_value = MagicMock(content='```json\n["Q1", "Q2", "Q3", "Q4"]\n```')
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content='```json\n["Q1", "Q2", "Q3", "Q4"]\n```'))
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
result = asyncio.run(suggestions.generate_suggestions("t1", req))
@ -61,7 +61,7 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch):
model_name=None,
)
fake_model = MagicMock()
fake_model.invoke.return_value = MagicMock(content=[{"type": "text", "text": '```json\n["Q1", "Q2"]\n```'}])
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "text", "text": '```json\n["Q1", "Q2"]\n```'}]))
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
result = asyncio.run(suggestions.generate_suggestions("t1", req))
@ -79,7 +79,7 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch):
model_name=None,
)
fake_model = MagicMock()
fake_model.invoke.return_value = MagicMock(content=[{"type": "output_text", "text": '```json\n["Q1", "Q2"]\n```'}])
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "output_text", "text": '```json\n["Q1", "Q2"]\n```'}]))
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
result = asyncio.run(suggestions.generate_suggestions("t1", req))
@ -94,7 +94,7 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch):
model_name=None,
)
fake_model = MagicMock()
fake_model.invoke.side_effect = RuntimeError("boom")
fake_model.ainvoke = AsyncMock(side_effect=RuntimeError("boom"))
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
result = asyncio.run(suggestions.generate_suggestions("t1", req))

View File

@ -0,0 +1,111 @@
"""Tests for thread_runs router with auth decorators.
These tests verify that auth decorators properly enforce permission checks
on run endpoints. They follow the same pattern as test_threads_router.py.
"""
from unittest.mock import MagicMock, patch
from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.gateway.auth.models import User
from app.gateway.authz import AuthContext
from app.gateway.routers.thread_runs import router
def test_create_run_requires_auth():
"""POST /{thread_id}/runs requires auth."""
app = FastAPI()
app.include_router(router)
with TestClient(app, raise_server_exceptions=False) as client:
response = client.post(
"/api/threads/test-thread/runs",
json={"assistant_id": "test"},
)
assert response.status_code == 401
def test_create_run_with_auth():
"""POST /{thread_id}/runs with valid auth passes through."""
app = FastAPI()
app.include_router(router)
mock_user = User(id=uuid4(), email="test@example.com", password_hash="hash")
mock_auth = AuthContext(
user=mock_user,
permissions=["runs:create", "threads:read", "threads:write"],
)
# Mock the checkpointer and run_manager to avoid 503s
mock_checkpointer = MagicMock()
mock_run_manager = MagicMock()
mock_run_manager.list_by_thread = MagicMock(return_value=[])
mock_stream_bridge = MagicMock()
with patch("app.gateway.routers.thread_runs.get_checkpointer", return_value=mock_checkpointer):
with patch("app.gateway.routers.thread_runs.get_run_manager", return_value=mock_run_manager):
with patch("app.gateway.routers.thread_runs.get_stream_bridge", return_value=mock_stream_bridge):
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
with TestClient(app, raise_server_exceptions=False) as client:
# Without a real checkpointer.setup, this will 500 - but the point is auth passed
response = client.post(
"/api/threads/test-thread/runs",
json={"assistant_id": "test"},
)
# Auth passed if we don't get 401
assert response.status_code != 401
def test_list_runs_requires_auth():
"""GET /{thread_id}/runs requires auth."""
app = FastAPI()
app.include_router(router)
with TestClient(app, raise_server_exceptions=False) as client:
response = client.get("/api/threads/test-thread/runs")
assert response.status_code == 401
def test_list_runs_with_auth():
"""GET /{thread_id}/runs with auth passes through."""
app = FastAPI()
app.include_router(router)
mock_user = User(id=uuid4(), email="test@example.com", password_hash="hash")
mock_auth = AuthContext(
user=mock_user,
permissions=["runs:read", "threads:read"],
)
mock_run_manager = MagicMock()
mock_run_manager.list_by_thread = MagicMock(return_value=[])
with patch("app.gateway.routers.thread_runs.get_run_manager", return_value=mock_run_manager):
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
with TestClient(app, raise_server_exceptions=False) as client:
response = client.get("/api/threads/test-thread/runs")
# Should not be 401 (may be 500 or other, but auth passed)
assert response.status_code != 401
def test_get_run_requires_auth():
"""GET /{thread_id}/runs/{run_id} requires auth."""
app = FastAPI()
app.include_router(router)
with TestClient(app, raise_server_exceptions=False) as client:
response = client.get("/api/threads/test-thread/runs/run-123")
assert response.status_code == 401
def test_cancel_run_requires_auth():
"""POST /{thread_id}/runs/{run_id}/cancel requires auth."""
app = FastAPI()
app.include_router(router)
with TestClient(app, raise_server_exceptions=False) as client:
response = client.post("/api/threads/test-thread/runs/run-123/cancel")
assert response.status_code == 401

Some files were not shown because too many files have changed in this diff Show More