mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-28 12:48:40 +00:00
feat(auth): authentication module with multi-tenant isolation (RFC-001)
Introduce an always-on auth layer with auto-created admin on first boot, multi-tenant isolation for threads/stores, and a full setup/login flow. Backend - JWT access tokens with `ver` field for stale-token rejection; bump on password/email change - Password hashing, HttpOnly+Secure cookies (Secure derived from request scheme at runtime) - CSRF middleware covering both REST and LangGraph routes - IP-based login rate limiting (5 attempts / 5-min lockout) with bounded dict growth and X-Forwarded-For bypass fix - Multi-worker-safe admin auto-creation (single DB write, WAL once) - needs_setup + token_version on User model; SQLite schema migration - Thread/store isolation by owner; orphan thread migration on first admin registration - thread_id validated as UUID to prevent log injection - CLI tool to reset admin password - Decorator-based authz module extracted from auth core Frontend - Login and setup pages with SSR guard for needs_setup flow - Account settings page (change password / email) - AuthProvider + route guards; skips redirect when no users registered - i18n (en-US / zh-CN) for auth surfaces - Typed auth API client; parseAuthError unwraps FastAPI detail envelope Infra & tooling - Unified `serve.sh` with gateway mode + auto dep install - Public PyPI uv.toml pin for CI compatibility - Regenerated uv.lock with public index Tests - HTTP vs HTTPS cookie security tests - Auth middleware, rate limiter, CSRF, setup flow coverage
This commit is contained in:
parent
636053fb6d
commit
27b66d6753
@ -6,6 +6,11 @@ JINA_API_KEY=your-jina-api-key
|
||||
|
||||
# InfoQuest API Key
|
||||
INFOQUEST_API_KEY=your-infoquest-api-key
|
||||
# Authentication — JWT secret for session signing
|
||||
# If not set, an ephemeral secret is auto-generated (sessions lost on restart)
|
||||
# Generate with: python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
# AUTH_JWT_SECRET=your-secure-jwt-secret-here
|
||||
|
||||
# CORS Origins (comma-separated) - e.g., http://localhost:3000,http://localhost:3001
|
||||
# CORS_ORIGINS=http://localhost:3000
|
||||
|
||||
@ -32,3 +37,5 @@ INFOQUEST_API_KEY=your-infoquest-api-key
|
||||
|
||||
# GitHub API Token
|
||||
# GITHUB_TOKEN=your-github-token
|
||||
# WECOM_BOT_ID=your-wecom-bot-id
|
||||
# WECOM_BOT_SECRET=your-wecom-bot-secret
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@ -54,3 +54,4 @@ web/
|
||||
# Deployment artifacts
|
||||
backend/Dockerfile.langgraph
|
||||
config.yaml.bak
|
||||
.gstack/
|
||||
|
||||
@ -310,7 +310,7 @@ Every pull request runs the backend regression workflow at [.github/workflows/ba
|
||||
|
||||
- [Configuration Guide](backend/docs/CONFIGURATION.md) - Setup and configuration
|
||||
- [Architecture Overview](backend/CLAUDE.md) - Technical architecture
|
||||
- [MCP Setup Guide](MCP_SETUP.md) - Model Context Protocol configuration
|
||||
- [MCP Setup Guide](backend/docs/MCP_SERVER.md) - Model Context Protocol configuration
|
||||
|
||||
## Need Help?
|
||||
|
||||
|
||||
82
Makefile
82
Makefile
@ -1,6 +1,6 @@
|
||||
# DeerFlow - Unified Development Environment
|
||||
|
||||
.PHONY: help config config-upgrade check install dev dev-daemon start stop up down clean docker-init docker-start docker-stop docker-logs docker-logs-frontend docker-logs-gateway
|
||||
.PHONY: help config config-upgrade check install dev dev-pro dev-daemon dev-daemon-pro start start-pro start-daemon start-daemon-pro stop up up-pro down clean docker-init docker-start docker-start-pro docker-stop docker-logs docker-logs-frontend docker-logs-gateway
|
||||
|
||||
BASH ?= bash
|
||||
|
||||
@ -20,18 +20,25 @@ help:
|
||||
@echo " make install - Install all dependencies (frontend + backend)"
|
||||
@echo " make setup-sandbox - Pre-pull sandbox container image (recommended)"
|
||||
@echo " make dev - Start all services in development mode (with hot-reloading)"
|
||||
@echo " make dev-daemon - Start all services in background (daemon mode)"
|
||||
@echo " make dev-pro - Start in dev + Gateway mode (experimental, no LangGraph server)"
|
||||
@echo " make dev-daemon - Start dev services in background (daemon mode)"
|
||||
@echo " make dev-daemon-pro - Start dev daemon + Gateway mode (experimental)"
|
||||
@echo " make start - Start all services in production mode (optimized, no hot-reloading)"
|
||||
@echo " make start-pro - Start in prod + Gateway mode (experimental)"
|
||||
@echo " make start-daemon - Start prod services in background (daemon mode)"
|
||||
@echo " make start-daemon-pro - Start prod daemon + Gateway mode (experimental)"
|
||||
@echo " make stop - Stop all running services"
|
||||
@echo " make clean - Clean up processes and temporary files"
|
||||
@echo ""
|
||||
@echo "Docker Production Commands:"
|
||||
@echo " make up - Build and start production Docker services (localhost:2026)"
|
||||
@echo " make up-pro - Build and start production Docker in Gateway mode (experimental)"
|
||||
@echo " make down - Stop and remove production Docker containers"
|
||||
@echo ""
|
||||
@echo "Docker Development Commands:"
|
||||
@echo " make docker-init - Pull the sandbox image"
|
||||
@echo " make docker-start - Start Docker services (mode-aware from config.yaml, localhost:2026)"
|
||||
@echo " make docker-start-pro - Start Docker in Gateway mode (experimental, no LangGraph container)"
|
||||
@echo " make docker-stop - Stop Docker development services"
|
||||
@echo " make docker-logs - View Docker development logs"
|
||||
@echo " make docker-logs-frontend - View Docker frontend logs"
|
||||
@ -105,6 +112,15 @@ else
|
||||
@./scripts/serve.sh --dev
|
||||
endif
|
||||
|
||||
# Start all services in dev + Gateway mode (experimental: agent runtime embedded in Gateway)
|
||||
dev-pro:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --dev --gateway
|
||||
else
|
||||
@./scripts/serve.sh --dev --gateway
|
||||
endif
|
||||
|
||||
# Start all services in production mode (with optimizations)
|
||||
start:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
@ -114,30 +130,54 @@ else
|
||||
@./scripts/serve.sh --prod
|
||||
endif
|
||||
|
||||
# Start all services in prod + Gateway mode (experimental)
|
||||
start-pro:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --prod --gateway
|
||||
else
|
||||
@./scripts/serve.sh --prod --gateway
|
||||
endif
|
||||
|
||||
# Start all services in daemon mode (background)
|
||||
dev-daemon:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/start-daemon.sh
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --dev --daemon
|
||||
else
|
||||
@./scripts/start-daemon.sh
|
||||
@./scripts/serve.sh --dev --daemon
|
||||
endif
|
||||
|
||||
# Start daemon + Gateway mode (experimental)
|
||||
dev-daemon-pro:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --dev --gateway --daemon
|
||||
else
|
||||
@./scripts/serve.sh --dev --gateway --daemon
|
||||
endif
|
||||
|
||||
# Start prod services in daemon mode (background)
|
||||
start-daemon:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --prod --daemon
|
||||
else
|
||||
@./scripts/serve.sh --prod --daemon
|
||||
endif
|
||||
|
||||
# Start prod daemon + Gateway mode (experimental)
|
||||
start-daemon-pro:
|
||||
@$(PYTHON) ./scripts/check.py
|
||||
ifeq ($(OS),Windows_NT)
|
||||
@call scripts\run-with-git-bash.cmd ./scripts/serve.sh --prod --gateway --daemon
|
||||
else
|
||||
@./scripts/serve.sh --prod --gateway --daemon
|
||||
endif
|
||||
|
||||
# Stop all services
|
||||
stop:
|
||||
@echo "Stopping all services..."
|
||||
@-pkill -f "langgraph dev" 2>/dev/null || true
|
||||
@-pkill -f "uvicorn app.gateway.app:app" 2>/dev/null || true
|
||||
@-pkill -f "next dev" 2>/dev/null || true
|
||||
@-pkill -f "next start" 2>/dev/null || true
|
||||
@-pkill -f "next-server" 2>/dev/null || true
|
||||
@-pkill -f "next-server" 2>/dev/null || true
|
||||
@-nginx -c $(PWD)/docker/nginx/nginx.local.conf -p $(PWD) -s quit 2>/dev/null || true
|
||||
@sleep 1
|
||||
@-pkill -9 nginx 2>/dev/null || true
|
||||
@echo "Cleaning up sandbox containers..."
|
||||
@-./scripts/cleanup-containers.sh deer-flow-sandbox 2>/dev/null || true
|
||||
@echo "✓ All services stopped"
|
||||
@./scripts/serve.sh --stop
|
||||
|
||||
# Clean up
|
||||
clean: stop
|
||||
@ -159,6 +199,10 @@ docker-init:
|
||||
docker-start:
|
||||
@./scripts/docker.sh start
|
||||
|
||||
# Start Docker in Gateway mode (experimental)
|
||||
docker-start-pro:
|
||||
@./scripts/docker.sh start --gateway
|
||||
|
||||
# Stop Docker development environment
|
||||
docker-stop:
|
||||
@./scripts/docker.sh stop
|
||||
@ -181,6 +225,10 @@ docker-logs-gateway:
|
||||
up:
|
||||
@./scripts/deploy.sh
|
||||
|
||||
# Build and start production services in Gateway mode
|
||||
up-pro:
|
||||
@./scripts/deploy.sh --gateway
|
||||
|
||||
# Stop and remove production containers
|
||||
down:
|
||||
@./scripts/deploy.sh down
|
||||
|
||||
77
README.md
77
README.md
@ -46,6 +46,7 @@ DeerFlow has newly integrated the intelligent search and crawling toolset indepe
|
||||
|
||||
- [🦌 DeerFlow - 2.0](#-deerflow---20)
|
||||
- [Official Website](#official-website)
|
||||
- [Coding Plan from ByteDance Volcengine](#coding-plan-from-bytedance-volcengine)
|
||||
- [InfoQuest](#infoquest)
|
||||
- [Table of Contents](#table-of-contents)
|
||||
- [One-Line Agent Setup](#one-line-agent-setup)
|
||||
@ -59,6 +60,8 @@ DeerFlow has newly integrated the intelligent search and crawling toolset indepe
|
||||
- [MCP Server](#mcp-server)
|
||||
- [IM Channels](#im-channels)
|
||||
- [LangSmith Tracing](#langsmith-tracing)
|
||||
- [Langfuse Tracing](#langfuse-tracing)
|
||||
- [Using Both Providers](#using-both-providers)
|
||||
- [From Deep Research to Super Agent Harness](#from-deep-research-to-super-agent-harness)
|
||||
- [Core Features](#core-features)
|
||||
- [Skills \& Tools](#skills--tools)
|
||||
@ -71,6 +74,8 @@ DeerFlow has newly integrated the intelligent search and crawling toolset indepe
|
||||
- [Embedded Python Client](#embedded-python-client)
|
||||
- [Documentation](#documentation)
|
||||
- [⚠️ Security Notice](#️-security-notice)
|
||||
- [Improper Deployment May Introduce Security Risks](#improper-deployment-may-introduce-security-risks)
|
||||
- [Security Recommendations](#security-recommendations)
|
||||
- [Contributing](#contributing)
|
||||
- [License](#license)
|
||||
- [Acknowledgments](#acknowledgments)
|
||||
@ -275,6 +280,60 @@ On Windows, run the local development flow from Git Bash. Native `cmd.exe` and P
|
||||
|
||||
6. **Access**: http://localhost:2026
|
||||
|
||||
#### Startup Modes
|
||||
|
||||
DeerFlow supports multiple startup modes across two dimensions:
|
||||
|
||||
- **Dev / Prod** — dev enables hot-reload; prod uses pre-built frontend
|
||||
- **Standard / Gateway** — standard uses a separate LangGraph server (4 processes); Gateway mode (experimental) embeds the agent runtime in the Gateway API (3 processes)
|
||||
|
||||
| | **Local Foreground** | **Local Daemon** | **Docker Dev** | **Docker Prod** |
|
||||
|---|---|---|---|---|
|
||||
| **Dev** | `./scripts/serve.sh --dev`<br/>`make dev` | `./scripts/serve.sh --dev --daemon`<br/>`make dev-daemon` | `./scripts/docker.sh start`<br/>`make docker-start` | — |
|
||||
| **Dev + Gateway** | `./scripts/serve.sh --dev --gateway`<br/>`make dev-pro` | `./scripts/serve.sh --dev --gateway --daemon`<br/>`make dev-daemon-pro` | `./scripts/docker.sh start --gateway`<br/>`make docker-start-pro` | — |
|
||||
| **Prod** | `./scripts/serve.sh --prod`<br/>`make start` | `./scripts/serve.sh --prod --daemon`<br/>`make start-daemon` | — | `./scripts/deploy.sh`<br/>`make up` |
|
||||
| **Prod + Gateway** | `./scripts/serve.sh --prod --gateway`<br/>`make start-pro` | `./scripts/serve.sh --prod --gateway --daemon`<br/>`make start-daemon-pro` | — | `./scripts/deploy.sh --gateway`<br/>`make up-pro` |
|
||||
|
||||
| Action | Local | Docker Dev | Docker Prod |
|
||||
|---|---|---|---|
|
||||
| **Stop** | `./scripts/serve.sh --stop`<br/>`make stop` | `./scripts/docker.sh stop`<br/>`make docker-stop` | `./scripts/deploy.sh down`<br/>`make down` |
|
||||
| **Restart** | `./scripts/serve.sh --restart [flags]` | `./scripts/docker.sh restart` | — |
|
||||
|
||||
> **Gateway mode** eliminates the LangGraph server process — the Gateway API handles agent execution directly via async tasks, managing its own concurrency.
|
||||
|
||||
#### Why Gateway Mode?
|
||||
|
||||
In standard mode, DeerFlow runs a dedicated [LangGraph Platform](https://langchain-ai.github.io/langgraph/) server alongside the Gateway API. This architecture works well but has trade-offs:
|
||||
|
||||
| | Standard Mode | Gateway Mode |
|
||||
|---|---|---|
|
||||
| **Architecture** | Gateway (REST API) + LangGraph (agent runtime) | Gateway embeds agent runtime |
|
||||
| **Concurrency** | `--n-jobs-per-worker` per worker (requires license) | `--workers` × async tasks (no per-worker cap) |
|
||||
| **Containers / Processes** | 4 (frontend, gateway, langgraph, nginx) | 3 (frontend, gateway, nginx) |
|
||||
| **Resource usage** | Higher (two Python runtimes) | Lower (single Python runtime) |
|
||||
| **LangGraph Platform license** | Required for production images | Not required |
|
||||
| **Cold start** | Slower (two services to initialize) | Faster |
|
||||
|
||||
Both modes are functionally equivalent — the same agents, tools, and skills work in either mode.
|
||||
|
||||
#### Docker Production Deployment
|
||||
|
||||
`deploy.sh` supports building and starting separately. Images are mode-agnostic — runtime mode is selected at start time:
|
||||
|
||||
```bash
|
||||
# One-step (build + start)
|
||||
deploy.sh # standard mode (default)
|
||||
deploy.sh --gateway # gateway mode
|
||||
|
||||
# Two-step (build once, start with any mode)
|
||||
deploy.sh build # build all images
|
||||
deploy.sh start # start in standard mode
|
||||
deploy.sh start --gateway # start in gateway mode
|
||||
|
||||
# Stop
|
||||
deploy.sh down
|
||||
```
|
||||
|
||||
### Advanced
|
||||
#### Sandbox Mode
|
||||
|
||||
@ -302,6 +361,7 @@ DeerFlow supports receiving tasks from messaging apps. Channels auto-start when
|
||||
| Telegram | Bot API (long-polling) | Easy |
|
||||
| Slack | Socket Mode | Moderate |
|
||||
| Feishu / Lark | WebSocket | Moderate |
|
||||
| WeCom | WebSocket | Moderate |
|
||||
|
||||
**Configuration in `config.yaml`:**
|
||||
|
||||
@ -329,6 +389,11 @@ channels:
|
||||
# domain: https://open.feishu.cn # China (default)
|
||||
# domain: https://open.larksuite.com # International
|
||||
|
||||
wecom:
|
||||
enabled: true
|
||||
bot_id: $WECOM_BOT_ID
|
||||
bot_secret: $WECOM_BOT_SECRET
|
||||
|
||||
slack:
|
||||
enabled: true
|
||||
bot_token: $SLACK_BOT_TOKEN # xoxb-...
|
||||
@ -372,6 +437,10 @@ SLACK_APP_TOKEN=xapp-...
|
||||
# Feishu / Lark
|
||||
FEISHU_APP_ID=cli_xxxx
|
||||
FEISHU_APP_SECRET=your_app_secret
|
||||
|
||||
# WeCom
|
||||
WECOM_BOT_ID=your_bot_id
|
||||
WECOM_BOT_SECRET=your_bot_secret
|
||||
```
|
||||
|
||||
**Telegram Setup**
|
||||
@ -394,6 +463,14 @@ FEISHU_APP_SECRET=your_app_secret
|
||||
3. Under **Events**, subscribe to `im.message.receive_v1` and select **Long Connection** mode.
|
||||
4. Copy the App ID and App Secret. Set `FEISHU_APP_ID` and `FEISHU_APP_SECRET` in `.env` and enable the channel in `config.yaml`.
|
||||
|
||||
**WeCom Setup**
|
||||
|
||||
1. Create a bot on the WeCom AI Bot platform and obtain the `bot_id` and `bot_secret`.
|
||||
2. Enable `channels.wecom` in `config.yaml` and fill in `bot_id` / `bot_secret`.
|
||||
3. Set `WECOM_BOT_ID` and `WECOM_BOT_SECRET` in `.env`.
|
||||
4. Make sure backend dependencies include `wecom-aibot-python-sdk`. The channel uses a WebSocket long connection and does not require a public callback URL.
|
||||
5. The current integration supports inbound text, image, and file messages. Final images/files generated by the agent are also sent back to the WeCom conversation.
|
||||
|
||||
When DeerFlow runs in Docker Compose, IM channels execute inside the `gateway` container. In that case, do not point `channels.langgraph_url` or `channels.gateway_url` at `localhost`; use container service names such as `http://langgraph:2024` and `http://gateway:8001`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` and `DEER_FLOW_CHANNELS_GATEWAY_URL`.
|
||||
|
||||
**Commands**
|
||||
|
||||
18
README_zh.md
18
README_zh.md
@ -232,6 +232,7 @@ DeerFlow 支持从即时通讯应用接收任务。只要配置完成,对应
|
||||
| Telegram | Bot API(long-polling) | 简单 |
|
||||
| Slack | Socket Mode | 中等 |
|
||||
| Feishu / Lark | WebSocket | 中等 |
|
||||
| 企业微信智能机器人 | WebSocket | 中等 |
|
||||
|
||||
**`config.yaml` 中的配置示例:**
|
||||
|
||||
@ -259,6 +260,11 @@ channels:
|
||||
# domain: https://open.feishu.cn # 国内版(默认)
|
||||
# domain: https://open.larksuite.com # 国际版
|
||||
|
||||
wecom:
|
||||
enabled: true
|
||||
bot_id: $WECOM_BOT_ID
|
||||
bot_secret: $WECOM_BOT_SECRET
|
||||
|
||||
slack:
|
||||
enabled: true
|
||||
bot_token: $SLACK_BOT_TOKEN # xoxb-...
|
||||
@ -302,6 +308,10 @@ SLACK_APP_TOKEN=xapp-...
|
||||
# Feishu / Lark
|
||||
FEISHU_APP_ID=cli_xxxx
|
||||
FEISHU_APP_SECRET=your_app_secret
|
||||
|
||||
# 企业微信智能机器人
|
||||
WECOM_BOT_ID=your_bot_id
|
||||
WECOM_BOT_SECRET=your_bot_secret
|
||||
```
|
||||
|
||||
**Telegram 配置**
|
||||
@ -324,6 +334,14 @@ FEISHU_APP_SECRET=your_app_secret
|
||||
3. 在 **事件订阅** 中订阅 `im.message.receive_v1`,连接方式选择 **长连接**。
|
||||
4. 复制 App ID 和 App Secret,在 `.env` 中设置 `FEISHU_APP_ID` 和 `FEISHU_APP_SECRET`,并在 `config.yaml` 中启用该渠道。
|
||||
|
||||
**企业微信智能机器人配置**
|
||||
|
||||
1. 在企业微信智能机器人平台创建机器人,获取 `bot_id` 和 `bot_secret`。
|
||||
2. 在 `config.yaml` 中启用 `channels.wecom`,并填入 `bot_id` / `bot_secret`。
|
||||
3. 在 `.env` 中设置 `WECOM_BOT_ID` 和 `WECOM_BOT_SECRET`。
|
||||
4. 安装后端依赖时确保包含 `wecom-aibot-python-sdk`,渠道会通过 WebSocket 长连接接收消息,无需公网回调地址。
|
||||
5. 当前支持文本、图片和文件入站消息;agent 生成的最终图片/文件也会回传到企业微信会话中。
|
||||
|
||||
**命令**
|
||||
|
||||
渠道连接完成后,你可以直接在聊天窗口里和 DeerFlow 交互:
|
||||
|
||||
@ -13,6 +13,10 @@ DeerFlow is a LangGraph-based AI super agent system with a full-stack architectu
|
||||
- **Nginx** (port 2026): Unified reverse proxy entry point
|
||||
- **Provisioner** (port 8002, optional in Docker dev): Started only when sandbox is configured for provisioner/Kubernetes mode
|
||||
|
||||
**Runtime Modes**:
|
||||
- **Standard mode** (`make dev`): LangGraph Server handles agent execution as a separate process. 4 processes total.
|
||||
- **Gateway mode** (`make dev-pro`, experimental): Agent runtime embedded in Gateway via `RunManager` + `run_agent()` + `StreamBridge` (`packages/harness/deerflow/runtime/`). Service manages its own concurrency via async tasks. 3 processes total, no LangGraph Server.
|
||||
|
||||
**Project Structure**:
|
||||
```
|
||||
deer-flow/
|
||||
@ -80,6 +84,8 @@ When making code changes, you MUST update the relevant documentation:
|
||||
make check # Check system requirements
|
||||
make install # Install all dependencies (frontend + backend)
|
||||
make dev # Start all services (LangGraph + Gateway + Frontend + Nginx), with config.yaml preflight
|
||||
make dev-pro # Gateway mode (experimental): skip LangGraph, agent runtime embedded in Gateway
|
||||
make start-pro # Production + Gateway mode (experimental)
|
||||
make stop # Stop all services
|
||||
```
|
||||
|
||||
@ -436,8 +442,25 @@ make dev
|
||||
|
||||
This starts all services and makes the application available at `http://localhost:2026`.
|
||||
|
||||
**All startup modes:**
|
||||
|
||||
| | **Local Foreground** | **Local Daemon** | **Docker Dev** | **Docker Prod** |
|
||||
|---|---|---|---|---|
|
||||
| **Dev** | `./scripts/serve.sh --dev`<br/>`make dev` | `./scripts/serve.sh --dev --daemon`<br/>`make dev-daemon` | `./scripts/docker.sh start`<br/>`make docker-start` | — |
|
||||
| **Dev + Gateway** | `./scripts/serve.sh --dev --gateway`<br/>`make dev-pro` | `./scripts/serve.sh --dev --gateway --daemon`<br/>`make dev-daemon-pro` | `./scripts/docker.sh start --gateway`<br/>`make docker-start-pro` | — |
|
||||
| **Prod** | `./scripts/serve.sh --prod`<br/>`make start` | `./scripts/serve.sh --prod --daemon`<br/>`make start-daemon` | — | `./scripts/deploy.sh`<br/>`make up` |
|
||||
| **Prod + Gateway** | `./scripts/serve.sh --prod --gateway`<br/>`make start-pro` | `./scripts/serve.sh --prod --gateway --daemon`<br/>`make start-daemon-pro` | — | `./scripts/deploy.sh --gateway`<br/>`make up-pro` |
|
||||
|
||||
| Action | Local | Docker Dev | Docker Prod |
|
||||
|---|---|---|---|
|
||||
| **Stop** | `./scripts/serve.sh --stop`<br/>`make stop` | `./scripts/docker.sh stop`<br/>`make docker-stop` | `./scripts/deploy.sh down`<br/>`make down` |
|
||||
| **Restart** | `./scripts/serve.sh --restart [flags]` | `./scripts/docker.sh restart` | — |
|
||||
|
||||
Gateway mode embeds the agent runtime in Gateway, no LangGraph server.
|
||||
|
||||
**Nginx routing**:
|
||||
- `/api/langgraph/*` → LangGraph Server (2024)
|
||||
- Standard mode: `/api/langgraph/*` → LangGraph Server (2024)
|
||||
- Gateway mode: `/api/langgraph/*` → Gateway embedded runtime (8001) (via envsubst)
|
||||
- `/api/*` (other) → Gateway API (8001)
|
||||
- `/` (non-API) → Frontend (3000)
|
||||
|
||||
|
||||
@ -1,34 +1,86 @@
|
||||
# Backend Development Dockerfile
|
||||
# Backend Dockerfile — multi-stage build
|
||||
# Stage 1 (builder): compiles native Python extensions with build-essential
|
||||
# Stage 2 (dev): retains toolchain for dev containers (uv sync at startup)
|
||||
# Stage 3 (runtime): clean image without compiler toolchain for production
|
||||
|
||||
# UV source image (override for restricted networks that cannot reach ghcr.io)
|
||||
ARG UV_IMAGE=ghcr.io/astral-sh/uv:0.7.20
|
||||
FROM ${UV_IMAGE} AS uv-source
|
||||
|
||||
FROM python:3.12-slim-bookworm
|
||||
# ── Stage 1: Builder ──────────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim-bookworm AS builder
|
||||
|
||||
ARG NODE_MAJOR=22
|
||||
ARG NODE_VERSION=22.16.0
|
||||
ARG APT_MIRROR
|
||||
ARG UV_INDEX_URL
|
||||
ARG NODE_DIST_URL
|
||||
|
||||
# Optionally override apt mirror for restricted networks (e.g. APT_MIRROR=mirrors.aliyun.com)
|
||||
# Optionally override apt mirror for restricted networks (e.g. APT_MIRROR=mirrors.byted.org)
|
||||
RUN if [ -n "${APT_MIRROR}" ]; then \
|
||||
sed -i "s|deb.debian.org|${APT_MIRROR}|g" /etc/apt/sources.list.d/debian.sources 2>/dev/null || true; \
|
||||
sed -i "s|deb.debian.org|${APT_MIRROR}|g" /etc/apt/sources.list 2>/dev/null || true; \
|
||||
fi
|
||||
|
||||
# Install system dependencies + Node.js (provides npx for MCP servers)
|
||||
# Install build tools + Node.js (build-essential needed for native Python extensions)
|
||||
# NODE_DIST_URL: base URL for Node.js binary tarballs in restricted networks.
|
||||
# npmmirror: https://registry.npmmirror.com/-/binary/node
|
||||
# official: https://nodejs.org/dist (default, via nodesource apt)
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
build-essential \
|
||||
gnupg \
|
||||
ca-certificates \
|
||||
&& mkdir -p /etc/apt/keyrings \
|
||||
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg \
|
||||
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" > /etc/apt/sources.list.d/nodesource.list \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y nodejs \
|
||||
xz-utils \
|
||||
&& if [ -n "${NODE_DIST_URL}" ]; then \
|
||||
curl -fsSL "${NODE_DIST_URL}/v${NODE_VERSION}/node-v${NODE_VERSION}-linux-x64.tar.xz" \
|
||||
| tar -xJ --strip-components=1 -C /usr/local \
|
||||
&& ln -sf /usr/local/bin/node /usr/bin/node \
|
||||
&& ln -sf /usr/local/lib/node_modules /usr/lib/node_modules; \
|
||||
else \
|
||||
mkdir -p /etc/apt/keyrings \
|
||||
&& curl -fsSL https://deb.nodesource.com/gpgkey/nodesource-repo.gpg.key | gpg --dearmor -o /etc/apt/keyrings/nodesource.gpg \
|
||||
&& echo "deb [signed-by=/etc/apt/keyrings/nodesource.gpg] https://deb.nodesource.com/node_${NODE_MAJOR}.x nodistro main" > /etc/apt/sources.list.d/nodesource.list \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y nodejs; \
|
||||
fi \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install uv (source image overridable via UV_IMAGE build arg)
|
||||
COPY --from=uv-source /uv /uvx /usr/local/bin/
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy backend source code
|
||||
COPY backend ./backend
|
||||
|
||||
# Install dependencies with cache mount
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync"
|
||||
|
||||
# ── Stage 2: Dev ──────────────────────────────────────────────────────────────
|
||||
# Retains compiler toolchain from builder so startup-time `uv sync` can build
|
||||
# source distributions in development containers.
|
||||
FROM builder AS dev
|
||||
|
||||
# Install Docker CLI (for DooD: allows starting sandbox containers via host Docker socket)
|
||||
COPY --from=docker:cli /usr/local/bin/docker /usr/local/bin/docker
|
||||
|
||||
EXPOSE 8001 2024
|
||||
|
||||
CMD ["sh", "-c", "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001"]
|
||||
|
||||
# ── Stage 3: Runtime ──────────────────────────────────────────────────────────
|
||||
# Clean image without build-essential — reduces size (~200 MB) and attack surface.
|
||||
FROM python:3.12-slim-bookworm
|
||||
|
||||
# Copy Node.js runtime from builder (provides npx for MCP servers)
|
||||
COPY --from=builder /usr/bin/node /usr/bin/node
|
||||
COPY --from=builder /usr/lib/node_modules /usr/lib/node_modules
|
||||
RUN ln -s ../lib/node_modules/npm/bin/npm-cli.js /usr/bin/npm \
|
||||
&& ln -s ../lib/node_modules/npm/bin/npx-cli.js /usr/bin/npx
|
||||
|
||||
# Install Docker CLI (for DooD: allows starting sandbox containers via host Docker socket)
|
||||
COPY --from=docker:cli /usr/local/bin/docker /usr/local/bin/docker
|
||||
|
||||
@ -38,12 +90,8 @@ COPY --from=uv-source /uv /uvx /usr/local/bin/
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Copy frontend source code
|
||||
COPY backend ./backend
|
||||
|
||||
# Install dependencies with cache mount
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync"
|
||||
# Copy backend with pre-built virtualenv from builder
|
||||
COPY --from=builder /app/backend ./backend
|
||||
|
||||
# Expose ports (gateway: 8001, langgraph: 2024)
|
||||
EXPOSE 8001 2024
|
||||
|
||||
@ -2,7 +2,7 @@ install:
|
||||
uv sync
|
||||
|
||||
dev:
|
||||
uv run langgraph dev --no-browser --allow-blocking --no-reload --n-jobs-per-worker 10
|
||||
uv run langgraph dev --no-browser --no-reload --n-jobs-per-worker 10
|
||||
|
||||
gateway:
|
||||
PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001
|
||||
|
||||
@ -206,7 +206,9 @@ class FeishuChannel(Channel):
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
logger.error("[Feishu] send failed after %d attempts: %s", _max_retries, last_exc)
|
||||
raise last_exc # type: ignore[misc]
|
||||
if last_exc is None:
|
||||
raise RuntimeError("Feishu send failed without an exception from any attempt")
|
||||
raise last_exc
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
if not self._api_client:
|
||||
|
||||
@ -7,9 +7,10 @@ import logging
|
||||
import mimetypes
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from langgraph_sdk.errors import ConflictError
|
||||
|
||||
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
|
||||
@ -36,8 +37,49 @@ CHANNEL_CAPABILITIES = {
|
||||
"feishu": {"supports_streaming": True},
|
||||
"slack": {"supports_streaming": False},
|
||||
"telegram": {"supports_streaming": False},
|
||||
"wecom": {"supports_streaming": True},
|
||||
}
|
||||
|
||||
InboundFileReader = Callable[[dict[str, Any], httpx.AsyncClient], Awaitable[bytes | None]]
|
||||
|
||||
|
||||
INBOUND_FILE_READERS: dict[str, InboundFileReader] = {}
|
||||
|
||||
|
||||
def register_inbound_file_reader(channel_name: str, reader: InboundFileReader) -> None:
|
||||
INBOUND_FILE_READERS[channel_name] = reader
|
||||
|
||||
|
||||
async def _read_http_inbound_file(file_info: dict[str, Any], client: httpx.AsyncClient) -> bytes | None:
|
||||
url = file_info.get("url")
|
||||
if not isinstance(url, str) or not url:
|
||||
return None
|
||||
|
||||
resp = await client.get(url)
|
||||
resp.raise_for_status()
|
||||
return resp.content
|
||||
|
||||
|
||||
async def _read_wecom_inbound_file(file_info: dict[str, Any], client: httpx.AsyncClient) -> bytes | None:
|
||||
data = await _read_http_inbound_file(file_info, client)
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
aeskey = file_info.get("aeskey") if isinstance(file_info.get("aeskey"), str) else None
|
||||
if not aeskey:
|
||||
return data
|
||||
|
||||
try:
|
||||
from aibot.crypto_utils import decrypt_file
|
||||
except Exception:
|
||||
logger.exception("[Manager] failed to import WeCom decrypt_file")
|
||||
return None
|
||||
|
||||
return decrypt_file(data, aeskey)
|
||||
|
||||
|
||||
register_inbound_file_reader("wecom", _read_wecom_inbound_file)
|
||||
|
||||
|
||||
class InvalidChannelSessionConfigError(ValueError):
|
||||
"""Raised when IM channel session overrides contain invalid agent config."""
|
||||
@ -342,6 +384,105 @@ def _prepare_artifact_delivery(
|
||||
return response_text, attachments
|
||||
|
||||
|
||||
async def _ingest_inbound_files(thread_id: str, msg: InboundMessage) -> list[dict[str, Any]]:
|
||||
if not msg.files:
|
||||
return []
|
||||
|
||||
from deerflow.uploads.manager import claim_unique_filename, ensure_uploads_dir, normalize_filename
|
||||
|
||||
uploads_dir = ensure_uploads_dir(thread_id)
|
||||
seen_names = {entry.name for entry in uploads_dir.iterdir() if entry.is_file()}
|
||||
|
||||
created: list[dict[str, Any]] = []
|
||||
file_reader = INBOUND_FILE_READERS.get(msg.channel_name, _read_http_inbound_file)
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(20.0)) as client:
|
||||
for idx, f in enumerate(msg.files):
|
||||
if not isinstance(f, dict):
|
||||
continue
|
||||
|
||||
ftype = f.get("type") if isinstance(f.get("type"), str) else "file"
|
||||
filename = f.get("filename") if isinstance(f.get("filename"), str) else ""
|
||||
|
||||
try:
|
||||
data = await file_reader(f, client)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[Manager] failed to read inbound file: channel=%s, file=%s",
|
||||
msg.channel_name,
|
||||
f.get("url") or filename or idx,
|
||||
)
|
||||
continue
|
||||
|
||||
if data is None:
|
||||
logger.warning(
|
||||
"[Manager] inbound file reader returned no data: channel=%s, file=%s",
|
||||
msg.channel_name,
|
||||
f.get("url") or filename or idx,
|
||||
)
|
||||
continue
|
||||
|
||||
if not filename:
|
||||
ext = ".bin"
|
||||
if ftype == "image":
|
||||
ext = ".png"
|
||||
filename = f"{msg.thread_ts or 'msg'}_{idx}{ext}"
|
||||
|
||||
try:
|
||||
safe_name = claim_unique_filename(normalize_filename(filename), seen_names)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"[Manager] skipping inbound file with unsafe filename: channel=%s, file=%r",
|
||||
msg.channel_name,
|
||||
filename,
|
||||
)
|
||||
continue
|
||||
|
||||
dest = uploads_dir / safe_name
|
||||
try:
|
||||
dest.write_bytes(data)
|
||||
except Exception:
|
||||
logger.exception("[Manager] failed to write inbound file: %s", dest)
|
||||
continue
|
||||
|
||||
created.append(
|
||||
{
|
||||
"filename": safe_name,
|
||||
"size": len(data),
|
||||
"path": f"/mnt/user-data/uploads/{safe_name}",
|
||||
"is_image": ftype == "image",
|
||||
}
|
||||
)
|
||||
|
||||
return created
|
||||
|
||||
|
||||
def _format_uploaded_files_block(files: list[dict[str, Any]]) -> str:
|
||||
lines = [
|
||||
"<uploaded_files>",
|
||||
"The following files were uploaded in this message:",
|
||||
"",
|
||||
]
|
||||
if not files:
|
||||
lines.append("(empty)")
|
||||
else:
|
||||
for f in files:
|
||||
filename = f.get("filename", "")
|
||||
size = int(f.get("size") or 0)
|
||||
size_kb = size / 1024 if size else 0
|
||||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||||
path = f.get("path", "")
|
||||
is_image = bool(f.get("is_image"))
|
||||
file_kind = "image" if is_image else "file"
|
||||
lines.append(f"- {filename} ({size_str})")
|
||||
lines.append(f" Type: {file_kind}")
|
||||
lines.append(f" Path: {path}")
|
||||
lines.append("")
|
||||
lines.append("Use `read_file` for text-based files and documents.")
|
||||
lines.append("Use `view_image` for image files (jpg, jpeg, png, webp) so the model can inspect the image content.")
|
||||
lines.append("</uploaded_files>")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class ChannelManager:
|
||||
"""Core dispatcher that bridges IM channels to the DeerFlow agent.
|
||||
|
||||
@ -536,6 +677,11 @@ class ChannelManager:
|
||||
assistant_id, run_config, run_context = self._resolve_run_params(msg, thread_id)
|
||||
if extra_context:
|
||||
run_context.update(extra_context)
|
||||
|
||||
uploaded = await _ingest_inbound_files(thread_id, msg)
|
||||
if uploaded:
|
||||
msg.text = f"{_format_uploaded_files_block(uploaded)}\n\n{msg.text}".strip()
|
||||
|
||||
if self._channel_supports_streaming(msg.channel_name):
|
||||
await self._handle_streaming_chat(
|
||||
client,
|
||||
|
||||
@ -17,6 +17,7 @@ _CHANNEL_REGISTRY: dict[str, str] = {
|
||||
"feishu": "app.channels.feishu:FeishuChannel",
|
||||
"slack": "app.channels.slack:SlackChannel",
|
||||
"telegram": "app.channels.telegram:TelegramChannel",
|
||||
"wecom": "app.channels.wecom:WeComChannel",
|
||||
}
|
||||
|
||||
_CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
|
||||
|
||||
@ -30,7 +30,7 @@ class SlackChannel(Channel):
|
||||
self._socket_client = None
|
||||
self._web_client = None
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._allowed_users: set[str] = set(config.get("allowed_users", []))
|
||||
self._allowed_users: set[str] = {str(user_id) for user_id in config.get("allowed_users", [])}
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
@ -126,7 +126,9 @@ class SlackChannel(Channel):
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
raise last_exc # type: ignore[misc]
|
||||
if last_exc is None:
|
||||
raise RuntimeError("Slack send failed without an exception from any attempt")
|
||||
raise last_exc
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
if not self._web_client:
|
||||
|
||||
@ -125,7 +125,9 @@ class TelegramChannel(Channel):
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
logger.error("[Telegram] send failed after %d attempts: %s", _max_retries, last_exc)
|
||||
raise last_exc # type: ignore[misc]
|
||||
if last_exc is None:
|
||||
raise RuntimeError("Telegram send failed without an exception from any attempt")
|
||||
raise last_exc
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
if not self._application:
|
||||
|
||||
394
backend/app/channels/wecom.py
Normal file
394
backend/app/channels/wecom.py
Normal file
@ -0,0 +1,394 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.message_bus import (
|
||||
InboundMessageType,
|
||||
MessageBus,
|
||||
OutboundMessage,
|
||||
ResolvedAttachment,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WeComChannel(Channel):
|
||||
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
|
||||
super().__init__(name="wecom", bus=bus, config=config)
|
||||
self._bot_id: str | None = None
|
||||
self._bot_secret: str | None = None
|
||||
self._ws_client = None
|
||||
self._ws_task: asyncio.Task | None = None
|
||||
self._ws_frames: dict[str, dict[str, Any]] = {}
|
||||
self._ws_stream_ids: dict[str, str] = {}
|
||||
self._working_message = "Working on it..."
|
||||
|
||||
def _clear_ws_context(self, thread_ts: str | None) -> None:
|
||||
if not thread_ts:
|
||||
return
|
||||
self._ws_frames.pop(thread_ts, None)
|
||||
self._ws_stream_ids.pop(thread_ts, None)
|
||||
|
||||
async def _send_ws_upload_command(self, req_id: str, body: dict[str, Any], cmd: str) -> dict[str, Any]:
|
||||
if not self._ws_client:
|
||||
raise RuntimeError("WeCom WebSocket client is not available")
|
||||
|
||||
ws_manager = getattr(self._ws_client, "_ws_manager", None)
|
||||
send_reply = getattr(ws_manager, "send_reply", None)
|
||||
if not callable(send_reply):
|
||||
raise RuntimeError("Installed wecom-aibot-python-sdk does not expose the WebSocket media upload API expected by DeerFlow. Use wecom-aibot-python-sdk==0.1.6 or update the adapter.")
|
||||
|
||||
send_reply_async = cast(Callable[[str, dict[str, Any], str], Awaitable[dict[str, Any]]], send_reply)
|
||||
return await send_reply_async(req_id, body, cmd)
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
return
|
||||
|
||||
bot_id = self.config.get("bot_id")
|
||||
bot_secret = self.config.get("bot_secret")
|
||||
working_message = self.config.get("working_message")
|
||||
|
||||
self._bot_id = bot_id if isinstance(bot_id, str) and bot_id else None
|
||||
self._bot_secret = bot_secret if isinstance(bot_secret, str) and bot_secret else None
|
||||
self._working_message = working_message if isinstance(working_message, str) and working_message else "Working on it..."
|
||||
|
||||
if not self._bot_id or not self._bot_secret:
|
||||
logger.error("WeCom channel requires bot_id and bot_secret")
|
||||
return
|
||||
|
||||
try:
|
||||
from aibot import WSClient, WSClientOptions
|
||||
except ImportError:
|
||||
logger.error("wecom-aibot-python-sdk is not installed. Install it with: uv add wecom-aibot-python-sdk")
|
||||
return
|
||||
else:
|
||||
self._ws_client = WSClient(WSClientOptions(bot_id=self._bot_id, secret=self._bot_secret, logger=logger))
|
||||
self._ws_client.on("message.text", self._on_ws_text)
|
||||
self._ws_client.on("message.mixed", self._on_ws_mixed)
|
||||
self._ws_client.on("message.image", self._on_ws_image)
|
||||
self._ws_client.on("message.file", self._on_ws_file)
|
||||
self._ws_task = asyncio.create_task(self._ws_client.connect())
|
||||
|
||||
self._running = True
|
||||
self.bus.subscribe_outbound(self._on_outbound)
|
||||
logger.info("WeCom channel started")
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
self.bus.unsubscribe_outbound(self._on_outbound)
|
||||
if self._ws_task:
|
||||
try:
|
||||
self._ws_task.cancel()
|
||||
except Exception:
|
||||
pass
|
||||
self._ws_task = None
|
||||
if self._ws_client:
|
||||
try:
|
||||
self._ws_client.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
self._ws_client = None
|
||||
self._ws_frames.clear()
|
||||
self._ws_stream_ids.clear()
|
||||
logger.info("WeCom channel stopped")
|
||||
|
||||
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||||
if self._ws_client:
|
||||
await self._send_ws(msg, _max_retries=_max_retries)
|
||||
return
|
||||
logger.warning("[WeCom] send called but WebSocket client is not available")
|
||||
|
||||
async def _on_outbound(self, msg: OutboundMessage) -> None:
|
||||
if msg.channel_name != self.name:
|
||||
return
|
||||
|
||||
try:
|
||||
await self.send(msg)
|
||||
except Exception:
|
||||
logger.exception("Failed to send outbound message on channel %s", self.name)
|
||||
if msg.is_final:
|
||||
self._clear_ws_context(msg.thread_ts)
|
||||
return
|
||||
|
||||
for attachment in msg.attachments:
|
||||
try:
|
||||
success = await self.send_file(msg, attachment)
|
||||
if not success:
|
||||
logger.warning("[%s] file upload skipped for %s", self.name, attachment.filename)
|
||||
except Exception:
|
||||
logger.exception("[%s] failed to upload file %s", self.name, attachment.filename)
|
||||
|
||||
if msg.is_final:
|
||||
self._clear_ws_context(msg.thread_ts)
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
if not msg.is_final:
|
||||
return True
|
||||
if not self._ws_client:
|
||||
return False
|
||||
if not msg.thread_ts:
|
||||
return False
|
||||
frame = self._ws_frames.get(msg.thread_ts)
|
||||
if not frame:
|
||||
return False
|
||||
|
||||
media_type = "image" if attachment.is_image else "file"
|
||||
size_limit = 2 * 1024 * 1024 if attachment.is_image else 20 * 1024 * 1024
|
||||
if attachment.size > size_limit:
|
||||
logger.warning(
|
||||
"[WeCom] %s too large (%d bytes), skipping: %s",
|
||||
media_type,
|
||||
attachment.size,
|
||||
attachment.filename,
|
||||
)
|
||||
return False
|
||||
|
||||
try:
|
||||
media_id = await self._upload_media_ws(
|
||||
media_type=media_type,
|
||||
filename=attachment.filename,
|
||||
path=str(attachment.actual_path),
|
||||
size=attachment.size,
|
||||
)
|
||||
if not media_id:
|
||||
return False
|
||||
|
||||
body = {media_type: {"media_id": media_id}, "msgtype": media_type}
|
||||
await self._ws_client.reply(frame, body)
|
||||
logger.debug("[WeCom] %s sent via ws: %s", media_type, attachment.filename)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("[WeCom] failed to upload/send file via ws: %s", attachment.filename)
|
||||
return False
|
||||
|
||||
async def _on_ws_text(self, frame: dict[str, Any]) -> None:
|
||||
body = frame.get("body", {}) or {}
|
||||
text = ((body.get("text") or {}).get("content") or "").strip()
|
||||
quote = body.get("quote", {}).get("text", {}).get("content", "").strip()
|
||||
if not text and not quote:
|
||||
return
|
||||
await self._publish_ws_inbound(frame, text + (f"\nQuote message: {quote}" if quote else ""))
|
||||
|
||||
async def _on_ws_mixed(self, frame: dict[str, Any]) -> None:
|
||||
body = frame.get("body", {}) or {}
|
||||
mixed = body.get("mixed") or {}
|
||||
items = mixed.get("msg_item") or []
|
||||
parts: list[str] = []
|
||||
files: list[dict[str, Any]] = []
|
||||
for item in items:
|
||||
item_type = (item or {}).get("msgtype")
|
||||
if item_type == "text":
|
||||
content = (((item or {}).get("text") or {}).get("content") or "").strip()
|
||||
if content:
|
||||
parts.append(content)
|
||||
elif item_type in ("image", "file"):
|
||||
payload = (item or {}).get(item_type) or {}
|
||||
url = payload.get("url")
|
||||
aeskey = payload.get("aeskey")
|
||||
if isinstance(url, str) and url:
|
||||
files.append(
|
||||
{
|
||||
"type": item_type,
|
||||
"url": url,
|
||||
"aeskey": (aeskey if isinstance(aeskey, str) and aeskey else None),
|
||||
}
|
||||
)
|
||||
text = "\n\n".join(parts).strip()
|
||||
if not text and not files:
|
||||
return
|
||||
if not text:
|
||||
text = "(receive image/file)"
|
||||
await self._publish_ws_inbound(frame, text, files=files)
|
||||
|
||||
async def _on_ws_image(self, frame: dict[str, Any]) -> None:
|
||||
body = frame.get("body", {}) or {}
|
||||
image = body.get("image") or {}
|
||||
url = image.get("url")
|
||||
aeskey = image.get("aeskey")
|
||||
if not isinstance(url, str) or not url:
|
||||
return
|
||||
await self._publish_ws_inbound(
|
||||
frame,
|
||||
"(receive image )",
|
||||
files=[
|
||||
{
|
||||
"type": "image",
|
||||
"url": url,
|
||||
"aeskey": aeskey if isinstance(aeskey, str) and aeskey else None,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
async def _on_ws_file(self, frame: dict[str, Any]) -> None:
|
||||
body = frame.get("body", {}) or {}
|
||||
file_obj = body.get("file") or {}
|
||||
url = file_obj.get("url")
|
||||
aeskey = file_obj.get("aeskey")
|
||||
if not isinstance(url, str) or not url:
|
||||
return
|
||||
await self._publish_ws_inbound(
|
||||
frame,
|
||||
"(receive file)",
|
||||
files=[
|
||||
{
|
||||
"type": "file",
|
||||
"url": url,
|
||||
"aeskey": aeskey if isinstance(aeskey, str) and aeskey else None,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
async def _publish_ws_inbound(
|
||||
self,
|
||||
frame: dict[str, Any],
|
||||
text: str,
|
||||
*,
|
||||
files: list[dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
if not self._ws_client:
|
||||
return
|
||||
try:
|
||||
from aibot import generate_req_id
|
||||
except Exception:
|
||||
return
|
||||
|
||||
body = frame.get("body", {}) or {}
|
||||
msg_id = body.get("msgid")
|
||||
if not msg_id:
|
||||
return
|
||||
|
||||
user_id = (body.get("from") or {}).get("userid")
|
||||
|
||||
inbound_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
||||
inbound = self._make_inbound(
|
||||
chat_id=user_id, # keep user's conversation in memory
|
||||
user_id=user_id,
|
||||
text=text,
|
||||
msg_type=inbound_type,
|
||||
thread_ts=msg_id,
|
||||
files=files or [],
|
||||
metadata={"aibotid": body.get("aibotid"), "chattype": body.get("chattype")},
|
||||
)
|
||||
inbound.topic_id = user_id # keep the same thread
|
||||
|
||||
stream_id = generate_req_id("stream")
|
||||
self._ws_frames[msg_id] = frame
|
||||
self._ws_stream_ids[msg_id] = stream_id
|
||||
|
||||
try:
|
||||
await self._ws_client.reply_stream(frame, stream_id, self._working_message, False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||||
if not self._ws_client:
|
||||
return
|
||||
try:
|
||||
from aibot import generate_req_id
|
||||
except Exception:
|
||||
generate_req_id = None
|
||||
|
||||
if msg.thread_ts and msg.thread_ts in self._ws_frames:
|
||||
frame = self._ws_frames[msg.thread_ts]
|
||||
stream_id = self._ws_stream_ids.get(msg.thread_ts)
|
||||
if not stream_id and generate_req_id:
|
||||
stream_id = generate_req_id("stream")
|
||||
self._ws_stream_ids[msg.thread_ts] = stream_id
|
||||
if not stream_id:
|
||||
return
|
||||
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(_max_retries):
|
||||
try:
|
||||
await self._ws_client.reply_stream(frame, stream_id, msg.text, bool(msg.is_final))
|
||||
return
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
if attempt < _max_retries - 1:
|
||||
await asyncio.sleep(2**attempt)
|
||||
if last_exc:
|
||||
raise last_exc
|
||||
|
||||
body = {"msgtype": "markdown", "markdown": {"content": msg.text}}
|
||||
last_exc = None
|
||||
for attempt in range(_max_retries):
|
||||
try:
|
||||
await self._ws_client.send_message(msg.chat_id, body)
|
||||
return
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
if attempt < _max_retries - 1:
|
||||
await asyncio.sleep(2**attempt)
|
||||
if last_exc:
|
||||
raise last_exc
|
||||
|
||||
async def _upload_media_ws(
|
||||
self,
|
||||
*,
|
||||
media_type: str,
|
||||
filename: str,
|
||||
path: str,
|
||||
size: int,
|
||||
) -> str | None:
|
||||
if not self._ws_client:
|
||||
return None
|
||||
try:
|
||||
from aibot import generate_req_id
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
chunk_size = 512 * 1024
|
||||
total_chunks = (size + chunk_size - 1) // chunk_size
|
||||
if total_chunks < 1 or total_chunks > 100:
|
||||
logger.warning("[WeCom] invalid total_chunks=%d for %s", total_chunks, filename)
|
||||
return None
|
||||
|
||||
md5_hasher = hashlib.md5()
|
||||
with open(path, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(1024 * 1024), b""):
|
||||
md5_hasher.update(chunk)
|
||||
md5 = md5_hasher.hexdigest()
|
||||
|
||||
init_req_id = generate_req_id("aibot_upload_media_init")
|
||||
init_body = {
|
||||
"type": media_type,
|
||||
"filename": filename,
|
||||
"total_size": int(size),
|
||||
"total_chunks": int(total_chunks),
|
||||
"md5": md5,
|
||||
}
|
||||
init_ack = await self._send_ws_upload_command(init_req_id, init_body, "aibot_upload_media_init")
|
||||
upload_id = (init_ack.get("body") or {}).get("upload_id")
|
||||
if not upload_id:
|
||||
logger.warning("[WeCom] upload init returned no upload_id: %s", init_ack)
|
||||
return None
|
||||
|
||||
with open(path, "rb") as f:
|
||||
for idx in range(total_chunks):
|
||||
data = f.read(chunk_size)
|
||||
if not data:
|
||||
break
|
||||
chunk_req_id = generate_req_id("aibot_upload_media_chunk")
|
||||
chunk_body = {
|
||||
"upload_id": upload_id,
|
||||
"chunk_index": int(idx),
|
||||
"base64_data": base64.b64encode(data).decode("utf-8"),
|
||||
}
|
||||
await self._send_ws_upload_command(chunk_req_id, chunk_body, "aibot_upload_media_chunk")
|
||||
|
||||
finish_req_id = generate_req_id("aibot_upload_media_finish")
|
||||
finish_ack = await self._send_ws_upload_command(finish_req_id, {"upload_id": upload_id}, "aibot_upload_media_finish")
|
||||
media_id = (finish_ack.get("body") or {}).get("media_id")
|
||||
if not media_id:
|
||||
logger.warning("[WeCom] upload finish returned no media_id: %s", finish_ack)
|
||||
return None
|
||||
return media_id
|
||||
@ -1,15 +1,21 @@
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import UTC
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.gateway.auth_middleware import AuthMiddleware
|
||||
from app.gateway.config import get_gateway_config
|
||||
from app.gateway.csrf_middleware import CSRFMiddleware
|
||||
from app.gateway.deps import langgraph_runtime
|
||||
from app.gateway.routers import (
|
||||
agents,
|
||||
artifacts,
|
||||
assistants_compat,
|
||||
auth,
|
||||
channels,
|
||||
mcp,
|
||||
memory,
|
||||
@ -33,6 +39,88 @@ logging.basicConfig(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _ensure_admin_user(app: FastAPI) -> None:
|
||||
"""Auto-create the admin user on first boot if no users exist.
|
||||
|
||||
Prints the generated password to stdout so the operator can log in.
|
||||
On subsequent boots, warns if any user still needs setup.
|
||||
|
||||
Multi-worker safe: relies on SQLite UNIQUE constraint to resolve races.
|
||||
Only the worker that successfully creates/updates the admin prints the
|
||||
password; losers silently skip.
|
||||
"""
|
||||
import secrets
|
||||
|
||||
from app.gateway.deps import get_local_provider
|
||||
|
||||
provider = get_local_provider()
|
||||
user_count = await provider.count_users()
|
||||
|
||||
if user_count == 0:
|
||||
password = secrets.token_urlsafe(16)
|
||||
try:
|
||||
admin = await provider.create_user(email="admin@deerflow.dev", password=password, system_role="admin", needs_setup=True)
|
||||
except ValueError:
|
||||
return # Another worker already created the admin.
|
||||
|
||||
# Migrate orphaned threads (no user_id) to this admin
|
||||
store = getattr(app.state, "store", None)
|
||||
if store is not None:
|
||||
await _migrate_orphaned_threads(store, str(admin.id))
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info(" Admin account created on first boot")
|
||||
logger.info(" Email: %s", admin.email)
|
||||
logger.info(" Password: %s", password)
|
||||
logger.info(" Change it after login: Settings -> Account")
|
||||
logger.info("=" * 60)
|
||||
return
|
||||
|
||||
# Admin exists but setup never completed — reset password so operator
|
||||
# can always find it in the console without needing the CLI.
|
||||
# Multi-worker guard: if admin was created less than 5s ago, another
|
||||
# worker just created it and will print the password — skip reset.
|
||||
admin = await provider.get_user_by_email("admin@deerflow.dev")
|
||||
if admin and admin.needs_setup:
|
||||
import time
|
||||
|
||||
age = time.time() - admin.created_at.replace(tzinfo=UTC).timestamp()
|
||||
if age < 30:
|
||||
return # Just created by another worker in this startup; its password is still valid.
|
||||
|
||||
from app.gateway.auth.password import hash_password_async
|
||||
|
||||
password = secrets.token_urlsafe(16)
|
||||
admin.password_hash = await hash_password_async(password)
|
||||
admin.token_version += 1
|
||||
await provider.update_user(admin)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info(" Admin account setup incomplete — password reset")
|
||||
logger.info(" Email: %s", admin.email)
|
||||
logger.info(" Password: %s", password)
|
||||
logger.info(" Change it after login: Settings -> Account")
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
async def _migrate_orphaned_threads(store, admin_user_id: str) -> None:
|
||||
"""Migrate threads with no user_id to the given admin."""
|
||||
try:
|
||||
migrated = 0
|
||||
results = await store.asearch(("threads",), limit=1000)
|
||||
for item in results:
|
||||
metadata = item.value.get("metadata", {})
|
||||
if not metadata.get("user_id"):
|
||||
metadata["user_id"] = admin_user_id
|
||||
item.value["metadata"] = metadata
|
||||
await store.aput(("threads",), item.key, item.value)
|
||||
migrated += 1
|
||||
if migrated:
|
||||
logger.info("Migrated %d orphaned thread(s) to admin", migrated)
|
||||
except Exception:
|
||||
logger.exception("Thread migration failed (non-fatal)")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application lifespan handler."""
|
||||
@ -52,6 +140,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
async with langgraph_runtime(app):
|
||||
logger.info("LangGraph runtime initialised")
|
||||
|
||||
# Ensure admin user exists (auto-create on first boot)
|
||||
# Must run AFTER langgraph_runtime so app.state.store is available for thread migration
|
||||
await _ensure_admin_user(app)
|
||||
|
||||
# Start IM channel service if any channels are configured
|
||||
try:
|
||||
from app.channels.service import start_channel_service
|
||||
@ -163,7 +255,30 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
|
||||
],
|
||||
)
|
||||
|
||||
# CORS is handled by nginx - no need for FastAPI middleware
|
||||
# Auth: reject unauthenticated requests to non-public paths (fail-closed safety net)
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
# CSRF: Double Submit Cookie pattern for state-changing requests
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
# CORS: when GATEWAY_CORS_ORIGINS is set (dev without nginx), add CORS middleware
|
||||
cors_origins_env = os.environ.get("GATEWAY_CORS_ORIGINS", "")
|
||||
if cors_origins_env:
|
||||
cors_origins = [o.strip() for o in cors_origins_env.split(",") if o.strip()]
|
||||
# Validate: wildcard origin with credentials is a security misconfiguration
|
||||
for origin in cors_origins:
|
||||
if origin == "*":
|
||||
logger.error("GATEWAY_CORS_ORIGINS contains wildcard '*' with allow_credentials=True. This is a security misconfiguration — browsers will reject the response. Use explicit scheme://host:port origins instead.")
|
||||
cors_origins = [o for o in cors_origins if o != "*"]
|
||||
break
|
||||
if cors_origins:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
# Models API is mounted at /api/models
|
||||
@ -199,6 +314,9 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
|
||||
# Assistants compatibility API (LangGraph Platform stub)
|
||||
app.include_router(assistants_compat.router)
|
||||
|
||||
# Auth API is mounted at /api/v1/auth
|
||||
app.include_router(auth.router)
|
||||
|
||||
# Thread Runs API (LangGraph Platform-compatible runs lifecycle)
|
||||
app.include_router(thread_runs.router)
|
||||
|
||||
|
||||
42
backend/app/gateway/auth/__init__.py
Normal file
42
backend/app/gateway/auth/__init__.py
Normal file
@ -0,0 +1,42 @@
|
||||
"""Authentication module for DeerFlow.
|
||||
|
||||
This module provides:
|
||||
- JWT-based authentication
|
||||
- Provider Factory pattern for extensible auth methods
|
||||
- UserRepository interface for storage backends (SQLite)
|
||||
"""
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, get_auth_config, set_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.gateway.auth.jwt import TokenPayload, create_access_token, decode_token
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
from app.gateway.auth.models import User, UserResponse
|
||||
from app.gateway.auth.password import hash_password, verify_password
|
||||
from app.gateway.auth.providers import AuthProvider
|
||||
from app.gateway.auth.repositories.base import UserRepository
|
||||
|
||||
__all__ = [
|
||||
# Config
|
||||
"AuthConfig",
|
||||
"get_auth_config",
|
||||
"set_auth_config",
|
||||
# Errors
|
||||
"AuthErrorCode",
|
||||
"AuthErrorResponse",
|
||||
"TokenError",
|
||||
# JWT
|
||||
"TokenPayload",
|
||||
"create_access_token",
|
||||
"decode_token",
|
||||
# Password
|
||||
"hash_password",
|
||||
"verify_password",
|
||||
# Models
|
||||
"User",
|
||||
"UserResponse",
|
||||
# Providers
|
||||
"AuthProvider",
|
||||
"LocalAuthProvider",
|
||||
# Repository
|
||||
"UserRepository",
|
||||
]
|
||||
55
backend/app/gateway/auth/config.py
Normal file
55
backend/app/gateway/auth/config.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""Authentication configuration for DeerFlow."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthConfig(BaseModel):
|
||||
"""JWT and auth-related configuration. Parsed once at startup."""
|
||||
|
||||
jwt_secret: str = Field(
|
||||
...,
|
||||
description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.",
|
||||
)
|
||||
token_expiry_days: int = Field(default=7, ge=1, le=30)
|
||||
users_db_path: str | None = Field(
|
||||
default=None,
|
||||
description="Path to users SQLite DB. Defaults to .deer-flow/users.db",
|
||||
)
|
||||
oauth_github_client_id: str | None = Field(default=None)
|
||||
oauth_github_client_secret: str | None = Field(default=None)
|
||||
|
||||
|
||||
_auth_config: AuthConfig | None = None
|
||||
|
||||
|
||||
def get_auth_config() -> AuthConfig:
|
||||
"""Get the global AuthConfig instance. Parses from env on first call."""
|
||||
global _auth_config
|
||||
if _auth_config is None:
|
||||
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
|
||||
if not jwt_secret:
|
||||
jwt_secret = secrets.token_urlsafe(32)
|
||||
os.environ["AUTH_JWT_SECRET"] = jwt_secret
|
||||
logger.warning(
|
||||
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
|
||||
"Sessions will be invalidated on restart. "
|
||||
"For production, add AUTH_JWT_SECRET to your .env file: "
|
||||
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
||||
)
|
||||
_auth_config = AuthConfig(jwt_secret=jwt_secret)
|
||||
return _auth_config
|
||||
|
||||
|
||||
def set_auth_config(config: AuthConfig) -> None:
|
||||
"""Set the global AuthConfig instance (for testing)."""
|
||||
global _auth_config
|
||||
_auth_config = config
|
||||
44
backend/app/gateway/auth/errors.py
Normal file
44
backend/app/gateway/auth/errors.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""Typed error definitions for auth module.
|
||||
|
||||
AuthErrorCode: exhaustive enum of all auth failure conditions.
|
||||
TokenError: exhaustive enum of JWT decode failures.
|
||||
AuthErrorResponse: structured error payload for HTTP responses.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AuthErrorCode(StrEnum):
|
||||
"""Exhaustive list of auth error conditions."""
|
||||
|
||||
INVALID_CREDENTIALS = "invalid_credentials"
|
||||
TOKEN_EXPIRED = "token_expired"
|
||||
TOKEN_INVALID = "token_invalid"
|
||||
USER_NOT_FOUND = "user_not_found"
|
||||
EMAIL_ALREADY_EXISTS = "email_already_exists"
|
||||
PROVIDER_NOT_FOUND = "provider_not_found"
|
||||
NOT_AUTHENTICATED = "not_authenticated"
|
||||
|
||||
|
||||
class TokenError(StrEnum):
|
||||
"""Exhaustive list of JWT decode failure reasons."""
|
||||
|
||||
EXPIRED = "expired"
|
||||
INVALID_SIGNATURE = "invalid_signature"
|
||||
MALFORMED = "malformed"
|
||||
|
||||
|
||||
class AuthErrorResponse(BaseModel):
|
||||
"""Structured error response — replaces bare `detail` strings."""
|
||||
|
||||
code: AuthErrorCode
|
||||
message: str
|
||||
|
||||
|
||||
def token_error_to_code(err: TokenError) -> AuthErrorCode:
|
||||
"""Map TokenError to AuthErrorCode — single source of truth."""
|
||||
if err == TokenError.EXPIRED:
|
||||
return AuthErrorCode.TOKEN_EXPIRED
|
||||
return AuthErrorCode.TOKEN_INVALID
|
||||
55
backend/app/gateway/auth/jwt.py
Normal file
55
backend/app/gateway/auth/jwt.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""JWT token creation and verification."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.gateway.auth.config import get_auth_config
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
"""JWT token payload."""
|
||||
|
||||
sub: str # user_id
|
||||
exp: datetime
|
||||
iat: datetime | None = None
|
||||
ver: int = 0 # token_version — must match User.token_version
|
||||
|
||||
|
||||
def create_access_token(user_id: str, expires_delta: timedelta | None = None, token_version: int = 0) -> str:
|
||||
"""Create a JWT access token.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID as string
|
||||
expires_delta: Optional custom expiry, defaults to 7 days
|
||||
token_version: User's current token_version for invalidation
|
||||
|
||||
Returns:
|
||||
Encoded JWT string
|
||||
"""
|
||||
config = get_auth_config()
|
||||
expiry = expires_delta or timedelta(days=config.token_expiry_days)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
payload = {"sub": user_id, "exp": now + expiry, "iat": now, "ver": token_version}
|
||||
return jwt.encode(payload, config.jwt_secret, algorithm="HS256")
|
||||
|
||||
|
||||
def decode_token(token: str) -> TokenPayload | TokenError:
|
||||
"""Decode and validate a JWT token.
|
||||
|
||||
Returns:
|
||||
TokenPayload if valid, or a specific TokenError variant.
|
||||
"""
|
||||
config = get_auth_config()
|
||||
try:
|
||||
payload = jwt.decode(token, config.jwt_secret, algorithms=["HS256"])
|
||||
return TokenPayload(**payload)
|
||||
except jwt.ExpiredSignatureError:
|
||||
return TokenError.EXPIRED
|
||||
except jwt.InvalidSignatureError:
|
||||
return TokenError.INVALID_SIGNATURE
|
||||
except jwt.PyJWTError:
|
||||
return TokenError.MALFORMED
|
||||
87
backend/app/gateway/auth/local_provider.py
Normal file
87
backend/app/gateway/auth/local_provider.py
Normal file
@ -0,0 +1,87 @@
|
||||
"""Local email/password authentication provider."""
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.auth.password import hash_password_async, verify_password_async
|
||||
from app.gateway.auth.providers import AuthProvider
|
||||
from app.gateway.auth.repositories.base import UserRepository
|
||||
|
||||
|
||||
class LocalAuthProvider(AuthProvider):
|
||||
"""Email/password authentication provider using local database."""
|
||||
|
||||
def __init__(self, repository: UserRepository):
|
||||
"""Initialize with a UserRepository.
|
||||
|
||||
Args:
|
||||
repository: UserRepository implementation (SQLite)
|
||||
"""
|
||||
self._repo = repository
|
||||
|
||||
async def authenticate(self, credentials: dict) -> User | None:
|
||||
"""Authenticate with email and password.
|
||||
|
||||
Args:
|
||||
credentials: dict with 'email' and 'password' keys
|
||||
|
||||
Returns:
|
||||
User if authentication succeeds, None otherwise
|
||||
"""
|
||||
email = credentials.get("email")
|
||||
password = credentials.get("password")
|
||||
|
||||
if not email or not password:
|
||||
return None
|
||||
|
||||
user = await self._repo.get_user_by_email(email)
|
||||
if user is None:
|
||||
return None
|
||||
|
||||
if user.password_hash is None:
|
||||
# OAuth user without local password
|
||||
return None
|
||||
|
||||
if not await verify_password_async(password, user.password_hash):
|
||||
return None
|
||||
|
||||
return user
|
||||
|
||||
async def get_user(self, user_id: str) -> User | None:
|
||||
"""Get user by ID."""
|
||||
return await self._repo.get_user_by_id(user_id)
|
||||
|
||||
async def create_user(self, email: str, password: str | None = None, system_role: str = "user", needs_setup: bool = False) -> User:
|
||||
"""Create a new local user.
|
||||
|
||||
Args:
|
||||
email: User email address
|
||||
password: Plain text password (will be hashed)
|
||||
system_role: Role to assign ("admin" or "user")
|
||||
needs_setup: If True, user must complete setup on first login
|
||||
|
||||
Returns:
|
||||
Created User instance
|
||||
"""
|
||||
password_hash = await hash_password_async(password) if password else None
|
||||
user = User(
|
||||
email=email,
|
||||
password_hash=password_hash,
|
||||
system_role=system_role,
|
||||
needs_setup=needs_setup,
|
||||
)
|
||||
return await self._repo.create_user(user)
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Get user by OAuth provider and ID."""
|
||||
return await self._repo.get_user_by_oauth(provider, oauth_id)
|
||||
|
||||
async def count_users(self) -> int:
|
||||
"""Return total number of registered users."""
|
||||
return await self._repo.count_users()
|
||||
|
||||
async def update_user(self, user: User) -> User:
|
||||
"""Update an existing user."""
|
||||
return await self._repo.update_user(user)
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
"""Get user by email."""
|
||||
return await self._repo.get_user_by_email(email)
|
||||
41
backend/app/gateway/auth/models.py
Normal file
41
backend/app/gateway/auth/models.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""User Pydantic models for authentication."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Literal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
"""Return current UTC time (timezone-aware)."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""Internal user representation."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, description="Primary key")
|
||||
email: EmailStr = Field(..., description="Unique email address")
|
||||
password_hash: str | None = Field(None, description="bcrypt hash, nullable for OAuth users")
|
||||
system_role: Literal["admin", "user"] = Field(default="user")
|
||||
created_at: datetime = Field(default_factory=_utc_now)
|
||||
|
||||
# OAuth linkage (optional)
|
||||
oauth_provider: str | None = Field(None, description="e.g. 'github', 'google'")
|
||||
oauth_id: str | None = Field(None, description="User ID from OAuth provider")
|
||||
|
||||
# Auth lifecycle
|
||||
needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes")
|
||||
token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs")
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""Response model for user info endpoint."""
|
||||
|
||||
id: str
|
||||
email: str
|
||||
system_role: Literal["admin", "user"]
|
||||
needs_setup: bool = False
|
||||
33
backend/app/gateway/auth/password.py
Normal file
33
backend/app/gateway/auth/password.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""Password hashing utilities using bcrypt directly."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import bcrypt
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password using bcrypt."""
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash."""
|
||||
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
|
||||
|
||||
|
||||
async def hash_password_async(password: str) -> str:
|
||||
"""Hash a password using bcrypt (non-blocking).
|
||||
|
||||
Wraps the blocking bcrypt operation in a thread pool to avoid
|
||||
blocking the event loop during password hashing.
|
||||
"""
|
||||
return await asyncio.to_thread(hash_password, password)
|
||||
|
||||
|
||||
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash (non-blocking).
|
||||
|
||||
Wraps the blocking bcrypt operation in a thread pool to avoid
|
||||
blocking the event loop during password verification.
|
||||
"""
|
||||
return await asyncio.to_thread(verify_password, plain_password, hashed_password)
|
||||
24
backend/app/gateway/auth/providers.py
Normal file
24
backend/app/gateway/auth/providers.py
Normal file
@ -0,0 +1,24 @@
|
||||
"""Auth provider abstraction."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class AuthProvider(ABC):
|
||||
"""Abstract base class for authentication providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def authenticate(self, credentials: dict) -> "User | None":
|
||||
"""Authenticate user with given credentials.
|
||||
|
||||
Returns User if authentication succeeds, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user(self, user_id: str) -> "User | None":
|
||||
"""Retrieve user by ID."""
|
||||
...
|
||||
|
||||
|
||||
# Import User at runtime to avoid circular imports
|
||||
from app.gateway.auth.models import User # noqa: E402
|
||||
0
backend/app/gateway/auth/repositories/__init__.py
Normal file
0
backend/app/gateway/auth/repositories/__init__.py
Normal file
82
backend/app/gateway/auth/repositories/base.py
Normal file
82
backend/app/gateway/auth/repositories/base.py
Normal file
@ -0,0 +1,82 @@
|
||||
"""User repository interface for abstracting database operations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
|
||||
|
||||
class UserRepository(ABC):
|
||||
"""Abstract interface for user data storage.
|
||||
|
||||
Implement this interface to support different storage backends
|
||||
(SQLite)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def create_user(self, user: User) -> User:
|
||||
"""Create a new user.
|
||||
|
||||
Args:
|
||||
user: User object to create
|
||||
|
||||
Returns:
|
||||
Created User with ID assigned
|
||||
|
||||
Raises:
|
||||
ValueError: If email already exists
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
"""Get user by ID.
|
||||
|
||||
Args:
|
||||
user_id: User UUID as string
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
"""Get user by email.
|
||||
|
||||
Args:
|
||||
email: User email address
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def update_user(self, user: User) -> User:
|
||||
"""Update an existing user.
|
||||
|
||||
Args:
|
||||
user: User object with updated fields
|
||||
|
||||
Returns:
|
||||
Updated User
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def count_users(self) -> int:
|
||||
"""Return total number of registered users."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Get user by OAuth provider and ID.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider name (e.g. 'github', 'google')
|
||||
oauth_id: User ID from the OAuth provider
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
...
|
||||
196
backend/app/gateway/auth/repositories/sqlite.py
Normal file
196
backend/app/gateway/auth/repositories/sqlite.py
Normal file
@ -0,0 +1,196 @@
|
||||
"""SQLite implementation of UserRepository."""
|
||||
|
||||
import asyncio
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.gateway.auth.config import get_auth_config
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.auth.repositories.base import UserRepository
|
||||
|
||||
_resolved_db_path: Path | None = None
|
||||
_table_initialized: bool = False
|
||||
|
||||
|
||||
def _get_users_db_path() -> Path:
|
||||
"""Get the users database path (resolved and cached once)."""
|
||||
global _resolved_db_path
|
||||
if _resolved_db_path is not None:
|
||||
return _resolved_db_path
|
||||
config = get_auth_config()
|
||||
if config.users_db_path:
|
||||
_resolved_db_path = Path(config.users_db_path)
|
||||
else:
|
||||
_resolved_db_path = Path(".deer-flow/users.db")
|
||||
_resolved_db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return _resolved_db_path
|
||||
|
||||
|
||||
def _get_connection() -> sqlite3.Connection:
|
||||
"""Get a SQLite connection for the users database."""
|
||||
db_path = _get_users_db_path()
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
|
||||
def _init_users_table(conn: sqlite3.Connection) -> None:
|
||||
"""Initialize the users table if it doesn't exist."""
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT,
|
||||
system_role TEXT NOT NULL DEFAULT 'user',
|
||||
created_at REAL NOT NULL,
|
||||
oauth_provider TEXT,
|
||||
oauth_id TEXT,
|
||||
needs_setup INTEGER NOT NULL DEFAULT 0,
|
||||
token_version INTEGER NOT NULL DEFAULT 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
# Add unique constraint for OAuth identity to prevent duplicate social logins
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_oauth_identity
|
||||
ON users(oauth_provider, oauth_id)
|
||||
WHERE oauth_provider IS NOT NULL AND oauth_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _get_users_conn():
|
||||
"""Context manager for users database connection."""
|
||||
global _table_initialized
|
||||
conn = _get_connection()
|
||||
try:
|
||||
if not _table_initialized:
|
||||
_init_users_table(conn)
|
||||
_table_initialized = True
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
class SQLiteUserRepository(UserRepository):
|
||||
"""SQLite implementation of UserRepository."""
|
||||
|
||||
async def create_user(self, user: User) -> User:
|
||||
"""Create a new user in SQLite."""
|
||||
return await asyncio.to_thread(self._create_user_sync, user)
|
||||
|
||||
def _create_user_sync(self, user: User) -> User:
|
||||
"""Synchronous user creation (runs in thread pool)."""
|
||||
with _get_users_conn() as conn:
|
||||
try:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO users (id, email, password_hash, system_role, created_at, oauth_provider, oauth_id, needs_setup, token_version)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
str(user.id),
|
||||
user.email,
|
||||
user.password_hash,
|
||||
user.system_role,
|
||||
datetime.now(UTC).timestamp(),
|
||||
user.oauth_provider,
|
||||
user.oauth_id,
|
||||
int(user.needs_setup),
|
||||
user.token_version,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.IntegrityError as e:
|
||||
if "UNIQUE constraint failed: users.email" in str(e):
|
||||
raise ValueError(f"Email already registered: {user.email}") from e
|
||||
raise
|
||||
return user
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
"""Get user by ID from SQLite."""
|
||||
return await asyncio.to_thread(self._get_user_by_id_sync, user_id)
|
||||
|
||||
def _get_user_by_id_sync(self, user_id: str) -> User | None:
|
||||
"""Synchronous get by ID (runs in thread pool)."""
|
||||
with _get_users_conn() as conn:
|
||||
cursor = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,))
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_user(dict(row))
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
"""Get user by email from SQLite."""
|
||||
return await asyncio.to_thread(self._get_user_by_email_sync, email)
|
||||
|
||||
def _get_user_by_email_sync(self, email: str) -> User | None:
|
||||
"""Synchronous get by email (runs in thread pool)."""
|
||||
with _get_users_conn() as conn:
|
||||
cursor = conn.execute("SELECT * FROM users WHERE email = ?", (email,))
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_user(dict(row))
|
||||
|
||||
async def update_user(self, user: User) -> User:
|
||||
"""Update an existing user in SQLite."""
|
||||
return await asyncio.to_thread(self._update_user_sync, user)
|
||||
|
||||
def _update_user_sync(self, user: User) -> User:
|
||||
with _get_users_conn() as conn:
|
||||
conn.execute(
|
||||
"UPDATE users SET email = ?, password_hash = ?, system_role = ?, oauth_provider = ?, oauth_id = ?, needs_setup = ?, token_version = ? WHERE id = ?",
|
||||
(user.email, user.password_hash, user.system_role, user.oauth_provider, user.oauth_id, int(user.needs_setup), user.token_version, str(user.id)),
|
||||
)
|
||||
conn.commit()
|
||||
return user
|
||||
|
||||
async def count_users(self) -> int:
|
||||
"""Return total number of registered users."""
|
||||
return await asyncio.to_thread(self._count_users_sync)
|
||||
|
||||
def _count_users_sync(self) -> int:
|
||||
with _get_users_conn() as conn:
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM users")
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Get user by OAuth provider and ID from SQLite."""
|
||||
return await asyncio.to_thread(self._get_user_by_oauth_sync, provider, oauth_id)
|
||||
|
||||
def _get_user_by_oauth_sync(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Synchronous get by OAuth (runs in thread pool)."""
|
||||
with _get_users_conn() as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM users WHERE oauth_provider = ? AND oauth_id = ?",
|
||||
(provider, oauth_id),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_user(dict(row))
|
||||
|
||||
@staticmethod
|
||||
def _row_to_user(row: dict[str, Any]) -> User:
|
||||
"""Convert a database row to a User model."""
|
||||
return User(
|
||||
id=UUID(row["id"]),
|
||||
email=row["email"],
|
||||
password_hash=row["password_hash"],
|
||||
system_role=row["system_role"],
|
||||
created_at=datetime.fromtimestamp(row["created_at"], tz=UTC),
|
||||
oauth_provider=row.get("oauth_provider"),
|
||||
oauth_id=row.get("oauth_id"),
|
||||
needs_setup=bool(row["needs_setup"]),
|
||||
token_version=int(row["token_version"]),
|
||||
)
|
||||
66
backend/app/gateway/auth/reset_admin.py
Normal file
66
backend/app/gateway/auth/reset_admin.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""CLI tool to reset admin password.
|
||||
|
||||
Usage:
|
||||
python -m app.gateway.auth.reset_admin
|
||||
python -m app.gateway.auth.reset_admin --email admin@example.com
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import secrets
|
||||
import sys
|
||||
|
||||
from app.gateway.auth.password import hash_password
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Reset admin password")
|
||||
parser.add_argument("--email", help="Admin email (default: first admin found)")
|
||||
args = parser.parse_args()
|
||||
|
||||
repo = SQLiteUserRepository()
|
||||
|
||||
# Find admin user synchronously (CLI context, no event loop)
|
||||
import asyncio
|
||||
|
||||
user = asyncio.run(_find_admin(repo, args.email))
|
||||
if user is None:
|
||||
if args.email:
|
||||
print(f"Error: user '{args.email}' not found.", file=sys.stderr)
|
||||
else:
|
||||
print("Error: no admin user found.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
user.password_hash = hash_password(new_password)
|
||||
user.token_version += 1
|
||||
user.needs_setup = True
|
||||
asyncio.run(repo.update_user(user))
|
||||
|
||||
print(f"Password reset for: {user.email}")
|
||||
print(f"New password: {new_password}")
|
||||
print("Next login will require setup (new email + password).")
|
||||
|
||||
|
||||
async def _find_admin(repo: SQLiteUserRepository, email: str | None):
|
||||
if email:
|
||||
return await repo.get_user_by_email(email)
|
||||
# Find first admin
|
||||
import asyncio
|
||||
|
||||
from app.gateway.auth.repositories.sqlite import _get_users_conn
|
||||
|
||||
def _find_sync():
|
||||
with _get_users_conn() as conn:
|
||||
cursor = conn.execute("SELECT id FROM users WHERE system_role = 'admin' LIMIT 1")
|
||||
row = cursor.fetchone()
|
||||
return dict(row)["id"] if row else None
|
||||
|
||||
admin_id = await asyncio.to_thread(_find_sync)
|
||||
if admin_id:
|
||||
return await repo.get_user_by_id(admin_id)
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
71
backend/app/gateway/auth_middleware.py
Normal file
71
backend/app/gateway/auth_middleware.py
Normal file
@ -0,0 +1,71 @@
|
||||
"""Global authentication middleware — fail-closed safety net.
|
||||
|
||||
Rejects unauthenticated requests to non-public paths with 401.
|
||||
Fine-grained permission checks remain in authz.py decorators.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from app.gateway.auth.errors import AuthErrorCode
|
||||
|
||||
# Paths that never require authentication.
|
||||
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
|
||||
"/health",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
)
|
||||
|
||||
# Exact auth paths that are public (login/register/status check).
|
||||
# /api/v1/auth/me, /api/v1/auth/change-password etc. are NOT public.
|
||||
_PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/setup-status",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_public(path: str) -> bool:
|
||||
stripped = path.rstrip("/")
|
||||
if stripped in _PUBLIC_EXACT_PATHS:
|
||||
return True
|
||||
return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES)
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
"""Coarse-grained auth gate: reject requests without a valid session cookie.
|
||||
|
||||
This does NOT verify JWT signature or user existence — that is the job of
|
||||
``get_current_user_from_request`` in deps.py (called by ``@require_auth``).
|
||||
The middleware only checks *presence* of the cookie so that new endpoints
|
||||
that forget ``@require_auth`` are not completely exposed.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
if _is_public(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Non-public path: require session cookie
|
||||
if not request.cookies.get("access_token"):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"detail": {
|
||||
"code": AuthErrorCode.NOT_AUTHENTICATED,
|
||||
"message": "Authentication required",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
261
backend/app/gateway/authz.py
Normal file
261
backend/app/gateway/authz.py
Normal file
@ -0,0 +1,261 @@
|
||||
"""Authorization decorators and context for DeerFlow.
|
||||
|
||||
Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blob/main/libs/sdk-py/langgraph_sdk/auth/__init__.py
|
||||
|
||||
**Usage:**
|
||||
|
||||
1. Use ``@require_auth`` on routes that need authentication
|
||||
2. Use ``@require_permission("resource", "action", filter_key=...)`` for permission checks
|
||||
3. The decorator chain processes from bottom to top
|
||||
|
||||
**Example:**
|
||||
|
||||
@router.get("/{thread_id}")
|
||||
@require_auth
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread(thread_id: str, request: Request):
|
||||
# User is authenticated and has threads:read permission
|
||||
...
|
||||
|
||||
**Permission Model:**
|
||||
|
||||
- threads:read - View thread
|
||||
- threads:write - Create/update thread
|
||||
- threads:delete - Delete thread
|
||||
- runs:create - Run agent
|
||||
- runs:read - View run
|
||||
- runs:cancel - Cancel run
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.gateway.auth.models import User
|
||||
|
||||
P = ParamSpec("P")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
# Permission constants
|
||||
class Permissions:
|
||||
"""Permission constants for resource:action format."""
|
||||
|
||||
# Threads
|
||||
THREADS_READ = "threads:read"
|
||||
THREADS_WRITE = "threads:write"
|
||||
THREADS_DELETE = "threads:delete"
|
||||
|
||||
# Runs
|
||||
RUNS_CREATE = "runs:create"
|
||||
RUNS_READ = "runs:read"
|
||||
RUNS_CANCEL = "runs:cancel"
|
||||
|
||||
|
||||
class AuthContext:
|
||||
"""Authentication context for the current request.
|
||||
|
||||
Stored in request.state.auth after require_auth decoration.
|
||||
|
||||
Attributes:
|
||||
user: The authenticated user, or None if anonymous
|
||||
permissions: List of permission strings (e.g., "threads:read")
|
||||
"""
|
||||
|
||||
__slots__ = ("user", "permissions")
|
||||
|
||||
def __init__(self, user: User | None = None, permissions: list[str] | None = None):
|
||||
self.user = user
|
||||
self.permissions = permissions or []
|
||||
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
"""Check if user is authenticated."""
|
||||
return self.user is not None
|
||||
|
||||
def has_permission(self, resource: str, action: str) -> bool:
|
||||
"""Check if context has permission for resource:action.
|
||||
|
||||
Args:
|
||||
resource: Resource name (e.g., "threads")
|
||||
action: Action name (e.g., "read")
|
||||
|
||||
Returns:
|
||||
True if user has permission
|
||||
"""
|
||||
permission = f"{resource}:{action}"
|
||||
return permission in self.permissions
|
||||
|
||||
def require_user(self) -> User:
|
||||
"""Get user or raise 401.
|
||||
|
||||
Raises:
|
||||
HTTPException 401 if not authenticated
|
||||
"""
|
||||
if not self.user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
return self.user
|
||||
|
||||
|
||||
def get_auth_context(request: Request) -> AuthContext | None:
|
||||
"""Get AuthContext from request state."""
|
||||
return getattr(request.state, "auth", None)
|
||||
|
||||
|
||||
_ALL_PERMISSIONS: list[str] = [
|
||||
Permissions.THREADS_READ,
|
||||
Permissions.THREADS_WRITE,
|
||||
Permissions.THREADS_DELETE,
|
||||
Permissions.RUNS_CREATE,
|
||||
Permissions.RUNS_READ,
|
||||
Permissions.RUNS_CANCEL,
|
||||
]
|
||||
|
||||
|
||||
async def _authenticate(request: Request) -> AuthContext:
|
||||
"""Authenticate request and return AuthContext.
|
||||
|
||||
Delegates to deps.get_optional_user_from_request() for the JWT→User pipeline.
|
||||
Returns AuthContext with user=None for anonymous requests.
|
||||
"""
|
||||
from app.gateway.deps import get_optional_user_from_request
|
||||
|
||||
user = await get_optional_user_from_request(request)
|
||||
if user is None:
|
||||
return AuthContext(user=None, permissions=[])
|
||||
|
||||
# In future, permissions could be stored in user record
|
||||
return AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
||||
|
||||
|
||||
def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
|
||||
"""Decorator that authenticates the request and sets AuthContext.
|
||||
|
||||
Must be placed ABOVE other decorators (executes after them).
|
||||
|
||||
Usage:
|
||||
@router.get("/{thread_id}")
|
||||
@require_auth # Bottom decorator (executes first after permission check)
|
||||
@require_permission("threads", "read")
|
||||
async def get_thread(thread_id: str, request: Request):
|
||||
auth: AuthContext = request.state.auth
|
||||
...
|
||||
|
||||
Raises:
|
||||
ValueError: If 'request' parameter is missing
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
request = kwargs.get("request")
|
||||
if request is None:
|
||||
raise ValueError("require_auth decorator requires 'request' parameter")
|
||||
|
||||
# Authenticate and set context
|
||||
auth_context = await _authenticate(request)
|
||||
request.state.auth = auth_context
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_permission(
|
||||
resource: str,
|
||||
action: str,
|
||||
owner_check: bool = False,
|
||||
owner_filter_key: str = "user_id",
|
||||
inject_record: bool = False,
|
||||
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
||||
"""Decorator that checks permission for resource:action.
|
||||
|
||||
Must be used AFTER @require_auth.
|
||||
|
||||
Args:
|
||||
resource: Resource name (e.g., "threads", "runs")
|
||||
action: Action name (e.g., "read", "write", "delete")
|
||||
owner_check: If True, validates that the current user owns the resource.
|
||||
Requires 'thread_id' path parameter and performs ownership check.
|
||||
owner_filter_key: Field name for ownership filter (default: "user_id")
|
||||
inject_record: If True and owner_check is True, injects the thread record
|
||||
into kwargs['thread_record'] for use in the handler.
|
||||
|
||||
Usage:
|
||||
# Simple permission check
|
||||
@require_permission("threads", "read")
|
||||
async def get_thread(thread_id: str, request: Request):
|
||||
...
|
||||
|
||||
# With ownership check (for /threads/{thread_id} endpoints)
|
||||
@require_permission("threads", "delete", owner_check=True)
|
||||
async def delete_thread(thread_id: str, request: Request):
|
||||
...
|
||||
|
||||
# With ownership check and record injection
|
||||
@require_permission("threads", "delete", owner_check=True, inject_record=True)
|
||||
async def delete_thread(thread_id: str, request: Request, thread_record: dict = None):
|
||||
# thread_record is injected if found
|
||||
...
|
||||
|
||||
Raises:
|
||||
HTTPException 401: If authentication required but user is anonymous
|
||||
HTTPException 403: If user lacks permission
|
||||
HTTPException 404: If owner_check=True but user doesn't own the thread
|
||||
ValueError: If owner_check=True but 'thread_id' parameter is missing
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
request = kwargs.get("request")
|
||||
if request is None:
|
||||
raise ValueError("require_permission decorator requires 'request' parameter")
|
||||
|
||||
auth: AuthContext = getattr(request.state, "auth", None)
|
||||
if auth is None:
|
||||
auth = await _authenticate(request)
|
||||
request.state.auth = auth
|
||||
|
||||
if not auth.is_authenticated:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
# Check permission
|
||||
if not auth.has_permission(resource, action):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Permission denied: {resource}:{action}",
|
||||
)
|
||||
|
||||
# Owner check for thread-specific resources
|
||||
if owner_check:
|
||||
thread_id = kwargs.get("thread_id")
|
||||
if thread_id is None:
|
||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||
|
||||
# Get thread and verify ownership
|
||||
from app.gateway.routers.threads import _store_get, get_store
|
||||
|
||||
store = get_store(request)
|
||||
if store is not None:
|
||||
record = await _store_get(store, thread_id)
|
||||
if record:
|
||||
owner_id = record.get("metadata", {}).get(owner_filter_key)
|
||||
if owner_id and owner_id != str(auth.user.id):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Thread {thread_id} not found",
|
||||
)
|
||||
# Inject record if requested
|
||||
if inject_record:
|
||||
kwargs["thread_record"] = record
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
112
backend/app/gateway/csrf_middleware.py
Normal file
112
backend/app/gateway/csrf_middleware.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""CSRF protection middleware for FastAPI.
|
||||
|
||||
Per RFC-001:
|
||||
State-changing operations require CSRF protection.
|
||||
"""
|
||||
|
||||
import secrets
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
CSRF_COOKIE_NAME = "csrf_token"
|
||||
CSRF_HEADER_NAME = "X-CSRF-Token"
|
||||
CSRF_TOKEN_LENGTH = 64 # bytes
|
||||
|
||||
|
||||
def is_secure_request(request: Request) -> bool:
|
||||
"""Detect whether the original client request was made over HTTPS."""
|
||||
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
|
||||
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
"""Generate a secure random CSRF token."""
|
||||
return secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
||||
|
||||
|
||||
def should_check_csrf(request: Request) -> bool:
|
||||
"""Determine if a request needs CSRF validation.
|
||||
|
||||
CSRF is checked for state-changing methods (POST, PUT, DELETE, PATCH).
|
||||
GET, HEAD, OPTIONS, and TRACE are exempt per RFC 7231.
|
||||
"""
|
||||
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
||||
return False
|
||||
|
||||
path = request.url.path.rstrip("/")
|
||||
# Exempt /api/v1/auth/me endpoint
|
||||
if path == "/api/v1/auth/me":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
_AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/register",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def is_auth_endpoint(request: Request) -> bool:
|
||||
"""Check if the request is to an auth endpoint.
|
||||
|
||||
Auth endpoints don't need CSRF validation on first call (no token).
|
||||
"""
|
||||
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
|
||||
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
_is_auth = is_auth_endpoint(request)
|
||||
|
||||
if should_check_csrf(request) and not _is_auth:
|
||||
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
||||
header_token = request.headers.get(CSRF_HEADER_NAME)
|
||||
|
||||
if not cookie_token or not header_token:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "CSRF token missing. Include X-CSRF-Token header."},
|
||||
)
|
||||
|
||||
if not secrets.compare_digest(cookie_token, header_token):
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "CSRF token mismatch."},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# For auth endpoints that set up session, also set CSRF cookie
|
||||
if _is_auth and request.method == "POST":
|
||||
# Generate a new CSRF token for the session
|
||||
csrf_token = generate_csrf_token()
|
||||
is_https = is_secure_request(request)
|
||||
response.set_cookie(
|
||||
key=CSRF_COOKIE_NAME,
|
||||
value=csrf_token,
|
||||
httponly=False, # Must be JS-readable for Double Submit Cookie pattern
|
||||
secure=is_https,
|
||||
samesite="strict",
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def get_csrf_token(request: Request) -> str | None:
|
||||
"""Get the CSRF token from the current request's cookies.
|
||||
|
||||
This is useful for server-side rendering where you need to embed
|
||||
token in forms or headers.
|
||||
"""
|
||||
return request.cookies.get(CSRF_COOKIE_NAME)
|
||||
@ -3,38 +3,22 @@
|
||||
**Getters** (used by routers): raise 503 when a required dependency is
|
||||
missing, except ``get_store`` which returns ``None``.
|
||||
|
||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import AsyncExitStack, asynccontextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
|
||||
from deerflow.runtime import RunManager, StreamBridge
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Bootstrap and tear down all LangGraph runtime singletons.
|
||||
|
||||
Usage in ``app.py``::
|
||||
|
||||
async with langgraph_runtime(app):
|
||||
yield
|
||||
"""
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.runtime import make_store, make_stream_bridge
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
|
||||
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
|
||||
app.state.store = await stack.enter_async_context(make_store())
|
||||
app.state.run_manager = RunManager()
|
||||
yield
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Getters – called by routers per-request
|
||||
@ -68,3 +52,102 @@ def get_checkpointer(request: Request):
|
||||
def get_store(request: Request):
|
||||
"""Return the global store (may be ``None`` if not configured)."""
|
||||
return getattr(request.app.state, "store", None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth helpers (used by authz.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Cached singletons to avoid repeated instantiation per request
|
||||
_cached_local_provider: LocalAuthProvider | None = None
|
||||
_cached_repo: SQLiteUserRepository | None = None
|
||||
|
||||
|
||||
def get_local_provider() -> LocalAuthProvider:
|
||||
"""Get or create the cached LocalAuthProvider singleton."""
|
||||
global _cached_local_provider, _cached_repo
|
||||
if _cached_repo is None:
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
|
||||
_cached_repo = SQLiteUserRepository()
|
||||
if _cached_local_provider is None:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
|
||||
_cached_local_provider = LocalAuthProvider(repository=_cached_repo)
|
||||
return _cached_local_provider
|
||||
|
||||
|
||||
async def get_current_user_from_request(request: Request):
|
||||
"""Get the current authenticated user from the request cookie.
|
||||
|
||||
Raises HTTPException 401 if not authenticated.
|
||||
"""
|
||||
from app.gateway.auth import decode_token
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
||||
|
||||
access_token = request.cookies.get("access_token")
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(),
|
||||
)
|
||||
|
||||
payload = decode_token(access_token)
|
||||
if isinstance(payload, TokenError):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=token_error_to_code(payload), message=f"Token error: {payload.value}").model_dump(),
|
||||
)
|
||||
|
||||
provider = get_local_provider()
|
||||
user = await provider.get_user(payload.sub)
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(),
|
||||
)
|
||||
|
||||
# Token version mismatch → password was changed, token is stale
|
||||
if user.token_version != payload.ver:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.TOKEN_INVALID, message="Token revoked (password changed)").model_dump(),
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_optional_user_from_request(request: Request):
|
||||
"""Get optional authenticated user from request.
|
||||
|
||||
Returns None if not authenticated.
|
||||
"""
|
||||
try:
|
||||
return await get_current_user_from_request(request)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Runtime bootstrap
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Bootstrap and tear down all LangGraph runtime singletons.
|
||||
|
||||
Usage in ``app.py``::
|
||||
|
||||
async with langgraph_runtime(app):
|
||||
yield
|
||||
"""
|
||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||
from deerflow.runtime import make_store, make_stream_bridge
|
||||
|
||||
async with AsyncExitStack() as stack:
|
||||
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
|
||||
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
|
||||
app.state.store = await stack.enter_async_context(make_store())
|
||||
app.state.run_manager = RunManager()
|
||||
yield
|
||||
|
||||
106
backend/app/gateway/langgraph_auth.py
Normal file
106
backend/app/gateway/langgraph_auth.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""LangGraph Server auth handler — shares JWT logic with Gateway.
|
||||
|
||||
Loaded by LangGraph Server via langgraph.json ``auth.path``.
|
||||
Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway,
|
||||
so both modes validate tokens with the same secret and rules.
|
||||
|
||||
Two layers:
|
||||
1. @auth.authenticate — validates JWT cookie, extracts user_id,
|
||||
and enforces CSRF on state-changing methods (POST/PUT/DELETE/PATCH)
|
||||
2. @auth.on — returns metadata filter so each user only sees own threads
|
||||
"""
|
||||
|
||||
import secrets
|
||||
|
||||
from langgraph_sdk import Auth
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
from app.gateway.auth.jwt import decode_token
|
||||
from app.gateway.deps import get_local_provider
|
||||
|
||||
auth = Auth()
|
||||
|
||||
# Methods that require CSRF validation (state-changing per RFC 7231).
|
||||
_CSRF_METHODS = frozenset({"POST", "PUT", "DELETE", "PATCH"})
|
||||
|
||||
|
||||
def _check_csrf(request) -> None:
|
||||
"""Enforce Double Submit Cookie CSRF check for state-changing requests.
|
||||
|
||||
Mirrors Gateway's CSRFMiddleware logic so that LangGraph routes
|
||||
proxied directly by nginx have the same CSRF protection.
|
||||
"""
|
||||
method = getattr(request, "method", "") or ""
|
||||
if method.upper() not in _CSRF_METHODS:
|
||||
return
|
||||
|
||||
cookie_token = request.cookies.get("csrf_token")
|
||||
header_token = request.headers.get("x-csrf-token")
|
||||
|
||||
if not cookie_token or not header_token:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=403,
|
||||
detail="CSRF token missing. Include X-CSRF-Token header.",
|
||||
)
|
||||
|
||||
if not secrets.compare_digest(cookie_token, header_token):
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=403,
|
||||
detail="CSRF token mismatch.",
|
||||
)
|
||||
|
||||
|
||||
@auth.authenticate
|
||||
async def authenticate(request):
|
||||
"""Validate the session cookie, decode JWT, and check token_version.
|
||||
|
||||
Same validation chain as Gateway's get_current_user_from_request:
|
||||
cookie → decode JWT → DB lookup → token_version match
|
||||
Also enforces CSRF on state-changing methods.
|
||||
"""
|
||||
# CSRF check before authentication so forged cross-site requests
|
||||
# are rejected early, even if the cookie carries a valid JWT.
|
||||
_check_csrf(request)
|
||||
|
||||
token = request.cookies.get("access_token")
|
||||
if not token:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail="Not authenticated",
|
||||
)
|
||||
|
||||
payload = decode_token(token)
|
||||
if isinstance(payload, TokenError):
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Token error: {payload.value}",
|
||||
)
|
||||
|
||||
user = await get_local_provider().get_user(payload.sub)
|
||||
if user is None:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail="User not found",
|
||||
)
|
||||
if user.token_version != payload.ver:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=401,
|
||||
detail="Token revoked (password changed)",
|
||||
)
|
||||
|
||||
return payload.sub
|
||||
|
||||
|
||||
@auth.on
|
||||
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
||||
"""Inject user_id metadata on writes; filter by user_id on reads.
|
||||
|
||||
Gateway stores thread ownership as ``metadata.user_id``.
|
||||
This handler ensures LangGraph Server enforces the same isolation.
|
||||
"""
|
||||
# On create/update: stamp user_id into metadata
|
||||
metadata = value.setdefault("metadata", {})
|
||||
metadata["user_id"] = ctx.user.identity
|
||||
|
||||
# Return filter dict — LangGraph applies it to search/read/delete
|
||||
return {"user_id": ctx.user.identity}
|
||||
@ -1,3 +1,3 @@
|
||||
from . import artifacts, assistants_compat, mcp, models, skills, suggestions, thread_runs, threads, uploads
|
||||
from . import artifacts, assistants_compat, auth, mcp, models, skills, suggestions, thread_runs, threads, uploads
|
||||
|
||||
__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"]
|
||||
__all__ = ["artifacts", "assistants_compat", "auth", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"]
|
||||
|
||||
@ -24,7 +24,7 @@ class AgentResponse(BaseModel):
|
||||
description: str = Field(default="", description="Agent description")
|
||||
model: str | None = Field(default=None, description="Optional model override")
|
||||
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
|
||||
soul: str | None = Field(default=None, description="SOUL.md content (included on GET /{name})")
|
||||
soul: str | None = Field(default=None, description="SOUL.md content")
|
||||
|
||||
|
||||
class AgentsListResponse(BaseModel):
|
||||
@ -92,17 +92,17 @@ def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False
|
||||
"/agents",
|
||||
response_model=AgentsListResponse,
|
||||
summary="List Custom Agents",
|
||||
description="List all custom agents available in the agents directory.",
|
||||
description="List all custom agents available in the agents directory, including their soul content.",
|
||||
)
|
||||
async def list_agents() -> AgentsListResponse:
|
||||
"""List all custom agents.
|
||||
|
||||
Returns:
|
||||
List of all custom agents with their metadata (without soul content).
|
||||
List of all custom agents with their metadata and soul content.
|
||||
"""
|
||||
try:
|
||||
agents = list_custom_agents()
|
||||
return AgentsListResponse(agents=[_agent_config_to_response(a) for a in agents])
|
||||
return AgentsListResponse(agents=[_agent_config_to_response(a, include_soul=True) for a in agents])
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list agents: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to list agents: {str(e)}")
|
||||
|
||||
303
backend/app/gateway/routers/auth.py
Normal file
303
backend/app/gateway/routers/auth.py
Normal file
@ -0,0 +1,303 @@
|
||||
"""Authentication endpoints."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
from app.gateway.auth import (
|
||||
UserResponse,
|
||||
create_access_token,
|
||||
)
|
||||
from app.gateway.auth.config import get_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
||||
from app.gateway.csrf_middleware import is_secure_request
|
||||
from app.gateway.deps import get_current_user_from_request, get_local_provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ── Request/Response Models ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
"""Response model for login — token only lives in HttpOnly cookie."""
|
||||
|
||||
expires_in: int # seconds
|
||||
needs_setup: bool = False
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
"""Request model for user registration."""
|
||||
|
||||
email: EmailStr
|
||||
password: str = Field(..., min_length=8)
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
"""Request model for password change (also handles setup flow)."""
|
||||
|
||||
current_password: str
|
||||
new_password: str = Field(..., min_length=8)
|
||||
new_email: EmailStr | None = None
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Generic message response."""
|
||||
|
||||
message: str
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _set_session_cookie(response: Response, token: str, request: Request) -> None:
|
||||
"""Set the access_token HttpOnly cookie on the response."""
|
||||
config = get_auth_config()
|
||||
is_https = is_secure_request(request)
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=is_https,
|
||||
samesite="lax",
|
||||
max_age=config.token_expiry_days * 24 * 3600 if is_https else None,
|
||||
)
|
||||
|
||||
|
||||
# ── Rate Limiting ────────────────────────────────────────────────────────
|
||||
# In-process dict — not shared across workers. Sufficient for single-worker deployments.
|
||||
|
||||
_MAX_LOGIN_ATTEMPTS = 5
|
||||
_LOCKOUT_SECONDS = 300 # 5 minutes
|
||||
|
||||
# ip → (fail_count, lock_until_timestamp)
|
||||
_login_attempts: dict[str, tuple[int, float]] = {}
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
"""Extract the real client IP for rate limiting.
|
||||
|
||||
Uses ``X-Real-IP`` header set by nginx (``proxy_set_header X-Real-IP
|
||||
$remote_addr``). Nginx unconditionally overwrites any client-supplied
|
||||
``X-Real-IP``, so the value seen by Gateway is always the TCP peer IP
|
||||
that nginx observed — it cannot be spoofed by the client.
|
||||
|
||||
``request.client.host`` is NOT reliable because uvicorn's default
|
||||
``proxy_headers=True`` replaces it with the *first* entry from
|
||||
``X-Forwarded-For``, which IS client-spoofable.
|
||||
|
||||
``X-Forwarded-For`` is intentionally NOT used for the same reason.
|
||||
"""
|
||||
real_ip = request.headers.get("x-real-ip", "").strip()
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# Fallback: direct connection without nginx (e.g. unit tests, dev).
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
def _check_rate_limit(ip: str) -> None:
|
||||
"""Raise 429 if the IP is currently locked out."""
|
||||
record = _login_attempts.get(ip)
|
||||
if record is None:
|
||||
return
|
||||
fail_count, lock_until = record
|
||||
if fail_count >= _MAX_LOGIN_ATTEMPTS:
|
||||
if time.time() < lock_until:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Too many login attempts. Try again later.",
|
||||
)
|
||||
del _login_attempts[ip]
|
||||
|
||||
|
||||
_MAX_TRACKED_IPS = 10000
|
||||
|
||||
|
||||
def _record_login_failure(ip: str) -> None:
|
||||
"""Record a failed login attempt for the given IP."""
|
||||
# Evict expired lockouts when dict grows too large
|
||||
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
||||
now = time.time()
|
||||
expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t]
|
||||
for k in expired:
|
||||
del _login_attempts[k]
|
||||
# If still too large, evict cheapest-to-lose half: below-threshold
|
||||
# IPs (lock_until=0.0) sort first, then earliest-expiring lockouts.
|
||||
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
||||
by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1])
|
||||
for k, _ in by_time[: len(by_time) // 2]:
|
||||
del _login_attempts[k]
|
||||
|
||||
record = _login_attempts.get(ip)
|
||||
if record is None:
|
||||
_login_attempts[ip] = (1, 0.0)
|
||||
else:
|
||||
new_count = record[0] + 1
|
||||
lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0
|
||||
_login_attempts[ip] = (new_count, lock_until)
|
||||
|
||||
|
||||
def _record_login_success(ip: str) -> None:
|
||||
"""Clear failure counter for the given IP on successful login."""
|
||||
_login_attempts.pop(ip, None)
|
||||
|
||||
|
||||
# ── Endpoints ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/login/local", response_model=LoginResponse)
|
||||
async def login_local(
|
||||
request: Request,
|
||||
response: Response,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
):
|
||||
"""Local email/password login."""
|
||||
client_ip = _get_client_ip(request)
|
||||
_check_rate_limit(client_ip)
|
||||
|
||||
user = await get_local_provider().authenticate({"email": form_data.username, "password": form_data.password})
|
||||
|
||||
if user is None:
|
||||
_record_login_failure(client_ip)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Incorrect email or password").model_dump(),
|
||||
)
|
||||
|
||||
_record_login_success(client_ip)
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
|
||||
return LoginResponse(
|
||||
expires_in=get_auth_config().token_expiry_days * 24 * 3600,
|
||||
needs_setup=user.needs_setup,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(request: Request, response: Response, body: RegisterRequest):
|
||||
"""Register a new user account (always 'user' role).
|
||||
|
||||
Admin is auto-created on first boot. This endpoint creates regular users.
|
||||
Auto-login by setting the session cookie.
|
||||
"""
|
||||
try:
|
||||
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="user")
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already registered").model_dump(),
|
||||
)
|
||||
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
|
||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
|
||||
|
||||
|
||||
@router.post("/logout", response_model=MessageResponse)
|
||||
async def logout(request: Request, response: Response):
|
||||
"""Logout current user by clearing the cookie."""
|
||||
response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax")
|
||||
return MessageResponse(message="Successfully logged out")
|
||||
|
||||
|
||||
@router.post("/change-password", response_model=MessageResponse)
|
||||
async def change_password(request: Request, response: Response, body: ChangePasswordRequest):
|
||||
"""Change password for the currently authenticated user.
|
||||
|
||||
Also handles the first-boot setup flow:
|
||||
- If new_email is provided, updates email (checks uniqueness)
|
||||
- If user.needs_setup is True and new_email is given, clears needs_setup
|
||||
- Always increments token_version to invalidate old sessions
|
||||
- Re-issues session cookie with new token_version
|
||||
"""
|
||||
from app.gateway.auth.password import hash_password_async, verify_password_async
|
||||
|
||||
user = await get_current_user_from_request(request)
|
||||
|
||||
if user.password_hash is None:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
|
||||
|
||||
if not await verify_password_async(body.current_password, user.password_hash):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Current password is incorrect").model_dump())
|
||||
|
||||
provider = get_local_provider()
|
||||
|
||||
# Update email if provided
|
||||
if body.new_email is not None:
|
||||
existing = await provider.get_user_by_email(body.new_email)
|
||||
if existing and str(existing.id) != str(user.id):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already in use").model_dump())
|
||||
user.email = body.new_email
|
||||
|
||||
# Update password + bump version
|
||||
user.password_hash = await hash_password_async(body.new_password)
|
||||
user.token_version += 1
|
||||
|
||||
# Clear setup flag if this is the setup flow
|
||||
if user.needs_setup and body.new_email is not None:
|
||||
user.needs_setup = False
|
||||
|
||||
await provider.update_user(user)
|
||||
|
||||
# Re-issue cookie with new token_version
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
|
||||
return MessageResponse(message="Password changed successfully")
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_me(request: Request):
|
||||
"""Get current authenticated user info."""
|
||||
user = await get_current_user_from_request(request)
|
||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
|
||||
|
||||
|
||||
@router.get("/setup-status")
|
||||
async def setup_status():
|
||||
"""Check if admin account exists. Always False after first boot."""
|
||||
user_count = await get_local_provider().count_users()
|
||||
return {"needs_setup": user_count == 0}
|
||||
|
||||
|
||||
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/oauth/{provider}")
|
||||
async def oauth_login(provider: str):
|
||||
"""Initiate OAuth login flow.
|
||||
|
||||
Redirects to the OAuth provider's authorization URL.
|
||||
Currently a placeholder - requires OAuth provider implementation.
|
||||
"""
|
||||
if provider not in ["github", "google"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported OAuth provider: {provider}",
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="OAuth login not yet implemented",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/callback/{provider}")
|
||||
async def oauth_callback(provider: str, code: str, state: str):
|
||||
"""OAuth callback endpoint.
|
||||
|
||||
Handles the OAuth provider's callback after user authorization.
|
||||
Currently a placeholder.
|
||||
"""
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
detail="OAuth callback not yet implemented",
|
||||
)
|
||||
@ -2,6 +2,7 @@ import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.models import create_chat_model
|
||||
@ -106,22 +107,21 @@ async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> S
|
||||
if not conversation:
|
||||
return SuggestionsResponse(suggestions=[])
|
||||
|
||||
prompt = (
|
||||
system_instruction = (
|
||||
"You are generating follow-up questions to help the user continue the conversation.\n"
|
||||
f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n"
|
||||
"Requirements:\n"
|
||||
"- Questions must be relevant to the conversation.\n"
|
||||
"- Questions must be relevant to the preceding conversation.\n"
|
||||
"- Questions must be written in the same language as the user.\n"
|
||||
"- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n"
|
||||
"- Do NOT include numbering, markdown, or any extra text.\n"
|
||||
"- Output MUST be a JSON array of strings only.\n\n"
|
||||
"Conversation:\n"
|
||||
f"{conversation}\n"
|
||||
"- Output MUST be a JSON array of strings only.\n"
|
||||
)
|
||||
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
|
||||
|
||||
try:
|
||||
model = create_chat_model(name=request.model_name, thinking_enabled=False)
|
||||
response = model.invoke(prompt)
|
||||
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
|
||||
raw = _extract_response_text(response.content)
|
||||
suggestions = _parse_json_string_list(raw) or []
|
||||
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
|
||||
|
||||
@ -19,6 +19,7 @@ from fastapi import APIRouter, HTTPException, Query, Request
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.gateway.authz import require_auth, require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
|
||||
from app.gateway.services import sse_consumer, start_run
|
||||
from deerflow.runtime import RunRecord, serialize_channel_values
|
||||
@ -92,19 +93,28 @@ def _record_to_response(record: RunRecord) -> RunResponse:
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs", response_model=RunResponse)
|
||||
@require_auth
|
||||
@require_permission("runs", "create", owner_check=True)
|
||||
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
|
||||
"""Create a background run (returns immediately)."""
|
||||
"""Create a background run (returns immediately).
|
||||
|
||||
Multi-tenant isolation: only the thread owner can create runs.
|
||||
"""
|
||||
record = await start_run(body, thread_id, request)
|
||||
return _record_to_response(record)
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/stream")
|
||||
@require_auth
|
||||
@require_permission("runs", "create", owner_check=True)
|
||||
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
|
||||
"""Create a run and stream events via SSE.
|
||||
|
||||
The response includes a ``Content-Location`` header with the run's
|
||||
resource URL, matching the LangGraph Platform protocol. The
|
||||
``useStream`` React hook uses this to extract run metadata.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can stream runs.
|
||||
"""
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
@ -125,8 +135,13 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/wait", response_model=dict)
|
||||
@require_auth
|
||||
@require_permission("runs", "create", owner_check=True)
|
||||
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
|
||||
"""Create a run and block until it completes, returning the final state."""
|
||||
"""Create a run and block until it completes, returning the final state.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can wait for runs.
|
||||
"""
|
||||
record = await start_run(body, thread_id, request)
|
||||
|
||||
if record.task is not None:
|
||||
@ -150,16 +165,26 @@ async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) ->
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
|
||||
@require_auth
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
|
||||
"""List all runs for a thread."""
|
||||
"""List all runs for a thread.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can list runs.
|
||||
"""
|
||||
run_mgr = get_run_manager(request)
|
||||
records = await run_mgr.list_by_thread(thread_id)
|
||||
return [_record_to_response(r) for r in records]
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
|
||||
@require_auth
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
||||
"""Get details of a specific run."""
|
||||
"""Get details of a specific run.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can get runs.
|
||||
"""
|
||||
run_mgr = get_run_manager(request)
|
||||
record = run_mgr.get(run_id)
|
||||
if record is None or record.thread_id != thread_id:
|
||||
@ -168,6 +193,8 @@ async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/cancel")
|
||||
@require_auth
|
||||
@require_permission("runs", "cancel", owner_check=True)
|
||||
async def cancel_run(
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
@ -181,6 +208,8 @@ async def cancel_run(
|
||||
- action=rollback: Stop execution, revert to pre-run checkpoint state
|
||||
- wait=true: Block until the run fully stops, return 204
|
||||
- wait=false: Return immediately with 202
|
||||
|
||||
Multi-tenant isolation: only the thread owner can cancel runs.
|
||||
"""
|
||||
run_mgr = get_run_manager(request)
|
||||
record = run_mgr.get(run_id)
|
||||
@ -205,8 +234,13 @@ async def cancel_run(
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/join")
|
||||
@require_auth
|
||||
@require_permission("runs", "read", owner_check=True)
|
||||
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
|
||||
"""Join an existing run's SSE stream."""
|
||||
"""Join an existing run's SSE stream.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can join runs.
|
||||
"""
|
||||
bridge = get_stream_bridge(request)
|
||||
run_mgr = get_run_manager(request)
|
||||
record = run_mgr.get(run_id)
|
||||
|
||||
@ -13,17 +13,26 @@ matching the LangGraph Platform wire format expected by the
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi import APIRouter, HTTPException, Path, Request
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.gateway.authz import require_auth, require_permission
|
||||
from app.gateway.deps import get_checkpointer, get_store
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Thread ID validation (prevents log-injection via control characters)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$")
|
||||
ThreadId = Annotated[str, Path(description="Thread UUID", pattern=_UUID_RE.pattern)]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Store namespace
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -65,6 +74,13 @@ class ThreadCreateRequest(BaseModel):
|
||||
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
|
||||
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
|
||||
|
||||
@field_validator("thread_id")
|
||||
@classmethod
|
||||
def _validate_uuid(cls, v: str | None) -> str | None:
|
||||
if v is not None and not _UUID_RE.match(v):
|
||||
raise ValueError("thread_id must be a valid UUID")
|
||||
return v
|
||||
|
||||
|
||||
class ThreadSearchRequest(BaseModel):
|
||||
"""Request body for searching threads."""
|
||||
@ -215,17 +231,23 @@ def _derive_thread_status(checkpoint_tuple) -> str:
|
||||
|
||||
|
||||
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
|
||||
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
|
||||
@require_auth
|
||||
@require_permission("threads", "delete", owner_check=True)
|
||||
async def delete_thread_data(thread_id: ThreadId, request: Request) -> ThreadDeleteResponse:
|
||||
"""Delete local persisted filesystem data for a thread.
|
||||
|
||||
Cleans DeerFlow-managed thread directories, removes checkpoint data,
|
||||
and removes the thread record from the Store.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can delete their thread.
|
||||
"""
|
||||
store = get_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
# Clean local filesystem
|
||||
response = _delete_thread_data(thread_id)
|
||||
|
||||
# Remove from Store (best-effort)
|
||||
store = get_store(request)
|
||||
if store is not None:
|
||||
try:
|
||||
await store.adelete(THREADS_NS, thread_id)
|
||||
@ -233,7 +255,6 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
||||
logger.debug("Could not delete store record for thread %s (not critical)", thread_id)
|
||||
|
||||
# Remove checkpoints (best-effort)
|
||||
checkpointer = getattr(request.app.state, "checkpointer", None)
|
||||
if checkpointer is not None:
|
||||
try:
|
||||
if hasattr(checkpointer, "adelete_thread"):
|
||||
@ -251,12 +272,23 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
The thread record is written to the Store (for fast listing) and an
|
||||
empty checkpoint is written to the checkpointer (for state reads).
|
||||
Idempotent: returns the existing record when ``thread_id`` already exists.
|
||||
|
||||
If authenticated, the user's ID is injected into the thread metadata
|
||||
for multi-tenant isolation.
|
||||
"""
|
||||
store = get_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
thread_id = body.thread_id or str(uuid.uuid4())
|
||||
now = time.time()
|
||||
|
||||
from app.gateway.deps import get_optional_user_from_request
|
||||
|
||||
user = await get_optional_user_from_request(request)
|
||||
|
||||
thread_metadata = dict(body.metadata)
|
||||
if user:
|
||||
thread_metadata["user_id"] = str(user.id)
|
||||
|
||||
# Idempotency: return existing record from Store when already present
|
||||
if store is not None:
|
||||
existing_record = await _store_get(store, thread_id)
|
||||
@ -279,7 +311,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
"status": "idle",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"metadata": body.metadata,
|
||||
"metadata": thread_metadata,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
@ -296,7 +328,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
"source": "input",
|
||||
"writes": None,
|
||||
"parents": {},
|
||||
**body.metadata,
|
||||
**thread_metadata,
|
||||
"created_at": now,
|
||||
}
|
||||
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
|
||||
@ -304,13 +336,13 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
logger.exception("Failed to create checkpoint for thread %s", thread_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
||||
|
||||
logger.info("Thread created: %s", thread_id)
|
||||
logger.info("Thread created: %s (user_id=%s)", thread_id, thread_metadata.get("user_id"))
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status="idle",
|
||||
created_at=str(now),
|
||||
updated_at=str(now),
|
||||
metadata=body.metadata,
|
||||
metadata=thread_metadata,
|
||||
)
|
||||
|
||||
|
||||
@ -330,10 +362,18 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
||||
newly found thread is immediately written to the Store so that the next
|
||||
search skips Phase 2 for that thread — the Store converges to a full
|
||||
index over time without a one-shot migration job.
|
||||
|
||||
If authenticated, only threads belonging to the current user are returned
|
||||
(enforced by user_id metadata filter for multi-tenant isolation).
|
||||
"""
|
||||
store = get_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
from app.gateway.deps import get_optional_user_from_request
|
||||
|
||||
user = await get_optional_user_from_request(request)
|
||||
user_id = str(user.id) if user else None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 1: Store
|
||||
# -----------------------------------------------------------------------
|
||||
@ -409,6 +449,10 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
||||
# -----------------------------------------------------------------------
|
||||
results = list(merged.values())
|
||||
|
||||
# Multi-tenant isolation: filter by user_id if authenticated
|
||||
if user_id:
|
||||
results = [r for r in results if r.metadata.get("user_id") == user_id]
|
||||
|
||||
if body.metadata:
|
||||
results = [r for r in results if all(r.metadata.get(k) == v for k, v in body.metadata.items())]
|
||||
|
||||
@ -420,13 +464,20 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
||||
|
||||
|
||||
@router.patch("/{thread_id}", response_model=ThreadResponse)
|
||||
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
|
||||
"""Merge metadata into a thread record."""
|
||||
@require_auth
|
||||
@require_permission("threads", "write", owner_check=True, inject_record=True)
|
||||
async def patch_thread(thread_id: ThreadId, request: Request, body: ThreadPatchRequest, thread_record: dict = None) -> ThreadResponse:
|
||||
"""Merge metadata into a thread record.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can patch their thread.
|
||||
"""
|
||||
store = get_store(request)
|
||||
if store is None:
|
||||
raise HTTPException(status_code=503, detail="Store not available")
|
||||
|
||||
record = await _store_get(store, thread_id)
|
||||
record = thread_record
|
||||
if record is None:
|
||||
record = await _store_get(store, thread_id)
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
@ -451,12 +502,17 @@ async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Reques
|
||||
|
||||
|
||||
@router.get("/{thread_id}", response_model=ThreadResponse)
|
||||
async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
@require_auth
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread(thread_id: ThreadId, request: Request) -> ThreadResponse:
|
||||
"""Get thread info.
|
||||
|
||||
Reads metadata from the Store and derives the accurate execution
|
||||
status from the checkpointer. Falls back to the checkpointer alone
|
||||
for threads that pre-date Store adoption (backward compat).
|
||||
|
||||
Multi-tenant isolation: returns 404 if the thread does not belong to
|
||||
the authenticated user.
|
||||
"""
|
||||
store = get_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
@ -488,26 +544,33 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
|
||||
}
|
||||
|
||||
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle") # type: ignore[union-attr]
|
||||
if record is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle")
|
||||
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} if checkpoint_tuple is not None else {}
|
||||
channel_values = checkpoint.get("channel_values", {})
|
||||
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=status,
|
||||
created_at=str(record.get("created_at", "")), # type: ignore[union-attr]
|
||||
updated_at=str(record.get("updated_at", "")), # type: ignore[union-attr]
|
||||
metadata=record.get("metadata", {}), # type: ignore[union-attr]
|
||||
created_at=str(record.get("created_at", "")),
|
||||
updated_at=str(record.get("updated_at", "")),
|
||||
metadata=record.get("metadata", {}),
|
||||
values=serialize_channel_values(channel_values),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
|
||||
@require_auth
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread_state(thread_id: ThreadId, request: Request) -> ThreadStateResponse:
|
||||
"""Get the latest state snapshot for a thread.
|
||||
|
||||
Channel values are serialized to ensure LangChain message objects
|
||||
are converted to JSON-safe dicts.
|
||||
|
||||
Multi-tenant isolation: returns 404 if thread does not belong to user.
|
||||
"""
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
@ -552,12 +615,16 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
|
||||
|
||||
|
||||
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
|
||||
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
|
||||
@require_auth
|
||||
@require_permission("threads", "write", owner_check=True)
|
||||
async def update_thread_state(thread_id: ThreadId, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
|
||||
"""Update thread state (e.g. for human-in-the-loop resume or title rename).
|
||||
|
||||
Writes a new checkpoint that merges *body.values* into the latest
|
||||
channel values, then syncs any updated ``title`` field back to the Store
|
||||
so that ``/threads/search`` reflects the change immediately.
|
||||
|
||||
Multi-tenant isolation: only the thread owner can update their thread.
|
||||
"""
|
||||
checkpointer = get_checkpointer(request)
|
||||
store = get_store(request)
|
||||
@ -635,8 +702,13 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
|
||||
|
||||
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
|
||||
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
|
||||
"""Get checkpoint history for a thread."""
|
||||
@require_auth
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def get_thread_history(thread_id: ThreadId, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
|
||||
"""Get checkpoint history for a thread.
|
||||
|
||||
Multi-tenant isolation: returns 404 if thread does not belong to user.
|
||||
"""
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
|
||||
|
||||
@ -116,6 +116,7 @@ def build_run_config(
|
||||
metadata: dict[str, Any] | None,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a RunnableConfig dict for the agent.
|
||||
|
||||
@ -128,6 +129,9 @@ def build_run_config(
|
||||
This mirrors the channel manager's ``_resolve_run_params`` logic so that
|
||||
the LangGraph Platform-compatible HTTP API and the IM channel path behave
|
||||
identically.
|
||||
|
||||
If *user_id* is provided, it is injected into the config metadata for
|
||||
multi-tenant isolation.
|
||||
"""
|
||||
config: dict[str, Any] = {"recursion_limit": 100}
|
||||
if request_config:
|
||||
@ -161,6 +165,11 @@ def build_run_config(
|
||||
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
|
||||
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
|
||||
config["configurable"]["agent_name"] = normalized
|
||||
|
||||
# Multi-tenant isolation: inject user_id into metadata
|
||||
if user_id:
|
||||
config.setdefault("metadata", {})["user_id"] = user_id
|
||||
|
||||
if metadata:
|
||||
config.setdefault("metadata", {}).update(metadata)
|
||||
return config
|
||||
@ -260,6 +269,10 @@ async def start_run(
|
||||
|
||||
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
||||
|
||||
# Reuse auth context set by @require_auth decorator to avoid redundant DB lookup
|
||||
auth = getattr(request.state, "auth", None)
|
||||
user_id = str(auth.user.id) if auth and auth.user else None
|
||||
|
||||
try:
|
||||
record = await run_mgr.create_or_reject(
|
||||
thread_id,
|
||||
@ -282,7 +295,13 @@ async def start_run(
|
||||
|
||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||
graph_input = normalize_input(body.input)
|
||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||
config = build_run_config(
|
||||
thread_id,
|
||||
body.config,
|
||||
body.metadata,
|
||||
assistant_id=body.assistant_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Merge DeerFlow-specific context overrides into configurable.
|
||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||||
|
||||
1786
backend/docs/AUTH_TEST_PLAN.md
Normal file
1786
backend/docs/AUTH_TEST_PLAN.md
Normal file
File diff suppressed because it is too large
Load Diff
129
backend/docs/AUTH_UPGRADE.md
Normal file
129
backend/docs/AUTH_UPGRADE.md
Normal file
@ -0,0 +1,129 @@
|
||||
# Authentication Upgrade Guide
|
||||
|
||||
DeerFlow 内置了认证模块。本文档面向从无认证版本升级的用户。
|
||||
|
||||
## 核心概念
|
||||
|
||||
认证模块采用**始终强制**策略:
|
||||
|
||||
- 首次启动时自动创建 admin 账号,随机密码打印到控制台日志
|
||||
- 认证从一开始就是强制的,无竞争窗口
|
||||
- 历史对话(升级前创建的 thread)自动迁移到 admin 名下
|
||||
|
||||
## 升级步骤
|
||||
|
||||
### 1. 更新代码
|
||||
|
||||
```bash
|
||||
git pull origin main
|
||||
cd backend && make install
|
||||
```
|
||||
|
||||
### 2. 首次启动
|
||||
|
||||
```bash
|
||||
make dev
|
||||
```
|
||||
|
||||
控制台会输出:
|
||||
|
||||
```
|
||||
============================================================
|
||||
Admin account created on first boot
|
||||
Email: admin@deerflow.dev
|
||||
Password: aB3xK9mN_pQ7rT2w
|
||||
Change it after login: Settings → Account
|
||||
============================================================
|
||||
```
|
||||
|
||||
如果未登录就重启了服务,不用担心——只要 setup 未完成,每次启动都会重置密码并重新打印到控制台。
|
||||
|
||||
### 3. 登录
|
||||
|
||||
访问 `http://localhost:2026/login`,使用控制台输出的邮箱和密码登录。
|
||||
|
||||
### 4. 修改密码
|
||||
|
||||
登录后进入 Settings → Account → Change Password。
|
||||
|
||||
### 5. 添加用户(可选)
|
||||
|
||||
其他用户通过 `/login` 页面注册,自动获得 **user** 角色。每个用户只能看到自己的对话。
|
||||
|
||||
## 安全机制
|
||||
|
||||
| 机制 | 说明 |
|
||||
|------|------|
|
||||
| JWT HttpOnly Cookie | Token 不暴露给 JavaScript,防止 XSS 窃取 |
|
||||
| CSRF Double Submit Cookie | 所有 POST/PUT/DELETE 请求需携带 `X-CSRF-Token` |
|
||||
| bcrypt 密码哈希 | 密码不以明文存储 |
|
||||
| 多租户隔离 | 用户只能访问自己的 thread |
|
||||
| HTTPS 自适应 | 检测 `x-forwarded-proto`,自动设置 `Secure` cookie 标志 |
|
||||
|
||||
## 常见操作
|
||||
|
||||
### 忘记密码
|
||||
|
||||
```bash
|
||||
cd backend
|
||||
|
||||
# 重置 admin 密码
|
||||
python -m app.gateway.auth.reset_admin
|
||||
|
||||
# 重置指定用户密码
|
||||
python -m app.gateway.auth.reset_admin --email user@example.com
|
||||
```
|
||||
|
||||
会输出新的随机密码。
|
||||
|
||||
### 完全重置
|
||||
|
||||
删除用户数据库,重启后自动创建新 admin:
|
||||
|
||||
```bash
|
||||
rm -f backend/.deer-flow/users.db
|
||||
# 重启服务,控制台输出新密码
|
||||
```
|
||||
|
||||
## 数据存储
|
||||
|
||||
| 文件 | 内容 |
|
||||
|------|------|
|
||||
| `.deer-flow/users.db` | SQLite 用户数据库(密码哈希、角色) |
|
||||
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成临时密钥,重启后 session 失效) |
|
||||
|
||||
### 生产环境建议
|
||||
|
||||
```bash
|
||||
# 生成持久化 JWT 密钥,避免重启后所有用户需重新登录
|
||||
python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
# 将输出添加到 .env:
|
||||
# AUTH_JWT_SECRET=<生成的密钥>
|
||||
```
|
||||
|
||||
## API 端点
|
||||
|
||||
| 端点 | 方法 | 说明 |
|
||||
|------|------|------|
|
||||
| `/api/v1/auth/login/local` | POST | 邮箱密码登录(OAuth2 form) |
|
||||
| `/api/v1/auth/register` | POST | 注册新用户(user 角色) |
|
||||
| `/api/v1/auth/logout` | POST | 登出(清除 cookie) |
|
||||
| `/api/v1/auth/me` | GET | 获取当前用户信息 |
|
||||
| `/api/v1/auth/change-password` | POST | 修改密码 |
|
||||
| `/api/v1/auth/setup-status` | GET | 检查 admin 是否存在 |
|
||||
|
||||
## 兼容性
|
||||
|
||||
- **标准模式**(`make dev`):完全兼容,admin 自动创建
|
||||
- **Gateway 模式**(`make dev-pro`):完全兼容
|
||||
- **Docker 部署**:完全兼容,`.deer-flow/users.db` 需持久化卷挂载
|
||||
- **IM 渠道**(Feishu/Slack/Telegram):通过 LangGraph SDK 通信,不经过认证层
|
||||
- **DeerFlowClient**(嵌入式):不经过 HTTP,不受认证影响
|
||||
|
||||
## 故障排查
|
||||
|
||||
| 症状 | 原因 | 解决 |
|
||||
|------|------|------|
|
||||
| 启动后没看到密码 | admin 已存在(非首次启动) | 用 `reset_admin` 重置,或删 `users.db` |
|
||||
| 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 |
|
||||
| 重启后需要重新登录 | `AUTH_JWT_SECRET` 未持久化 | 在 `.env` 中设置固定密钥 |
|
||||
@ -248,7 +248,7 @@ def after_agent(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | N
|
||||
- [`packages/harness/deerflow/agents/thread_state.py`](../packages/harness/deerflow/agents/thread_state.py) - ThreadState 定义
|
||||
- [`packages/harness/deerflow/agents/middlewares/title_middleware.py`](../packages/harness/deerflow/agents/middlewares/title_middleware.py) - TitleMiddleware 实现
|
||||
- [`packages/harness/deerflow/config/title_config.py`](../packages/harness/deerflow/config/title_config.py) - 配置管理
|
||||
- [`config.yaml`](../config.yaml) - 配置文件
|
||||
- [`config.yaml`](../../config.example.yaml) - 配置文件
|
||||
- [`packages/harness/deerflow/agents/lead_agent/agent.py`](../packages/harness/deerflow/agents/lead_agent/agent.py) - Middleware 注册
|
||||
|
||||
## 参考资料
|
||||
|
||||
@ -30,7 +30,7 @@
|
||||
|
||||
### 2. 配置文件
|
||||
|
||||
#### [`config.yaml`](../config.yaml)
|
||||
#### [`config.yaml`](../../config.example.yaml)
|
||||
- ✅ 添加 title 配置段:
|
||||
```yaml
|
||||
title:
|
||||
@ -51,7 +51,7 @@ title:
|
||||
- ✅ 故障排查指南
|
||||
- ✅ State vs Metadata 对比
|
||||
|
||||
#### [`BACKEND_TODO.md`](../BACKEND_TODO.md)
|
||||
#### [`TODO.md`](TODO.md)
|
||||
- ✅ 添加功能完成记录
|
||||
|
||||
### 4. 测试
|
||||
|
||||
446
backend/docs/rfc-grep-glob-tools.md
Normal file
446
backend/docs/rfc-grep-glob-tools.md
Normal file
@ -0,0 +1,446 @@
|
||||
# [RFC] 在 DeerFlow 中增加 `grep` 与 `glob` 文件搜索工具
|
||||
|
||||
## Summary
|
||||
|
||||
我认为这个方向是对的,而且值得做。
|
||||
|
||||
如果 DeerFlow 想更接近 Claude Code 这类 coding agent 的实际工作流,仅有 `ls` / `read_file` / `write_file` / `str_replace` 还不够。模型在进入修改前,通常还需要两类能力:
|
||||
|
||||
- `glob`: 快速按路径模式找文件
|
||||
- `grep`: 快速按内容模式找候选位置
|
||||
|
||||
这两类工具的价值,不是“功能上 bash 也能做”,而是它们能以更低 token 成本、更强约束、更稳定的输出格式,替代模型频繁走 `bash find` / `bash grep` / `rg` 的习惯。
|
||||
|
||||
但前提是实现方式要对:**它们应该是只读、结构化、受限、可审计的原生工具,而不是对 shell 命令的简单包装。**
|
||||
|
||||
## Problem
|
||||
|
||||
当前 DeerFlow 的文件工具层主要覆盖:
|
||||
|
||||
- `ls`: 浏览目录结构
|
||||
- `read_file`: 读取文件内容
|
||||
- `write_file`: 写文件
|
||||
- `str_replace`: 做局部字符串替换
|
||||
- `bash`: 兜底执行命令
|
||||
|
||||
这套能力能完成任务,但在代码库探索阶段效率不高。
|
||||
|
||||
典型问题:
|
||||
|
||||
1. 模型想找 “所有 `*.tsx` 的 page 文件” 时,只能反复 `ls` 多层目录,或者退回 `bash find`
|
||||
2. 模型想找 “某个 symbol / 文案 / 配置键在哪里出现” 时,只能逐文件 `read_file`,或者退回 `bash grep` / `rg`
|
||||
3. 一旦退回 `bash`,工具调用就失去结构化输出,结果也更难做裁剪、分页、审计和跨 sandbox 一致化
|
||||
4. 对没有开启 host bash 的本地模式,`bash` 甚至可能不可用,此时缺少足够强的只读检索能力
|
||||
|
||||
结论:DeerFlow 现在缺的不是“再多一个 shell 命令”,而是**文件系统检索层**。
|
||||
|
||||
## Goals
|
||||
|
||||
- 为 agent 提供稳定的路径搜索和内容搜索能力
|
||||
- 减少对 `bash` 的依赖,特别是在仓库探索阶段
|
||||
- 保持与现有 sandbox 安全模型一致
|
||||
- 输出格式结构化,便于模型后续串联 `read_file` / `str_replace`
|
||||
- 让本地 sandbox、容器 sandbox、未来 MCP 文件系统工具都能遵守同一语义
|
||||
|
||||
## Non-Goals
|
||||
|
||||
- 不做通用 shell 兼容层
|
||||
- 不暴露完整 grep/find/rg CLI 语法
|
||||
- 不在第一版支持二进制检索、复杂 PCRE 特性、上下文窗口高亮渲染等重功能
|
||||
- 不把它做成“任意磁盘搜索”,仍然只允许在 DeerFlow 已授权的路径内执行
|
||||
|
||||
## Why This Is Worth Doing
|
||||
|
||||
参考 Claude Code 这一类 agent 的设计思路,`glob` 和 `grep` 的核心价值不是新能力本身,而是把“探索代码库”的常见动作从开放式 shell 降到受控工具层。
|
||||
|
||||
这样有几个直接收益:
|
||||
|
||||
1. **更低的模型负担**
|
||||
模型不需要自己拼 `find`, `grep`, `rg`, `xargs`, quoting 等命令细节。
|
||||
|
||||
2. **更稳定的跨环境行为**
|
||||
本地、Docker、AIO sandbox 不必依赖容器里是否装了 `rg`,也不会因为 shell 差异导致行为漂移。
|
||||
|
||||
3. **更强的安全与审计**
|
||||
调用参数就是“搜索什么、在哪搜、最多返回多少”,天然比任意命令更容易审计和限流。
|
||||
|
||||
4. **更好的 token 效率**
|
||||
`grep` 返回的是命中摘要而不是整段文件,模型只对少数候选路径再调用 `read_file`。
|
||||
|
||||
5. **对 `tool_search` 友好**
|
||||
当 DeerFlow 持续扩展工具集时,`grep` / `glob` 会成为非常高频的基础工具,值得保留为 built-in,而不是让模型总是退回通用 bash。
|
||||
|
||||
## Proposal
|
||||
|
||||
增加两个 built-in sandbox tools:
|
||||
|
||||
- `glob`
|
||||
- `grep`
|
||||
|
||||
推荐继续放在:
|
||||
|
||||
- `backend/packages/harness/deerflow/sandbox/tools.py`
|
||||
|
||||
并在 `config.example.yaml` 中默认加入 `file:read` 组。
|
||||
|
||||
### 1. `glob` 工具
|
||||
|
||||
用途:按路径模式查找文件或目录。
|
||||
|
||||
建议 schema:
|
||||
|
||||
```python
|
||||
@tool("glob", parse_docstring=True)
|
||||
def glob_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
description: str,
|
||||
pattern: str,
|
||||
path: str,
|
||||
include_dirs: bool = False,
|
||||
max_results: int = 200,
|
||||
) -> str:
|
||||
...
|
||||
```
|
||||
|
||||
参数语义:
|
||||
|
||||
- `description`: 与现有工具保持一致
|
||||
- `pattern`: glob 模式,例如 `**/*.py`、`src/**/test_*.ts`
|
||||
- `path`: 搜索根目录,必须是绝对路径
|
||||
- `include_dirs`: 是否返回目录
|
||||
- `max_results`: 最大返回条数,防止一次性打爆上下文
|
||||
|
||||
建议返回格式:
|
||||
|
||||
```text
|
||||
Found 3 paths under /mnt/user-data/workspace
|
||||
1. /mnt/user-data/workspace/backend/app.py
|
||||
2. /mnt/user-data/workspace/backend/tests/test_app.py
|
||||
3. /mnt/user-data/workspace/scripts/build.py
|
||||
```
|
||||
|
||||
如果后续想更适合前端消费,也可以改成 JSON 字符串;但第一版为了兼容现有工具风格,返回可读文本即可。
|
||||
|
||||
### 2. `grep` 工具
|
||||
|
||||
用途:按内容模式搜索文件,返回命中位置摘要。
|
||||
|
||||
建议 schema:
|
||||
|
||||
```python
|
||||
@tool("grep", parse_docstring=True)
|
||||
def grep_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
description: str,
|
||||
pattern: str,
|
||||
path: str,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = 100,
|
||||
) -> str:
|
||||
...
|
||||
```
|
||||
|
||||
参数语义:
|
||||
|
||||
- `pattern`: 搜索词或正则
|
||||
- `path`: 搜索根目录,必须是绝对路径
|
||||
- `glob`: 可选路径过滤,例如 `**/*.py`
|
||||
- `literal`: 为 `True` 时按普通字符串匹配,不解释为正则
|
||||
- `case_sensitive`: 是否大小写敏感
|
||||
- `max_results`: 最大返回命中数,不是文件数
|
||||
|
||||
建议返回格式:
|
||||
|
||||
```text
|
||||
Found 4 matches under /mnt/user-data/workspace
|
||||
/mnt/user-data/workspace/backend/config.py:12: TOOL_GROUPS = [...]
|
||||
/mnt/user-data/workspace/backend/config.py:48: def load_tool_config(...):
|
||||
/mnt/user-data/workspace/backend/tools.py:91: "tool_groups"
|
||||
/mnt/user-data/workspace/backend/tests/test_config.py:22: assert "tool_groups" in data
|
||||
```
|
||||
|
||||
第一版建议只返回:
|
||||
|
||||
- 文件路径
|
||||
- 行号
|
||||
- 命中行摘要
|
||||
|
||||
不返回上下文块,避免结果过大。模型如果需要上下文,再调用 `read_file(path, start_line, end_line)`。
|
||||
|
||||
## Design Principles
|
||||
|
||||
### A. 不做 shell wrapper
|
||||
|
||||
不建议把 `grep` 实现为:
|
||||
|
||||
```python
|
||||
subprocess.run("grep ...")
|
||||
```
|
||||
|
||||
也不建议在容器里直接拼 `find` / `rg` 命令。
|
||||
|
||||
原因:
|
||||
|
||||
- 会引入 shell quoting 和注入面
|
||||
- 会依赖不同 sandbox 内镜像是否安装同一套命令
|
||||
- Windows / macOS / Linux 行为不一致
|
||||
- 很难稳定控制输出条数与格式
|
||||
|
||||
正确方向是:
|
||||
|
||||
- `glob` 使用 Python 标准库路径遍历
|
||||
- `grep` 使用 Python 逐文件扫描
|
||||
- 输出由 DeerFlow 自己格式化
|
||||
|
||||
如果未来为了性能考虑要优先调用 `rg`,也应该封装在 provider 内部,并保证外部语义不变,而不是把 CLI 暴露给模型。
|
||||
|
||||
### B. 继续沿用 DeerFlow 的路径权限模型
|
||||
|
||||
这两个工具必须复用当前 `ls` / `read_file` 的路径校验逻辑:
|
||||
|
||||
- 本地模式走 `validate_local_tool_path(..., read_only=True)`
|
||||
- 支持 `/mnt/skills/...`
|
||||
- 支持 `/mnt/acp-workspace/...`
|
||||
- 支持 thread workspace / uploads / outputs 的虚拟路径解析
|
||||
- 明确拒绝越权路径与 path traversal
|
||||
|
||||
也就是说,它们属于 **file:read**,不是 `bash` 的替代越权入口。
|
||||
|
||||
### C. 结果必须硬限制
|
||||
|
||||
没有硬限制的 `glob` / `grep` 很容易炸上下文。
|
||||
|
||||
建议第一版至少限制:
|
||||
|
||||
- `glob.max_results` 默认 200,最大 1000
|
||||
- `grep.max_results` 默认 100,最大 500
|
||||
- 单行摘要最大长度,例如 200 字符
|
||||
- 二进制文件跳过
|
||||
- 超大文件跳过,例如单文件大于 1 MB 或按配置控制
|
||||
|
||||
此外,命中数超过阈值时应返回:
|
||||
|
||||
- 已展示的条数
|
||||
- 被截断的事实
|
||||
- 建议用户缩小搜索范围
|
||||
|
||||
例如:
|
||||
|
||||
```text
|
||||
Found more than 100 matches, showing first 100. Narrow the path or add a glob filter.
|
||||
```
|
||||
|
||||
### D. 工具语义要彼此互补
|
||||
|
||||
推荐模型工作流应该是:
|
||||
|
||||
1. `glob` 找候选文件
|
||||
2. `grep` 找候选位置
|
||||
3. `read_file` 读局部上下文
|
||||
4. `str_replace` / `write_file` 执行修改
|
||||
|
||||
这样工具边界清晰,也更利于 prompt 中教模型形成稳定习惯。
|
||||
|
||||
## Implementation Approach
|
||||
|
||||
## Option A: 直接在 `sandbox/tools.py` 实现第一版
|
||||
|
||||
这是我推荐的起步方案。
|
||||
|
||||
做法:
|
||||
|
||||
- 在 `sandbox/tools.py` 新增 `glob_tool` 与 `grep_tool`
|
||||
- 在 local sandbox 场景直接使用 Python 文件系统 API
|
||||
- 在非 local sandbox 场景,优先也通过 DeerFlow 自己控制的路径访问层实现
|
||||
|
||||
优点:
|
||||
|
||||
- 改动小
|
||||
- 能尽快验证 agent 效果
|
||||
- 不需要先改 `Sandbox` 抽象
|
||||
|
||||
缺点:
|
||||
|
||||
- `tools.py` 会继续变胖
|
||||
- 如果未来想在 provider 侧做性能优化,需要再抽象一次
|
||||
|
||||
## Option B: 先扩展 `Sandbox` 抽象
|
||||
|
||||
例如新增:
|
||||
|
||||
```python
|
||||
class Sandbox(ABC):
|
||||
def glob(self, path: str, pattern: str, include_dirs: bool = False, max_results: int = 200) -> list[str]:
|
||||
...
|
||||
|
||||
def grep(
|
||||
self,
|
||||
path: str,
|
||||
pattern: str,
|
||||
*,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = 100,
|
||||
) -> list[GrepMatch]:
|
||||
...
|
||||
```
|
||||
|
||||
优点:
|
||||
|
||||
- 抽象更干净
|
||||
- 容器 / 远程 sandbox 可以各自优化
|
||||
|
||||
缺点:
|
||||
|
||||
- 首次引入成本更高
|
||||
- 需要同步改所有 sandbox provider
|
||||
|
||||
结论:
|
||||
|
||||
**第一版建议走 Option A,等工具价值验证后再下沉到 `Sandbox` 抽象层。**
|
||||
|
||||
## Detailed Behavior
|
||||
|
||||
### `glob` 行为
|
||||
|
||||
- 输入根目录不存在:返回清晰错误
|
||||
- 根路径不是目录:返回清晰错误
|
||||
- 模式非法:返回清晰错误
|
||||
- 结果为空:返回 `No files matched`
|
||||
- 默认忽略项应尽量与当前 `list_dir` 对齐,例如:
|
||||
- `.git`
|
||||
- `node_modules`
|
||||
- `__pycache__`
|
||||
- `.venv`
|
||||
- 构建产物目录
|
||||
|
||||
这里建议抽一个共享 ignore 集,避免 `ls` 与 `glob` 结果风格不一致。
|
||||
|
||||
### `grep` 行为
|
||||
|
||||
- 默认只扫描文本文件
|
||||
- 检测到二进制文件直接跳过
|
||||
- 对超大文件直接跳过或只扫前 N KB
|
||||
- regex 编译失败时返回参数错误
|
||||
- 输出中的路径继续使用虚拟路径,而不是暴露宿主真实路径
|
||||
- 建议默认按文件路径、行号排序,保持稳定输出
|
||||
|
||||
## Prompting Guidance
|
||||
|
||||
如果引入这两个工具,建议同步更新系统提示中的文件操作建议:
|
||||
|
||||
- 查找文件名模式时优先用 `glob`
|
||||
- 查找代码符号、配置项、文案时优先用 `grep`
|
||||
- 只有在工具不足以完成目标时才退回 `bash`
|
||||
|
||||
否则模型仍会习惯性先调用 `bash`。
|
||||
|
||||
## Risks
|
||||
|
||||
### 1. 与 `bash` 能力重叠
|
||||
|
||||
这是事实,但不是问题。
|
||||
|
||||
`ls` 和 `read_file` 也都能被 `bash` 替代,但我们仍然保留它们,因为结构化工具更适合 agent。
|
||||
|
||||
### 2. 性能问题
|
||||
|
||||
在大仓库上,纯 Python `grep` 可能比 `rg` 慢。
|
||||
|
||||
缓解方式:
|
||||
|
||||
- 第一版先加结果上限和文件大小上限
|
||||
- 路径上强制要求 root path
|
||||
- 提供 `glob` 过滤缩小扫描范围
|
||||
- 后续如有必要,在 provider 内部做 `rg` 优化,但保持同一 schema
|
||||
|
||||
### 3. 忽略规则不一致
|
||||
|
||||
如果 `ls` 能看到的路径,`glob` 却看不到,模型会困惑。
|
||||
|
||||
缓解方式:
|
||||
|
||||
- 统一 ignore 规则
|
||||
- 在文档里明确“默认跳过常见依赖和构建目录”
|
||||
|
||||
### 4. 正则搜索过于复杂
|
||||
|
||||
如果第一版就支持大量 grep 方言,边界会很乱。
|
||||
|
||||
缓解方式:
|
||||
|
||||
- 第一版只支持 Python `re`
|
||||
- 并提供 `literal=True` 的简单模式
|
||||
|
||||
## Alternatives Considered
|
||||
|
||||
### A. 不增加工具,完全依赖 `bash`
|
||||
|
||||
不推荐。
|
||||
|
||||
这会让 DeerFlow 在代码探索体验上持续落后,也削弱无 bash 或受限 bash 场景下的能力。
|
||||
|
||||
### B. 只加 `glob`,不加 `grep`
|
||||
|
||||
不推荐。
|
||||
|
||||
只解决“找文件”,没有解决“找位置”。模型最终还是会退回 `bash grep`。
|
||||
|
||||
### C. 只加 `grep`,不加 `glob`
|
||||
|
||||
也不推荐。
|
||||
|
||||
`grep` 缺少路径模式过滤时,扫描范围经常太大;`glob` 是它的天然前置工具。
|
||||
|
||||
### D. 直接接入 MCP filesystem server 的搜索能力
|
||||
|
||||
短期不推荐作为主路径。
|
||||
|
||||
MCP 可以是补充,但 `glob` / `grep` 作为 DeerFlow 的基础 coding tool,最好仍然是 built-in,这样才能在默认安装中稳定可用。
|
||||
|
||||
## Acceptance Criteria
|
||||
|
||||
- `config.example.yaml` 中可默认启用 `glob` 与 `grep`
|
||||
- 两个工具归属 `file:read` 组
|
||||
- 本地 sandbox 下严格遵守现有路径权限
|
||||
- 输出不泄露宿主机真实路径
|
||||
- 大结果集会被截断并明确提示
|
||||
- 模型可以通过 `glob -> grep -> read_file -> str_replace` 完成典型改码流
|
||||
- 在禁用 host bash 的本地模式下,仓库探索能力明显提升
|
||||
|
||||
## Rollout Plan
|
||||
|
||||
1. 在 `sandbox/tools.py` 中实现 `glob_tool` 与 `grep_tool`
|
||||
2. 抽取与 `list_dir` 一致的 ignore 规则,避免行为漂移
|
||||
3. 在 `config.example.yaml` 默认加入工具配置
|
||||
4. 为本地路径校验、虚拟路径映射、结果截断、二进制跳过补测试
|
||||
5. 更新 README / backend docs / prompt guidance
|
||||
6. 收集实际 agent 调用数据,再决定是否下沉到 `Sandbox` 抽象
|
||||
|
||||
## Suggested Config
|
||||
|
||||
```yaml
|
||||
tools:
|
||||
- name: glob
|
||||
group: file:read
|
||||
use: deerflow.sandbox.tools:glob_tool
|
||||
|
||||
- name: grep
|
||||
group: file:read
|
||||
use: deerflow.sandbox.tools:grep_tool
|
||||
```
|
||||
|
||||
## Final Recommendation
|
||||
|
||||
结论是:**可以加,而且应该加。**
|
||||
|
||||
但我会明确卡三个边界:
|
||||
|
||||
1. `grep` / `glob` 必须是 built-in 的只读结构化工具
|
||||
2. 第一版不要做 shell wrapper,不要把 CLI 方言直接暴露给模型
|
||||
3. 先在 `sandbox/tools.py` 验证价值,再考虑是否下沉到 `Sandbox` provider 抽象
|
||||
|
||||
如果按这个方向做,它会明显提升 DeerFlow 在 coding / repo exploration 场景下的可用性,而且风险可控。
|
||||
@ -8,6 +8,9 @@
|
||||
"graphs": {
|
||||
"lead_agent": "deerflow.agents:make_lead_agent"
|
||||
},
|
||||
"auth": {
|
||||
"path": "./app/gateway/langgraph_auth.py:auth"
|
||||
},
|
||||
"checkpointer": {
|
||||
"path": "./packages/harness/deerflow/agents/checkpointer/async_provider.py:make_checkpointer"
|
||||
}
|
||||
|
||||
@ -8,6 +8,14 @@ from deerflow.subagents import get_available_subagent_names
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_enabled_skills():
|
||||
try:
|
||||
return list(load_skills(enabled_only=True))
|
||||
except Exception:
|
||||
logger.exception("Failed to load enabled skills for prompt injection")
|
||||
return []
|
||||
|
||||
|
||||
def _build_subagent_section(max_concurrent: int) -> str:
|
||||
"""Build the subagent system prompt section with dynamic concurrency limit.
|
||||
|
||||
@ -386,7 +394,7 @@ def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
|
||||
Returns the <skill_system>...</skill_system> block listing all enabled skills,
|
||||
suitable for injection into any agent's system prompt.
|
||||
"""
|
||||
skills = load_skills(enabled_only=True)
|
||||
skills = _get_enabled_skills()
|
||||
|
||||
try:
|
||||
from deerflow.config import get_app_config
|
||||
@ -450,7 +458,7 @@ def get_deferred_tools_prompt_section() -> str:
|
||||
|
||||
if not get_app_config().tool_search.enabled:
|
||||
return ""
|
||||
except FileNotFoundError:
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
registry = get_deferred_registry()
|
||||
|
||||
@ -246,6 +246,10 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
||||
if earlier.get("summary"):
|
||||
history_sections.append(f"Earlier: {earlier['summary']}")
|
||||
|
||||
background = history_data.get("longTermBackground", {})
|
||||
if background.get("summary"):
|
||||
history_sections.append(f"Background: {background['summary']}")
|
||||
|
||||
if history_sections:
|
||||
sections.append("History:\n" + "\n".join(f"- {s}" for s in history_sections))
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@ class ConversationContext:
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
agent_name: str | None = None
|
||||
correction_detected: bool = False
|
||||
reinforcement_detected: bool = False
|
||||
|
||||
|
||||
class MemoryUpdateQueue:
|
||||
@ -44,6 +45,7 @@ class MemoryUpdateQueue:
|
||||
messages: list[Any],
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> None:
|
||||
"""Add a conversation to the update queue.
|
||||
|
||||
@ -52,6 +54,7 @@ class MemoryUpdateQueue:
|
||||
messages: The conversation messages.
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
if not config.enabled:
|
||||
@ -63,11 +66,13 @@ class MemoryUpdateQueue:
|
||||
None,
|
||||
)
|
||||
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||
merged_reinforcement_detected = reinforcement_detected or (existing_context.reinforcement_detected if existing_context is not None else False)
|
||||
context = ConversationContext(
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
correction_detected=merged_correction_detected,
|
||||
reinforcement_detected=merged_reinforcement_detected,
|
||||
)
|
||||
|
||||
# Check if this thread already has a pending update
|
||||
@ -130,6 +135,7 @@ class MemoryUpdateQueue:
|
||||
thread_id=context.thread_id,
|
||||
agent_name=context.agent_name,
|
||||
correction_detected=context.correction_detected,
|
||||
reinforcement_detected=context.reinforcement_detected,
|
||||
)
|
||||
if success:
|
||||
logger.info("Memory updated successfully for thread %s", context.thread_id)
|
||||
|
||||
@ -246,7 +246,7 @@ def _fact_content_key(content: Any) -> str | None:
|
||||
stripped = content.strip()
|
||||
if not stripped:
|
||||
return None
|
||||
return stripped
|
||||
return stripped.casefold()
|
||||
|
||||
|
||||
class MemoryUpdater:
|
||||
@ -272,6 +272,7 @@ class MemoryUpdater:
|
||||
thread_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> bool:
|
||||
"""Update memory based on conversation messages.
|
||||
|
||||
@ -280,6 +281,7 @@ class MemoryUpdater:
|
||||
thread_id: Optional thread ID for tracking source.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
|
||||
Returns:
|
||||
True if update was successful, False otherwise.
|
||||
@ -310,6 +312,14 @@ class MemoryUpdater:
|
||||
"and record the correct approach as a fact with category "
|
||||
'"correction" and confidence >= 0.95 when appropriate.'
|
||||
)
|
||||
if reinforcement_detected:
|
||||
reinforcement_hint = (
|
||||
"IMPORTANT: Positive reinforcement signals were detected in this conversation. "
|
||||
"The user explicitly confirmed the agent's approach was correct or helpful. "
|
||||
"Record the confirmed approach, style, or preference as a fact with category "
|
||||
'"preference" or "behavior" and confidence >= 0.9 when appropriate.'
|
||||
)
|
||||
correction_hint = (correction_hint + "\n" + reinforcement_hint).strip() if correction_hint else reinforcement_hint
|
||||
|
||||
prompt = MEMORY_UPDATE_PROMPT.format(
|
||||
current_memory=json.dumps(current_memory, indent=2),
|
||||
@ -441,6 +451,7 @@ def update_memory_from_conversation(
|
||||
thread_id: str | None = None,
|
||||
agent_name: str | None = None,
|
||||
correction_detected: bool = False,
|
||||
reinforcement_detected: bool = False,
|
||||
) -> bool:
|
||||
"""Convenience function to update memory from a conversation.
|
||||
|
||||
@ -449,9 +460,10 @@ def update_memory_from_conversation(
|
||||
thread_id: Optional thread ID.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
correction_detected: Whether recent turns include an explicit correction signal.
|
||||
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
updater = MemoryUpdater()
|
||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected)
|
||||
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)
|
||||
|
||||
@ -182,6 +182,23 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
|
||||
return None, False
|
||||
|
||||
@staticmethod
|
||||
def _append_text(content: str | list | None, text: str) -> str | list:
|
||||
"""Append *text* to AIMessage content, handling str, list, and None.
|
||||
|
||||
When content is a list of content blocks (e.g. Anthropic thinking mode),
|
||||
we append a new ``{"type": "text", ...}`` block instead of concatenating
|
||||
a string to a list, which would raise ``TypeError``.
|
||||
"""
|
||||
if content is None:
|
||||
return text
|
||||
if isinstance(content, list):
|
||||
return [*content, {"type": "text", "text": f"\n\n{text}"}]
|
||||
if isinstance(content, str):
|
||||
return content + f"\n\n{text}"
|
||||
# Fallback: coerce unexpected types to str to avoid TypeError
|
||||
return str(content) + f"\n\n{text}"
|
||||
|
||||
def _apply(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
warning, hard_stop = self._track_and_check(state, runtime)
|
||||
|
||||
@ -192,7 +209,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
|
||||
stripped_msg = last_msg.model_copy(
|
||||
update={
|
||||
"tool_calls": [],
|
||||
"content": (last_msg.content or "") + f"\n\n{_HARD_STOP_MSG}",
|
||||
"content": self._append_text(last_msg.content, _HARD_STOP_MSG),
|
||||
}
|
||||
)
|
||||
return {"messages": [stripped_msg]}
|
||||
|
||||
@ -29,6 +29,22 @@ _CORRECTION_PATTERNS = (
|
||||
re.compile(r"改用"),
|
||||
)
|
||||
|
||||
_REINFORCEMENT_PATTERNS = (
|
||||
re.compile(r"\byes[,.]?\s+(?:exactly|perfect|that(?:'s| is) (?:right|correct|it))\b", re.IGNORECASE),
|
||||
re.compile(r"\bperfect(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"\bexactly\s+(?:right|correct)\b", re.IGNORECASE),
|
||||
re.compile(r"\bthat(?:'s| is)\s+(?:exactly\s+)?(?:right|correct|what i (?:wanted|needed|meant))\b", re.IGNORECASE),
|
||||
re.compile(r"\bkeep\s+(?:doing\s+)?that\b", re.IGNORECASE),
|
||||
re.compile(r"\bjust\s+(?:like\s+)?(?:that|this)\b", re.IGNORECASE),
|
||||
re.compile(r"\bthis is (?:great|helpful)\b(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"\bthis is what i wanted\b(?:[.!?]|$)", re.IGNORECASE),
|
||||
re.compile(r"对[,,]?\s*就是这样(?:[。!?!?.]|$)"),
|
||||
re.compile(r"完全正确(?:[。!?!?.]|$)"),
|
||||
re.compile(r"(?:对[,,]?\s*)?就是这个意思(?:[。!?!?.]|$)"),
|
||||
re.compile(r"正是我想要的(?:[。!?!?.]|$)"),
|
||||
re.compile(r"继续保持(?:[。!?!?.]|$)"),
|
||||
)
|
||||
|
||||
|
||||
class MemoryMiddlewareState(AgentState):
|
||||
"""Compatible with the `ThreadState` schema."""
|
||||
@ -132,6 +148,29 @@ def detect_correction(messages: list[Any]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def detect_reinforcement(messages: list[Any]) -> bool:
|
||||
"""Detect explicit positive reinforcement signals in recent conversation turns.
|
||||
|
||||
Complements detect_correction() by identifying when the user confirms the
|
||||
agent's approach was correct. This allows the memory system to record what
|
||||
worked well, not just what went wrong.
|
||||
|
||||
The queue keeps only one pending context per thread, so callers pass the
|
||||
latest filtered message list. Checking only recent user turns keeps signal
|
||||
detection conservative while avoiding stale signals from long histories.
|
||||
"""
|
||||
recent_user_msgs = [msg for msg in messages[-6:] if getattr(msg, "type", None) == "human"]
|
||||
|
||||
for msg in recent_user_msgs:
|
||||
content = _extract_message_text(msg).strip()
|
||||
if not content:
|
||||
continue
|
||||
if any(pattern.search(content) for pattern in _REINFORCEMENT_PATTERNS):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
"""Middleware that queues conversation for memory update after agent execution.
|
||||
|
||||
@ -196,12 +235,14 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
|
||||
# Queue the filtered conversation for memory update
|
||||
correction_detected = detect_correction(filtered_messages)
|
||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||
queue = get_memory_queue()
|
||||
queue.add(
|
||||
thread_id=thread_id,
|
||||
messages=filtered_messages,
|
||||
agent_name=self._agent_name,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@ -101,44 +101,33 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
return user_msg if user_msg else "New Conversation"
|
||||
|
||||
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
"""Synchronously generate a title. Returns state update or None."""
|
||||
"""Generate a local fallback title without blocking on an LLM call."""
|
||||
if not self._should_generate_title(state):
|
||||
return None
|
||||
|
||||
prompt, user_msg = self._build_title_prompt(state)
|
||||
config = get_title_config()
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
|
||||
try:
|
||||
response = model.invoke(prompt)
|
||||
title = self._parse_title(response.content)
|
||||
if not title:
|
||||
title = self._fallback_title(user_msg)
|
||||
except Exception:
|
||||
logger.exception("Failed to generate title (sync)")
|
||||
title = self._fallback_title(user_msg)
|
||||
|
||||
return {"title": title}
|
||||
_, user_msg = self._build_title_prompt(state)
|
||||
return {"title": self._fallback_title(user_msg)}
|
||||
|
||||
async def _agenerate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
"""Asynchronously generate a title. Returns state update or None."""
|
||||
"""Generate a title asynchronously and fall back locally on failure."""
|
||||
if not self._should_generate_title(state):
|
||||
return None
|
||||
|
||||
prompt, user_msg = self._build_title_prompt(state)
|
||||
config = get_title_config()
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
prompt, user_msg = self._build_title_prompt(state)
|
||||
|
||||
try:
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
else:
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
response = await model.ainvoke(prompt)
|
||||
title = self._parse_title(response.content)
|
||||
if not title:
|
||||
title = self._fallback_title(user_msg)
|
||||
if title:
|
||||
return {"title": title}
|
||||
except Exception:
|
||||
logger.exception("Failed to generate title (async)")
|
||||
title = self._fallback_title(user_msg)
|
||||
|
||||
return {"title": title}
|
||||
logger.debug("Failed to generate async title; falling back to local title", exc_info=True)
|
||||
return {"title": self._fallback_title(user_msg)}
|
||||
|
||||
@override
|
||||
def after_model(self, state: TitleMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
|
||||
@ -138,6 +138,6 @@ def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentM
|
||||
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
|
||||
return _build_runtime_middlewares(
|
||||
include_uploads=False,
|
||||
include_dangling_tool_call_patch=False,
|
||||
include_dangling_tool_call_patch=True,
|
||||
lazy_init=lazy_init,
|
||||
)
|
||||
|
||||
@ -10,10 +10,52 @@ from langchain_core.messages import HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.utils.file_conversion import extract_outline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_OUTLINE_PREVIEW_LINES = 5
|
||||
|
||||
|
||||
def _extract_outline_for_file(file_path: Path) -> tuple[list[dict], list[str]]:
|
||||
"""Return the document outline and fallback preview for *file_path*.
|
||||
|
||||
Looks for a sibling ``<stem>.md`` file produced by the upload conversion
|
||||
pipeline.
|
||||
|
||||
Returns:
|
||||
(outline, preview) where:
|
||||
- outline: list of ``{title, line}`` dicts (plus optional sentinel).
|
||||
Empty when no headings are found or no .md exists.
|
||||
- preview: first few non-empty lines of the .md, used as a content
|
||||
anchor when outline is empty so the agent has some context.
|
||||
Empty when outline is non-empty (no fallback needed).
|
||||
"""
|
||||
md_path = file_path.with_suffix(".md")
|
||||
if not md_path.is_file():
|
||||
return [], []
|
||||
|
||||
outline = extract_outline(md_path)
|
||||
if outline:
|
||||
logger.debug("Extracted %d outline entries from %s", len(outline), file_path.name)
|
||||
return outline, []
|
||||
|
||||
# outline is empty — read the first few non-empty lines as a content preview
|
||||
preview: list[str] = []
|
||||
try:
|
||||
with md_path.open(encoding="utf-8") as f:
|
||||
for line in f:
|
||||
stripped = line.strip()
|
||||
if stripped:
|
||||
preview.append(stripped)
|
||||
if len(preview) >= _OUTLINE_PREVIEW_LINES:
|
||||
break
|
||||
except Exception:
|
||||
logger.debug("Failed to read preview lines from %s", md_path, exc_info=True)
|
||||
return [], preview
|
||||
|
||||
|
||||
class UploadsMiddlewareState(AgentState):
|
||||
"""State schema for uploads middleware."""
|
||||
|
||||
@ -39,12 +81,38 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
super().__init__()
|
||||
self._paths = Paths(base_dir) if base_dir else get_paths()
|
||||
|
||||
def _format_file_entry(self, file: dict, lines: list[str]) -> None:
|
||||
"""Append a single file entry (name, size, path, optional outline) to lines."""
|
||||
size_kb = file["size"] / 1024
|
||||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||||
lines.append(f"- {file['filename']} ({size_str})")
|
||||
lines.append(f" Path: {file['path']}")
|
||||
outline = file.get("outline") or []
|
||||
if outline:
|
||||
truncated = outline[-1].get("truncated", False)
|
||||
visible = [e for e in outline if not e.get("truncated")]
|
||||
lines.append(" Document outline (use `read_file` with line ranges to read sections):")
|
||||
for entry in visible:
|
||||
lines.append(f" L{entry['line']}: {entry['title']}")
|
||||
if truncated:
|
||||
lines.append(f" ... (showing first {len(visible)} headings; use `read_file` to explore further)")
|
||||
else:
|
||||
preview = file.get("outline_preview") or []
|
||||
if preview:
|
||||
lines.append(" No structural headings detected. Document begins with:")
|
||||
for text in preview:
|
||||
lines.append(f" > {text}")
|
||||
lines.append(" Use `grep` to search for keywords (e.g. `grep(pattern='keyword', path='/mnt/user-data/uploads/')`).")
|
||||
lines.append("")
|
||||
|
||||
def _create_files_message(self, new_files: list[dict], historical_files: list[dict]) -> str:
|
||||
"""Create a formatted message listing uploaded files.
|
||||
|
||||
Args:
|
||||
new_files: Files uploaded in the current message.
|
||||
historical_files: Files uploaded in previous messages.
|
||||
Each file dict may contain an optional ``outline`` key — a list of
|
||||
``{title, line}`` dicts extracted from the converted Markdown file.
|
||||
|
||||
Returns:
|
||||
Formatted string inside <uploaded_files> tags.
|
||||
@ -55,25 +123,24 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
lines.append("")
|
||||
if new_files:
|
||||
for file in new_files:
|
||||
size_kb = file["size"] / 1024
|
||||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||||
lines.append(f"- {file['filename']} ({size_str})")
|
||||
lines.append(f" Path: {file['path']}")
|
||||
lines.append("")
|
||||
self._format_file_entry(file, lines)
|
||||
else:
|
||||
lines.append("(empty)")
|
||||
lines.append("")
|
||||
|
||||
if historical_files:
|
||||
lines.append("The following files were uploaded in previous messages and are still available:")
|
||||
lines.append("")
|
||||
for file in historical_files:
|
||||
size_kb = file["size"] / 1024
|
||||
size_str = f"{size_kb:.1f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB"
|
||||
lines.append(f"- {file['filename']} ({size_str})")
|
||||
lines.append(f" Path: {file['path']}")
|
||||
lines.append("")
|
||||
self._format_file_entry(file, lines)
|
||||
|
||||
lines.append("You can read these files using the `read_file` tool with the paths shown above.")
|
||||
lines.append("To work with these files:")
|
||||
lines.append("- Read from the file first — use the outline line numbers and `read_file` to locate relevant sections.")
|
||||
lines.append("- Use `grep` to search for keywords when you are not sure which section to look at")
|
||||
lines.append(" (e.g. `grep(pattern='revenue', path='/mnt/user-data/uploads/')`).")
|
||||
lines.append("- Use `glob` to find files by name pattern")
|
||||
lines.append(" (e.g. `glob(pattern='**/*.md', path='/mnt/user-data/uploads/')`).")
|
||||
lines.append("- Only fall back to web search if the file content is clearly insufficient to answer the question.")
|
||||
lines.append("</uploaded_files>")
|
||||
|
||||
return "\n".join(lines)
|
||||
@ -147,6 +214,13 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
|
||||
# Resolve uploads directory for existence checks
|
||||
thread_id = (runtime.context or {}).get("thread_id")
|
||||
if thread_id is None:
|
||||
try:
|
||||
from langgraph.config import get_config
|
||||
|
||||
thread_id = get_config().get("configurable", {}).get("thread_id")
|
||||
except RuntimeError:
|
||||
pass # get_config() raises outside a runnable context (e.g. unit tests)
|
||||
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
|
||||
|
||||
# Get newly uploaded files from the current message's additional_kwargs.files
|
||||
@ -159,15 +233,26 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
|
||||
for file_path in sorted(uploads_dir.iterdir()):
|
||||
if file_path.is_file() and file_path.name not in new_filenames:
|
||||
stat = file_path.stat()
|
||||
outline, preview = _extract_outline_for_file(file_path)
|
||||
historical_files.append(
|
||||
{
|
||||
"filename": file_path.name,
|
||||
"size": stat.st_size,
|
||||
"path": f"/mnt/user-data/uploads/{file_path.name}",
|
||||
"extension": file_path.suffix,
|
||||
"outline": outline,
|
||||
"outline_preview": preview,
|
||||
}
|
||||
)
|
||||
|
||||
# Attach outlines to new files as well
|
||||
if uploads_dir:
|
||||
for file in new_files:
|
||||
phys_path = uploads_dir / file["filename"]
|
||||
outline, preview = _extract_outline_for_file(phys_path)
|
||||
file["outline"] = outline
|
||||
file["outline_preview"] = preview
|
||||
|
||||
if not new_files and not historical_files:
|
||||
return None
|
||||
|
||||
|
||||
@ -117,6 +117,7 @@ class DeerFlowClient:
|
||||
subagent_enabled: bool = False,
|
||||
plan_mode: bool = False,
|
||||
agent_name: str | None = None,
|
||||
available_skills: set[str] | None = None,
|
||||
middlewares: Sequence[AgentMiddleware] | None = None,
|
||||
):
|
||||
"""Initialize the client.
|
||||
@ -133,6 +134,7 @@ class DeerFlowClient:
|
||||
subagent_enabled: Enable subagent delegation.
|
||||
plan_mode: Enable TodoList middleware for plan mode.
|
||||
agent_name: Name of the agent to use.
|
||||
available_skills: Optional set of skill names to make available. If None (default), all scanned skills are available.
|
||||
middlewares: Optional list of custom middlewares to inject into the agent.
|
||||
"""
|
||||
if config_path is not None:
|
||||
@ -148,6 +150,7 @@ class DeerFlowClient:
|
||||
self._subagent_enabled = subagent_enabled
|
||||
self._plan_mode = plan_mode
|
||||
self._agent_name = agent_name
|
||||
self._available_skills = set(available_skills) if available_skills is not None else None
|
||||
self._middlewares = list(middlewares) if middlewares else []
|
||||
|
||||
# Lazy agent — created on first call, recreated when config changes.
|
||||
@ -208,6 +211,8 @@ class DeerFlowClient:
|
||||
cfg.get("thinking_enabled"),
|
||||
cfg.get("is_plan_mode"),
|
||||
cfg.get("subagent_enabled"),
|
||||
self._agent_name,
|
||||
frozenset(self._available_skills) if self._available_skills is not None else None,
|
||||
)
|
||||
|
||||
if self._agent is not None and self._agent_config_key == key:
|
||||
@ -226,6 +231,7 @@ class DeerFlowClient:
|
||||
subagent_enabled=subagent_enabled,
|
||||
max_concurrent_subagents=max_concurrent_subagents,
|
||||
agent_name=self._agent_name,
|
||||
available_skills=self._available_skills,
|
||||
),
|
||||
"state_schema": ThreadState,
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@ import uuid
|
||||
from agent_sandbox import Sandbox as AioSandboxClient
|
||||
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.search import GrepMatch, path_matches, should_ignore_path, truncate_line
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -135,6 +136,86 @@ class AioSandbox(Sandbox):
|
||||
logger.error(f"Failed to write file in sandbox: {e}")
|
||||
raise
|
||||
|
||||
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
|
||||
if not include_dirs:
|
||||
result = self._client.file.find_files(path=path, glob=pattern)
|
||||
files = result.data.files if result.data and result.data.files else []
|
||||
filtered = [file_path for file_path in files if not should_ignore_path(file_path)]
|
||||
truncated = len(filtered) > max_results
|
||||
return filtered[:max_results], truncated
|
||||
|
||||
result = self._client.file.list_path(path=path, recursive=True, show_hidden=False)
|
||||
entries = result.data.files if result.data and result.data.files else []
|
||||
matches: list[str] = []
|
||||
root_path = path.rstrip("/") or "/"
|
||||
root_prefix = root_path if root_path == "/" else f"{root_path}/"
|
||||
for entry in entries:
|
||||
if entry.path != root_path and not entry.path.startswith(root_prefix):
|
||||
continue
|
||||
if should_ignore_path(entry.path):
|
||||
continue
|
||||
rel_path = entry.path[len(root_path) :].lstrip("/")
|
||||
if path_matches(pattern, rel_path):
|
||||
matches.append(entry.path)
|
||||
if len(matches) >= max_results:
|
||||
return matches, True
|
||||
return matches, False
|
||||
|
||||
def grep(
|
||||
self,
|
||||
path: str,
|
||||
pattern: str,
|
||||
*,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = 100,
|
||||
) -> tuple[list[GrepMatch], bool]:
|
||||
import re as _re
|
||||
|
||||
regex_source = _re.escape(pattern) if literal else pattern
|
||||
# Validate the pattern locally so an invalid regex raises re.error
|
||||
# (caught by grep_tool's except re.error handler) rather than a
|
||||
# generic remote API error.
|
||||
_re.compile(regex_source, 0 if case_sensitive else _re.IGNORECASE)
|
||||
regex = regex_source if case_sensitive else f"(?i){regex_source}"
|
||||
|
||||
if glob is not None:
|
||||
find_result = self._client.file.find_files(path=path, glob=glob)
|
||||
candidate_paths = find_result.data.files if find_result.data and find_result.data.files else []
|
||||
else:
|
||||
list_result = self._client.file.list_path(path=path, recursive=True, show_hidden=False)
|
||||
entries = list_result.data.files if list_result.data and list_result.data.files else []
|
||||
candidate_paths = [entry.path for entry in entries if not entry.is_directory]
|
||||
|
||||
matches: list[GrepMatch] = []
|
||||
truncated = False
|
||||
|
||||
for file_path in candidate_paths:
|
||||
if should_ignore_path(file_path):
|
||||
continue
|
||||
|
||||
search_result = self._client.file.search_in_file(file=file_path, regex=regex)
|
||||
data = search_result.data
|
||||
if data is None:
|
||||
continue
|
||||
|
||||
line_numbers = data.line_numbers or []
|
||||
matched_lines = data.matches or []
|
||||
for line_number, line in zip(line_numbers, matched_lines):
|
||||
matches.append(
|
||||
GrepMatch(
|
||||
path=file_path,
|
||||
line_number=line_number if isinstance(line_number, int) else 0,
|
||||
line=truncate_line(line),
|
||||
)
|
||||
)
|
||||
if len(matches) >= max_results:
|
||||
truncated = True
|
||||
return matches, truncated
|
||||
|
||||
return matches, truncated
|
||||
|
||||
def update_file(self, path: str, content: bytes) -> None:
|
||||
"""Update a file with binary content in the sandbox.
|
||||
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
from pathlib import Path
|
||||
from typing import Any, Self
|
||||
|
||||
@ -10,15 +11,15 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from deerflow.config.acp_config import load_acp_config_from_dict
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.config.guardrails_config import load_guardrails_config_from_dict
|
||||
from deerflow.config.memory_config import load_memory_config_from_dict
|
||||
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
|
||||
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
|
||||
from deerflow.config.model_config import ModelConfig
|
||||
from deerflow.config.sandbox_config import SandboxConfig
|
||||
from deerflow.config.skills_config import SkillsConfig
|
||||
from deerflow.config.stream_bridge_config import StreamBridgeConfig, load_stream_bridge_config_from_dict
|
||||
from deerflow.config.subagents_config import load_subagents_config_from_dict
|
||||
from deerflow.config.summarization_config import load_summarization_config_from_dict
|
||||
from deerflow.config.title_config import load_title_config_from_dict
|
||||
from deerflow.config.subagents_config import SubagentsAppConfig, load_subagents_config_from_dict
|
||||
from deerflow.config.summarization_config import SummarizationConfig, load_summarization_config_from_dict
|
||||
from deerflow.config.title_config import TitleConfig, load_title_config_from_dict
|
||||
from deerflow.config.token_usage_config import TokenUsageConfig
|
||||
from deerflow.config.tool_config import ToolConfig, ToolGroupConfig
|
||||
from deerflow.config.tool_search_config import ToolSearchConfig, load_tool_search_config_from_dict
|
||||
@ -28,6 +29,13 @@ load_dotenv()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _default_config_candidates() -> tuple[Path, ...]:
|
||||
"""Return deterministic config.yaml locations without relying on cwd."""
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
repo_root = backend_dir.parent
|
||||
return (backend_dir / "config.yaml", repo_root / "config.yaml")
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""Config for the DeerFlow application"""
|
||||
|
||||
@ -40,6 +48,11 @@ class AppConfig(BaseModel):
|
||||
skills: SkillsConfig = Field(default_factory=SkillsConfig, description="Skills configuration")
|
||||
extensions: ExtensionsConfig = Field(default_factory=ExtensionsConfig, description="Extensions configuration (MCP servers and skills state)")
|
||||
tool_search: ToolSearchConfig = Field(default_factory=ToolSearchConfig, description="Tool search / deferred loading configuration")
|
||||
title: TitleConfig = Field(default_factory=TitleConfig, description="Automatic title generation configuration")
|
||||
summarization: SummarizationConfig = Field(default_factory=SummarizationConfig, description="Conversation summarization configuration")
|
||||
memory: MemoryConfig = Field(default_factory=MemoryConfig, description="Memory subsystem configuration")
|
||||
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||
model_config = ConfigDict(extra="allow", frozen=False)
|
||||
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
|
||||
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
|
||||
@ -51,7 +64,7 @@ class AppConfig(BaseModel):
|
||||
Priority:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, first check the `config.yaml` in the current directory, then fallback to `config.yaml` in the parent directory.
|
||||
3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`.
|
||||
"""
|
||||
if config_path:
|
||||
path = Path(config_path)
|
||||
@ -64,14 +77,10 @@ class AppConfig(BaseModel):
|
||||
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
|
||||
return path
|
||||
else:
|
||||
# Check if the config.yaml is in the current directory
|
||||
path = Path(os.getcwd()) / "config.yaml"
|
||||
if not path.exists():
|
||||
# Check if the config.yaml is in the parent directory of CWD
|
||||
path = Path(os.getcwd()).parent / "config.yaml"
|
||||
if not path.exists():
|
||||
raise FileNotFoundError("`config.yaml` file not found at the current directory nor its parent directory")
|
||||
return path
|
||||
for path in _default_config_candidates():
|
||||
if path.exists():
|
||||
return path
|
||||
raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations")
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str | None = None) -> Self:
|
||||
@ -244,6 +253,8 @@ _app_config: AppConfig | None = None
|
||||
_app_config_path: Path | None = None
|
||||
_app_config_mtime: float | None = None
|
||||
_app_config_is_custom = False
|
||||
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
|
||||
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
|
||||
|
||||
|
||||
def _get_config_mtime(config_path: Path) -> float | None:
|
||||
@ -276,6 +287,10 @@ def get_app_config() -> AppConfig:
|
||||
"""
|
||||
global _app_config, _app_config_path, _app_config_mtime
|
||||
|
||||
runtime_override = _current_app_config.get()
|
||||
if runtime_override is not None:
|
||||
return runtime_override
|
||||
|
||||
if _app_config is not None and _app_config_is_custom:
|
||||
return _app_config
|
||||
|
||||
@ -337,3 +352,26 @@ def set_app_config(config: AppConfig) -> None:
|
||||
_app_config_path = None
|
||||
_app_config_mtime = None
|
||||
_app_config_is_custom = True
|
||||
|
||||
|
||||
def peek_current_app_config() -> AppConfig | None:
|
||||
"""Return the runtime-scoped AppConfig override, if one is active."""
|
||||
return _current_app_config.get()
|
||||
|
||||
|
||||
def push_current_app_config(config: AppConfig) -> None:
|
||||
"""Push a runtime-scoped AppConfig override for the current execution context."""
|
||||
stack = _current_app_config_stack.get()
|
||||
_current_app_config_stack.set(stack + (_current_app_config.get(),))
|
||||
_current_app_config.set(config)
|
||||
|
||||
|
||||
def pop_current_app_config() -> None:
|
||||
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
|
||||
stack = _current_app_config_stack.get()
|
||||
if not stack:
|
||||
_current_app_config.set(None)
|
||||
return
|
||||
previous = stack[-1]
|
||||
_current_app_config_stack.set(stack[:-1])
|
||||
_current_app_config.set(previous)
|
||||
|
||||
@ -80,6 +80,12 @@ class ExtensionsConfig(BaseModel):
|
||||
Args:
|
||||
config_path: Optional path to extensions config file.
|
||||
|
||||
Resolution order:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_EXTENSIONS_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, search backend/repository-root defaults for
|
||||
`extensions_config.json`, then legacy `mcp_config.json`.
|
||||
|
||||
Returns:
|
||||
Path to the extensions config file if found, otherwise None.
|
||||
"""
|
||||
@ -94,24 +100,16 @@ class ExtensionsConfig(BaseModel):
|
||||
raise FileNotFoundError(f"Extensions config file specified by environment variable `DEER_FLOW_EXTENSIONS_CONFIG_PATH` not found at {path}")
|
||||
return path
|
||||
else:
|
||||
# Check if the extensions_config.json is in the current directory
|
||||
path = Path(os.getcwd()) / "extensions_config.json"
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
# Check if the extensions_config.json is in the parent directory of CWD
|
||||
path = Path(os.getcwd()).parent / "extensions_config.json"
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
# Backward compatibility: check for mcp_config.json
|
||||
path = Path(os.getcwd()) / "mcp_config.json"
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
path = Path(os.getcwd()).parent / "mcp_config.json"
|
||||
if path.exists():
|
||||
return path
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
repo_root = backend_dir.parent
|
||||
for path in (
|
||||
backend_dir / "extensions_config.json",
|
||||
repo_root / "extensions_config.json",
|
||||
backend_dir / "mcp_config.json",
|
||||
repo_root / "mcp_config.json",
|
||||
):
|
||||
if path.exists():
|
||||
return path
|
||||
|
||||
# Extensions are optional, so return None if not found
|
||||
return None
|
||||
|
||||
@ -9,6 +9,12 @@ VIRTUAL_PATH_PREFIX = "/mnt/user-data"
|
||||
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
|
||||
|
||||
def _default_local_base_dir() -> Path:
|
||||
"""Return the repo-local DeerFlow state directory without relying on cwd."""
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
return backend_dir / ".deer-flow"
|
||||
|
||||
|
||||
def _validate_thread_id(thread_id: str) -> str:
|
||||
"""Validate a thread ID before using it in filesystem paths."""
|
||||
if not _SAFE_THREAD_ID_RE.match(thread_id):
|
||||
@ -67,8 +73,7 @@ class Paths:
|
||||
BaseDir resolution (in priority order):
|
||||
1. Constructor argument `base_dir`
|
||||
2. DEER_FLOW_HOME environment variable
|
||||
3. Local dev fallback: cwd/.deer-flow (when cwd is the backend/ dir)
|
||||
4. Default: $HOME/.deer-flow
|
||||
3. Repo-local fallback derived from this module path: `{backend_dir}/.deer-flow`
|
||||
"""
|
||||
|
||||
def __init__(self, base_dir: str | Path | None = None) -> None:
|
||||
@ -104,11 +109,7 @@ class Paths:
|
||||
if env_home := os.getenv("DEER_FLOW_HOME"):
|
||||
return Path(env_home).resolve()
|
||||
|
||||
cwd = Path.cwd()
|
||||
if cwd.name == "backend" or (cwd / "pyproject.toml").exists():
|
||||
return cwd / ".deer-flow"
|
||||
|
||||
return Path.home() / ".deer-flow"
|
||||
return _default_local_base_dir()
|
||||
|
||||
@property
|
||||
def memory_file(self) -> Path:
|
||||
|
||||
@ -3,6 +3,11 @@ from pathlib import Path
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
def _default_repo_root() -> Path:
|
||||
"""Resolve the repo root without relying on the current working directory."""
|
||||
return Path(__file__).resolve().parents[5]
|
||||
|
||||
|
||||
class SkillsConfig(BaseModel):
|
||||
"""Configuration for skills system"""
|
||||
|
||||
@ -26,8 +31,8 @@ class SkillsConfig(BaseModel):
|
||||
# Use configured path (can be absolute or relative)
|
||||
path = Path(self.path)
|
||||
if not path.is_absolute():
|
||||
# If relative, resolve from current working directory
|
||||
path = Path.cwd() / path
|
||||
# If relative, resolve from the repo root for deterministic behavior.
|
||||
path = _default_repo_root() / path
|
||||
return path.resolve()
|
||||
else:
|
||||
# Default: ../skills relative to backend directory
|
||||
|
||||
@ -15,6 +15,11 @@ class SubagentOverrideConfig(BaseModel):
|
||||
ge=1,
|
||||
description="Timeout in seconds for this subagent (None = use global default)",
|
||||
)
|
||||
max_turns: int | None = Field(
|
||||
default=None,
|
||||
ge=1,
|
||||
description="Maximum turns for this subagent (None = use global or builtin default)",
|
||||
)
|
||||
|
||||
|
||||
class SubagentsAppConfig(BaseModel):
|
||||
@ -25,6 +30,11 @@ class SubagentsAppConfig(BaseModel):
|
||||
ge=1,
|
||||
description="Default timeout in seconds for all subagents (default: 900 = 15 minutes)",
|
||||
)
|
||||
max_turns: int | None = Field(
|
||||
default=None,
|
||||
ge=1,
|
||||
description="Optional default max-turn override for all subagents (None = keep builtin defaults)",
|
||||
)
|
||||
agents: dict[str, SubagentOverrideConfig] = Field(
|
||||
default_factory=dict,
|
||||
description="Per-agent configuration overrides keyed by agent name",
|
||||
@ -44,6 +54,15 @@ class SubagentsAppConfig(BaseModel):
|
||||
return override.timeout_seconds
|
||||
return self.timeout_seconds
|
||||
|
||||
def get_max_turns_for(self, agent_name: str, builtin_default: int) -> int:
|
||||
"""Get the effective max_turns for a specific agent."""
|
||||
override = self.agents.get(agent_name)
|
||||
if override is not None and override.max_turns is not None:
|
||||
return override.max_turns
|
||||
if self.max_turns is not None:
|
||||
return self.max_turns
|
||||
return builtin_default
|
||||
|
||||
|
||||
_subagents_config: SubagentsAppConfig = SubagentsAppConfig()
|
||||
|
||||
@ -58,8 +77,26 @@ def load_subagents_config_from_dict(config_dict: dict) -> None:
|
||||
global _subagents_config
|
||||
_subagents_config = SubagentsAppConfig(**config_dict)
|
||||
|
||||
overrides_summary = {name: f"{override.timeout_seconds}s" for name, override in _subagents_config.agents.items() if override.timeout_seconds is not None}
|
||||
overrides_summary = {}
|
||||
for name, override in _subagents_config.agents.items():
|
||||
parts = []
|
||||
if override.timeout_seconds is not None:
|
||||
parts.append(f"timeout={override.timeout_seconds}s")
|
||||
if override.max_turns is not None:
|
||||
parts.append(f"max_turns={override.max_turns}")
|
||||
if parts:
|
||||
overrides_summary[name] = ", ".join(parts)
|
||||
|
||||
if overrides_summary:
|
||||
logger.info(f"Subagents config loaded: default timeout={_subagents_config.timeout_seconds}s, per-agent overrides={overrides_summary}")
|
||||
logger.info(
|
||||
"Subagents config loaded: default timeout=%ss, default max_turns=%s, per-agent overrides=%s",
|
||||
_subagents_config.timeout_seconds,
|
||||
_subagents_config.max_turns,
|
||||
overrides_summary,
|
||||
)
|
||||
else:
|
||||
logger.info(f"Subagents config loaded: default timeout={_subagents_config.timeout_seconds}s, no per-agent overrides")
|
||||
logger.info(
|
||||
"Subagents config loaded: default timeout=%ss, default max_turns=%s, no per-agent overrides",
|
||||
_subagents_config.timeout_seconds,
|
||||
_subagents_config.max_turns,
|
||||
)
|
||||
|
||||
@ -25,6 +25,7 @@ class MemoryStreamBridge(StreamBridge):
|
||||
self._maxsize = queue_maxsize
|
||||
self._queues: dict[str, asyncio.Queue[StreamEvent]] = {}
|
||||
self._counters: dict[str, int] = {}
|
||||
self._dropped_counts: dict[str, int] = {}
|
||||
|
||||
# -- helpers ---------------------------------------------------------------
|
||||
|
||||
@ -32,6 +33,7 @@ class MemoryStreamBridge(StreamBridge):
|
||||
if run_id not in self._queues:
|
||||
self._queues[run_id] = asyncio.Queue(maxsize=self._maxsize)
|
||||
self._counters[run_id] = 0
|
||||
self._dropped_counts[run_id] = 0
|
||||
return self._queues[run_id]
|
||||
|
||||
def _next_id(self, run_id: str) -> str:
|
||||
@ -48,14 +50,41 @@ class MemoryStreamBridge(StreamBridge):
|
||||
try:
|
||||
await asyncio.wait_for(queue.put(entry), timeout=_PUBLISH_TIMEOUT)
|
||||
except TimeoutError:
|
||||
logger.warning("Stream bridge queue full for run %s — dropping event %s", run_id, event)
|
||||
self._dropped_counts[run_id] = self._dropped_counts.get(run_id, 0) + 1
|
||||
logger.warning(
|
||||
"Stream bridge queue full for run %s — dropping event %s (total dropped: %d)",
|
||||
run_id,
|
||||
event,
|
||||
self._dropped_counts[run_id],
|
||||
)
|
||||
|
||||
async def publish_end(self, run_id: str) -> None:
|
||||
queue = self._get_or_create_queue(run_id)
|
||||
try:
|
||||
await asyncio.wait_for(queue.put(END_SENTINEL), timeout=_PUBLISH_TIMEOUT)
|
||||
except TimeoutError:
|
||||
logger.warning("Stream bridge queue full for run %s — dropping END sentinel", run_id)
|
||||
|
||||
# END sentinel is critical — it is the only signal that allows
|
||||
# subscribers to terminate. If the queue is full we evict the
|
||||
# oldest *regular* events to make room rather than dropping END,
|
||||
# which would cause the SSE connection to hang forever and leak
|
||||
# the queue/counter resources for this run_id.
|
||||
if queue.full():
|
||||
evicted = 0
|
||||
while queue.full():
|
||||
try:
|
||||
queue.get_nowait()
|
||||
evicted += 1
|
||||
except asyncio.QueueEmpty:
|
||||
break # pragma: no cover – defensive
|
||||
if evicted:
|
||||
logger.warning(
|
||||
"Stream bridge queue full for run %s — evicted %d event(s) to guarantee END sentinel delivery",
|
||||
run_id,
|
||||
evicted,
|
||||
)
|
||||
|
||||
# After eviction the queue is guaranteed to have space, so a
|
||||
# simple non-blocking put is safe. We still use put() (which
|
||||
# blocks until space is available) as a defensive measure.
|
||||
await queue.put(END_SENTINEL)
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
@ -84,7 +113,18 @@ class MemoryStreamBridge(StreamBridge):
|
||||
await asyncio.sleep(delay)
|
||||
self._queues.pop(run_id, None)
|
||||
self._counters.pop(run_id, None)
|
||||
self._dropped_counts.pop(run_id, None)
|
||||
|
||||
async def close(self) -> None:
|
||||
self._queues.clear()
|
||||
self._counters.clear()
|
||||
self._dropped_counts.clear()
|
||||
|
||||
def dropped_count(self, run_id: str) -> int:
|
||||
"""Return the number of events dropped for *run_id*."""
|
||||
return self._dropped_counts.get(run_id, 0)
|
||||
|
||||
@property
|
||||
def dropped_total(self) -> int:
|
||||
"""Return the total number of events dropped across all runs."""
|
||||
return sum(self._dropped_counts.values())
|
||||
|
||||
@ -1,72 +1,6 @@
|
||||
import fnmatch
|
||||
from pathlib import Path
|
||||
|
||||
IGNORE_PATTERNS = [
|
||||
# Version Control
|
||||
".git",
|
||||
".svn",
|
||||
".hg",
|
||||
".bzr",
|
||||
# Dependencies
|
||||
"node_modules",
|
||||
"__pycache__",
|
||||
".venv",
|
||||
"venv",
|
||||
".env",
|
||||
"env",
|
||||
".tox",
|
||||
".nox",
|
||||
".eggs",
|
||||
"*.egg-info",
|
||||
"site-packages",
|
||||
# Build outputs
|
||||
"dist",
|
||||
"build",
|
||||
".next",
|
||||
".nuxt",
|
||||
".output",
|
||||
".turbo",
|
||||
"target",
|
||||
"out",
|
||||
# IDE & Editor
|
||||
".idea",
|
||||
".vscode",
|
||||
"*.swp",
|
||||
"*.swo",
|
||||
"*~",
|
||||
".project",
|
||||
".classpath",
|
||||
".settings",
|
||||
# OS generated
|
||||
".DS_Store",
|
||||
"Thumbs.db",
|
||||
"desktop.ini",
|
||||
"*.lnk",
|
||||
# Logs & temp files
|
||||
"*.log",
|
||||
"*.tmp",
|
||||
"*.temp",
|
||||
"*.bak",
|
||||
"*.cache",
|
||||
".cache",
|
||||
"logs",
|
||||
# Coverage & test artifacts
|
||||
".coverage",
|
||||
"coverage",
|
||||
".nyc_output",
|
||||
"htmlcov",
|
||||
".pytest_cache",
|
||||
".mypy_cache",
|
||||
".ruff_cache",
|
||||
]
|
||||
|
||||
|
||||
def _should_ignore(name: str) -> bool:
|
||||
"""Check if a file/directory name matches any ignore pattern."""
|
||||
for pattern in IGNORE_PATTERNS:
|
||||
if fnmatch.fnmatch(name, pattern):
|
||||
return True
|
||||
return False
|
||||
from deerflow.sandbox.search import should_ignore_name
|
||||
|
||||
|
||||
def list_dir(path: str, max_depth: int = 2) -> list[str]:
|
||||
@ -95,7 +29,7 @@ def list_dir(path: str, max_depth: int = 2) -> list[str]:
|
||||
|
||||
try:
|
||||
for item in current_path.iterdir():
|
||||
if _should_ignore(item.name):
|
||||
if should_ignore_name(item.name):
|
||||
continue
|
||||
|
||||
post_fix = "/" if item.is_dir() else ""
|
||||
|
||||
@ -1,11 +1,23 @@
|
||||
import errno
|
||||
import ntpath
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.sandbox.local.list_dir import list_dir
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.search import GrepMatch, find_glob_matches, find_grep_matches
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PathMapping:
|
||||
"""A path mapping from a container path to a local path with optional read-only flag."""
|
||||
|
||||
container_path: str
|
||||
local_path: str
|
||||
read_only: bool = False
|
||||
|
||||
|
||||
class LocalSandbox(Sandbox):
|
||||
@ -39,17 +51,42 @@ class LocalSandbox(Sandbox):
|
||||
|
||||
return None
|
||||
|
||||
def __init__(self, id: str, path_mappings: dict[str, str] | None = None):
|
||||
def __init__(self, id: str, path_mappings: list[PathMapping] | None = None):
|
||||
"""
|
||||
Initialize local sandbox with optional path mappings.
|
||||
|
||||
Args:
|
||||
id: Sandbox identifier
|
||||
path_mappings: Dictionary mapping container paths to local paths
|
||||
Example: {"/mnt/skills": "/absolute/path/to/skills"}
|
||||
path_mappings: List of path mappings with optional read-only flag.
|
||||
Skills directory is read-only by default.
|
||||
"""
|
||||
super().__init__(id)
|
||||
self.path_mappings = path_mappings or {}
|
||||
self.path_mappings = path_mappings or []
|
||||
|
||||
def _is_read_only_path(self, resolved_path: str) -> bool:
|
||||
"""Check if a resolved path is under a read-only mount.
|
||||
|
||||
When multiple mappings match (nested mounts), prefer the most specific
|
||||
mapping (i.e. the one whose local_path is the longest prefix of the
|
||||
resolved path), similar to how ``_resolve_path`` handles container paths.
|
||||
"""
|
||||
resolved = str(Path(resolved_path).resolve())
|
||||
|
||||
best_mapping: PathMapping | None = None
|
||||
best_prefix_len = -1
|
||||
|
||||
for mapping in self.path_mappings:
|
||||
local_resolved = str(Path(mapping.local_path).resolve())
|
||||
if resolved == local_resolved or resolved.startswith(local_resolved + os.sep):
|
||||
prefix_len = len(local_resolved)
|
||||
if prefix_len > best_prefix_len:
|
||||
best_prefix_len = prefix_len
|
||||
best_mapping = mapping
|
||||
|
||||
if best_mapping is None:
|
||||
return False
|
||||
|
||||
return best_mapping.read_only
|
||||
|
||||
def _resolve_path(self, path: str) -> str:
|
||||
"""
|
||||
@ -64,7 +101,9 @@ class LocalSandbox(Sandbox):
|
||||
path_str = str(path)
|
||||
|
||||
# Try each mapping (longest prefix first for more specific matches)
|
||||
for container_path, local_path in sorted(self.path_mappings.items(), key=lambda x: len(x[0]), reverse=True):
|
||||
for mapping in sorted(self.path_mappings, key=lambda m: len(m.container_path), reverse=True):
|
||||
container_path = mapping.container_path
|
||||
local_path = mapping.local_path
|
||||
if path_str == container_path or path_str.startswith(container_path + "/"):
|
||||
# Replace the container path prefix with local path
|
||||
relative = path_str[len(container_path) :].lstrip("/")
|
||||
@ -84,15 +123,16 @@ class LocalSandbox(Sandbox):
|
||||
Returns:
|
||||
Container path if mapping exists, otherwise original path
|
||||
"""
|
||||
path_str = str(Path(path).resolve())
|
||||
normalized_path = path.replace("\\", "/")
|
||||
path_str = str(Path(normalized_path).resolve())
|
||||
|
||||
# Try each mapping (longest local path first for more specific matches)
|
||||
for container_path, local_path in sorted(self.path_mappings.items(), key=lambda x: len(x[1]), reverse=True):
|
||||
local_path_resolved = str(Path(local_path).resolve())
|
||||
if path_str.startswith(local_path_resolved):
|
||||
for mapping in sorted(self.path_mappings, key=lambda m: len(m.local_path), reverse=True):
|
||||
local_path_resolved = str(Path(mapping.local_path).resolve())
|
||||
if path_str == local_path_resolved or path_str.startswith(local_path_resolved + "/"):
|
||||
# Replace the local path prefix with container path
|
||||
relative = path_str[len(local_path_resolved) :].lstrip("/")
|
||||
resolved = f"{container_path}/{relative}" if relative else container_path
|
||||
resolved = f"{mapping.container_path}/{relative}" if relative else mapping.container_path
|
||||
return resolved
|
||||
|
||||
# No mapping found, return original path
|
||||
@ -111,7 +151,7 @@ class LocalSandbox(Sandbox):
|
||||
import re
|
||||
|
||||
# Sort mappings by local path length (longest first) for correct prefix matching
|
||||
sorted_mappings = sorted(self.path_mappings.items(), key=lambda x: len(x[1]), reverse=True)
|
||||
sorted_mappings = sorted(self.path_mappings, key=lambda m: len(m.local_path), reverse=True)
|
||||
|
||||
if not sorted_mappings:
|
||||
return output
|
||||
@ -119,12 +159,11 @@ class LocalSandbox(Sandbox):
|
||||
# Create pattern that matches absolute paths
|
||||
# Match paths like /Users/... or other absolute paths
|
||||
result = output
|
||||
for container_path, local_path in sorted_mappings:
|
||||
local_path_resolved = str(Path(local_path).resolve())
|
||||
for mapping in sorted_mappings:
|
||||
# Escape the local path for use in regex
|
||||
escaped_local = re.escape(local_path_resolved)
|
||||
# Match the local path followed by optional path components
|
||||
pattern = re.compile(escaped_local + r"(?:/[^\s\"';&|<>()]*)?")
|
||||
escaped_local = re.escape(str(Path(mapping.local_path).resolve()))
|
||||
# Match the local path followed by optional path components with either separator
|
||||
pattern = re.compile(escaped_local + r"(?:[/\\][^\s\"';&|<>()]*)?")
|
||||
|
||||
def replace_match(match: re.Match) -> str:
|
||||
matched_path = match.group(0)
|
||||
@ -147,7 +186,7 @@ class LocalSandbox(Sandbox):
|
||||
import re
|
||||
|
||||
# Sort mappings by length (longest first) for correct prefix matching
|
||||
sorted_mappings = sorted(self.path_mappings.items(), key=lambda x: len(x[0]), reverse=True)
|
||||
sorted_mappings = sorted(self.path_mappings, key=lambda m: len(m.container_path), reverse=True)
|
||||
|
||||
# Build regex pattern to match all container paths
|
||||
# Match container path followed by optional path components
|
||||
@ -157,7 +196,7 @@ class LocalSandbox(Sandbox):
|
||||
# Create pattern that matches any of the container paths.
|
||||
# The lookahead (?=/|$|...) ensures we only match at a path-segment boundary,
|
||||
# preventing /mnt/skills from matching inside /mnt/skills-extra.
|
||||
patterns = [re.escape(container_path) + r"(?=/|$|[\s\"';&|<>()])(?:/[^\s\"';&|<>()]*)?" for container_path, _ in sorted_mappings]
|
||||
patterns = [re.escape(m.container_path) + r"(?=/|$|[\s\"';&|<>()])(?:/[^\s\"';&|<>()]*)?" for m in sorted_mappings]
|
||||
pattern = re.compile("|".join(f"({p})" for p in patterns))
|
||||
|
||||
def replace_match(match: re.Match) -> str:
|
||||
@ -248,6 +287,8 @@ class LocalSandbox(Sandbox):
|
||||
|
||||
def write_file(self, path: str, content: str, append: bool = False) -> None:
|
||||
resolved_path = self._resolve_path(path)
|
||||
if self._is_read_only_path(resolved_path):
|
||||
raise OSError(errno.EROFS, "Read-only file system", path)
|
||||
try:
|
||||
dir_path = os.path.dirname(resolved_path)
|
||||
if dir_path:
|
||||
@ -259,8 +300,43 @@ class LocalSandbox(Sandbox):
|
||||
# Re-raise with the original path for clearer error messages, hiding internal resolved paths
|
||||
raise type(e)(e.errno, e.strerror, path) from None
|
||||
|
||||
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
|
||||
resolved_path = Path(self._resolve_path(path))
|
||||
matches, truncated = find_glob_matches(resolved_path, pattern, include_dirs=include_dirs, max_results=max_results)
|
||||
return [self._reverse_resolve_path(match) for match in matches], truncated
|
||||
|
||||
def grep(
|
||||
self,
|
||||
path: str,
|
||||
pattern: str,
|
||||
*,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = 100,
|
||||
) -> tuple[list[GrepMatch], bool]:
|
||||
resolved_path = Path(self._resolve_path(path))
|
||||
matches, truncated = find_grep_matches(
|
||||
resolved_path,
|
||||
pattern,
|
||||
glob_pattern=glob,
|
||||
literal=literal,
|
||||
case_sensitive=case_sensitive,
|
||||
max_results=max_results,
|
||||
)
|
||||
return [
|
||||
GrepMatch(
|
||||
path=self._reverse_resolve_path(match.path),
|
||||
line_number=match.line_number,
|
||||
line=match.line,
|
||||
)
|
||||
for match in matches
|
||||
], truncated
|
||||
|
||||
def update_file(self, path: str, content: bytes) -> None:
|
||||
resolved_path = self._resolve_path(path)
|
||||
if self._is_read_only_path(resolved_path):
|
||||
raise OSError(errno.EROFS, "Read-only file system", path)
|
||||
try:
|
||||
dir_path = os.path.dirname(resolved_path)
|
||||
if dir_path:
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.sandbox.local.local_sandbox import LocalSandbox
|
||||
from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.sandbox_provider import SandboxProvider
|
||||
|
||||
@ -14,16 +15,17 @@ class LocalSandboxProvider(SandboxProvider):
|
||||
"""Initialize the local sandbox provider with path mappings."""
|
||||
self._path_mappings = self._setup_path_mappings()
|
||||
|
||||
def _setup_path_mappings(self) -> dict[str, str]:
|
||||
def _setup_path_mappings(self) -> list[PathMapping]:
|
||||
"""
|
||||
Setup path mappings for local sandbox.
|
||||
|
||||
Maps container paths to actual local paths, including skills directory.
|
||||
Maps container paths to actual local paths, including skills directory
|
||||
and any custom mounts configured in config.yaml.
|
||||
|
||||
Returns:
|
||||
Dictionary of path mappings
|
||||
List of path mappings
|
||||
"""
|
||||
mappings = {}
|
||||
mappings: list[PathMapping] = []
|
||||
|
||||
# Map skills container path to local skills directory
|
||||
try:
|
||||
@ -35,10 +37,63 @@ class LocalSandboxProvider(SandboxProvider):
|
||||
|
||||
# Only add mapping if skills directory exists
|
||||
if skills_path.exists():
|
||||
mappings[container_path] = str(skills_path)
|
||||
mappings.append(
|
||||
PathMapping(
|
||||
container_path=container_path,
|
||||
local_path=str(skills_path),
|
||||
read_only=True, # Skills directory is always read-only
|
||||
)
|
||||
)
|
||||
|
||||
# Map custom mounts from sandbox config
|
||||
_RESERVED_CONTAINER_PREFIXES = [container_path, "/mnt/acp-workspace", "/mnt/user-data"]
|
||||
sandbox_config = config.sandbox
|
||||
if sandbox_config and sandbox_config.mounts:
|
||||
for mount in sandbox_config.mounts:
|
||||
host_path = Path(mount.host_path)
|
||||
container_path = mount.container_path.rstrip("/") or "/"
|
||||
|
||||
if not host_path.is_absolute():
|
||||
logger.warning(
|
||||
"Mount host_path must be absolute, skipping: %s -> %s",
|
||||
mount.host_path,
|
||||
mount.container_path,
|
||||
)
|
||||
continue
|
||||
|
||||
if not container_path.startswith("/"):
|
||||
logger.warning(
|
||||
"Mount container_path must be absolute, skipping: %s -> %s",
|
||||
mount.host_path,
|
||||
mount.container_path,
|
||||
)
|
||||
continue
|
||||
|
||||
# Reject mounts that conflict with reserved container paths
|
||||
if any(container_path == p or container_path.startswith(p + "/") for p in _RESERVED_CONTAINER_PREFIXES):
|
||||
logger.warning(
|
||||
"Mount container_path conflicts with reserved prefix, skipping: %s",
|
||||
mount.container_path,
|
||||
)
|
||||
continue
|
||||
# Ensure the host path exists before adding mapping
|
||||
if host_path.exists():
|
||||
mappings.append(
|
||||
PathMapping(
|
||||
container_path=container_path,
|
||||
local_path=str(host_path.resolve()),
|
||||
read_only=mount.read_only,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Mount host_path does not exist, skipping: %s -> %s",
|
||||
mount.host_path,
|
||||
mount.container_path,
|
||||
)
|
||||
except Exception as e:
|
||||
# Log but don't fail if config loading fails
|
||||
logger.warning("Could not setup skills path mapping: %s", e, exc_info=True)
|
||||
logger.warning("Could not setup path mappings: %s", e, exc_info=True)
|
||||
|
||||
return mappings
|
||||
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from deerflow.sandbox.search import GrepMatch
|
||||
|
||||
|
||||
class Sandbox(ABC):
|
||||
"""Abstract base class for sandbox environments"""
|
||||
@ -61,6 +63,25 @@ class Sandbox(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def glob(self, path: str, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
|
||||
"""Find paths that match a glob pattern under a root directory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def grep(
|
||||
self,
|
||||
path: str,
|
||||
pattern: str,
|
||||
*,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = 100,
|
||||
) -> tuple[list[GrepMatch], bool]:
|
||||
"""Search for matches inside text files under a directory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_file(self, path: str, content: bytes) -> None:
|
||||
"""Update a file with binary content.
|
||||
|
||||
210
backend/packages/harness/deerflow/sandbox/search.py
Normal file
210
backend/packages/harness/deerflow/sandbox/search.py
Normal file
@ -0,0 +1,210 @@
|
||||
import fnmatch
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path, PurePosixPath
|
||||
|
||||
IGNORE_PATTERNS = [
|
||||
".git",
|
||||
".svn",
|
||||
".hg",
|
||||
".bzr",
|
||||
"node_modules",
|
||||
"__pycache__",
|
||||
".venv",
|
||||
"venv",
|
||||
".env",
|
||||
"env",
|
||||
".tox",
|
||||
".nox",
|
||||
".eggs",
|
||||
"*.egg-info",
|
||||
"site-packages",
|
||||
"dist",
|
||||
"build",
|
||||
".next",
|
||||
".nuxt",
|
||||
".output",
|
||||
".turbo",
|
||||
"target",
|
||||
"out",
|
||||
".idea",
|
||||
".vscode",
|
||||
"*.swp",
|
||||
"*.swo",
|
||||
"*~",
|
||||
".project",
|
||||
".classpath",
|
||||
".settings",
|
||||
".DS_Store",
|
||||
"Thumbs.db",
|
||||
"desktop.ini",
|
||||
"*.lnk",
|
||||
"*.log",
|
||||
"*.tmp",
|
||||
"*.temp",
|
||||
"*.bak",
|
||||
"*.cache",
|
||||
".cache",
|
||||
"logs",
|
||||
".coverage",
|
||||
"coverage",
|
||||
".nyc_output",
|
||||
"htmlcov",
|
||||
".pytest_cache",
|
||||
".mypy_cache",
|
||||
".ruff_cache",
|
||||
]
|
||||
|
||||
DEFAULT_MAX_FILE_SIZE_BYTES = 1_000_000
|
||||
DEFAULT_LINE_SUMMARY_LENGTH = 200
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GrepMatch:
|
||||
path: str
|
||||
line_number: int
|
||||
line: str
|
||||
|
||||
|
||||
def should_ignore_name(name: str) -> bool:
|
||||
for pattern in IGNORE_PATTERNS:
|
||||
if fnmatch.fnmatch(name, pattern):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def should_ignore_path(path: str) -> bool:
|
||||
return any(should_ignore_name(segment) for segment in path.replace("\\", "/").split("/") if segment)
|
||||
|
||||
|
||||
def path_matches(pattern: str, rel_path: str) -> bool:
|
||||
path = PurePosixPath(rel_path)
|
||||
if path.match(pattern):
|
||||
return True
|
||||
if pattern.startswith("**/"):
|
||||
return path.match(pattern[3:])
|
||||
return False
|
||||
|
||||
|
||||
def truncate_line(line: str, max_chars: int = DEFAULT_LINE_SUMMARY_LENGTH) -> str:
|
||||
line = line.rstrip("\n\r")
|
||||
if len(line) <= max_chars:
|
||||
return line
|
||||
return line[: max_chars - 3] + "..."
|
||||
|
||||
|
||||
def is_binary_file(path: Path, sample_size: int = 8192) -> bool:
|
||||
try:
|
||||
with path.open("rb") as handle:
|
||||
return b"\0" in handle.read(sample_size)
|
||||
except OSError:
|
||||
return True
|
||||
|
||||
|
||||
def find_glob_matches(root: Path, pattern: str, *, include_dirs: bool = False, max_results: int = 200) -> tuple[list[str], bool]:
|
||||
matches: list[str] = []
|
||||
truncated = False
|
||||
root = root.resolve()
|
||||
|
||||
if not root.exists():
|
||||
raise FileNotFoundError(root)
|
||||
if not root.is_dir():
|
||||
raise NotADirectoryError(root)
|
||||
|
||||
for current_root, dirs, files in os.walk(root):
|
||||
dirs[:] = [name for name in dirs if not should_ignore_name(name)]
|
||||
# root is already resolved; os.walk builds current_root by joining under root,
|
||||
# so relative_to() works without an extra stat()/resolve() per directory.
|
||||
rel_dir = Path(current_root).relative_to(root)
|
||||
|
||||
if include_dirs:
|
||||
for name in dirs:
|
||||
rel_path = (rel_dir / name).as_posix()
|
||||
if path_matches(pattern, rel_path):
|
||||
matches.append(str(Path(current_root) / name))
|
||||
if len(matches) >= max_results:
|
||||
truncated = True
|
||||
return matches, truncated
|
||||
|
||||
for name in files:
|
||||
if should_ignore_name(name):
|
||||
continue
|
||||
rel_path = (rel_dir / name).as_posix()
|
||||
if path_matches(pattern, rel_path):
|
||||
matches.append(str(Path(current_root) / name))
|
||||
if len(matches) >= max_results:
|
||||
truncated = True
|
||||
return matches, truncated
|
||||
|
||||
return matches, truncated
|
||||
|
||||
|
||||
def find_grep_matches(
|
||||
root: Path,
|
||||
pattern: str,
|
||||
*,
|
||||
glob_pattern: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = 100,
|
||||
max_file_size: int = DEFAULT_MAX_FILE_SIZE_BYTES,
|
||||
line_summary_length: int = DEFAULT_LINE_SUMMARY_LENGTH,
|
||||
) -> tuple[list[GrepMatch], bool]:
|
||||
matches: list[GrepMatch] = []
|
||||
truncated = False
|
||||
root = root.resolve()
|
||||
|
||||
if not root.exists():
|
||||
raise FileNotFoundError(root)
|
||||
if not root.is_dir():
|
||||
raise NotADirectoryError(root)
|
||||
|
||||
regex_source = re.escape(pattern) if literal else pattern
|
||||
flags = 0 if case_sensitive else re.IGNORECASE
|
||||
regex = re.compile(regex_source, flags)
|
||||
|
||||
# Skip lines longer than this to prevent ReDoS on minified / no-newline files.
|
||||
_max_line_chars = line_summary_length * 10
|
||||
|
||||
for current_root, dirs, files in os.walk(root):
|
||||
dirs[:] = [name for name in dirs if not should_ignore_name(name)]
|
||||
rel_dir = Path(current_root).relative_to(root)
|
||||
|
||||
for name in files:
|
||||
if should_ignore_name(name):
|
||||
continue
|
||||
|
||||
candidate_path = Path(current_root) / name
|
||||
rel_path = (rel_dir / name).as_posix()
|
||||
|
||||
if glob_pattern is not None and not path_matches(glob_pattern, rel_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
if candidate_path.is_symlink():
|
||||
continue
|
||||
file_path = candidate_path.resolve()
|
||||
if not file_path.is_relative_to(root):
|
||||
continue
|
||||
if file_path.stat().st_size > max_file_size or is_binary_file(file_path):
|
||||
continue
|
||||
with file_path.open(encoding="utf-8", errors="replace") as handle:
|
||||
for line_number, line in enumerate(handle, start=1):
|
||||
if len(line) > _max_line_chars:
|
||||
continue
|
||||
if regex.search(line):
|
||||
matches.append(
|
||||
GrepMatch(
|
||||
path=str(file_path),
|
||||
line_number=line_number,
|
||||
line=truncate_line(line, line_summary_length),
|
||||
)
|
||||
)
|
||||
if len(matches) >= max_results:
|
||||
truncated = True
|
||||
return matches, truncated
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
return matches, truncated
|
||||
@ -7,6 +7,7 @@ from langchain.tools import ToolRuntime, tool
|
||||
from langgraph.typing import ContextT
|
||||
|
||||
from deerflow.agents.thread_state import ThreadDataState, ThreadState
|
||||
from deerflow.config import get_app_config
|
||||
from deerflow.config.paths import VIRTUAL_PATH_PREFIX
|
||||
from deerflow.sandbox.exceptions import (
|
||||
SandboxError,
|
||||
@ -16,6 +17,7 @@ from deerflow.sandbox.exceptions import (
|
||||
from deerflow.sandbox.file_operation_lock import get_file_operation_lock
|
||||
from deerflow.sandbox.sandbox import Sandbox
|
||||
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
|
||||
from deerflow.sandbox.search import GrepMatch
|
||||
from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_bash_allowed
|
||||
|
||||
_ABSOLUTE_PATH_PATTERN = re.compile(r"(?<![:\w])(?<!:/)/(?:[^\s\"'`;&|<>()]+)")
|
||||
@ -31,6 +33,10 @@ _LOCAL_BASH_SYSTEM_PATH_PREFIXES = (
|
||||
|
||||
_DEFAULT_SKILLS_CONTAINER_PATH = "/mnt/skills"
|
||||
_ACP_WORKSPACE_VIRTUAL_PATH = "/mnt/acp-workspace"
|
||||
_DEFAULT_GLOB_MAX_RESULTS = 200
|
||||
_MAX_GLOB_MAX_RESULTS = 1000
|
||||
_DEFAULT_GREP_MAX_RESULTS = 100
|
||||
_MAX_GREP_MAX_RESULTS = 500
|
||||
|
||||
|
||||
def _get_skills_container_path() -> str:
|
||||
@ -113,6 +119,54 @@ def _is_acp_workspace_path(path: str) -> bool:
|
||||
return path == _ACP_WORKSPACE_VIRTUAL_PATH or path.startswith(f"{_ACP_WORKSPACE_VIRTUAL_PATH}/")
|
||||
|
||||
|
||||
def _get_custom_mounts():
|
||||
"""Get custom volume mounts from sandbox config.
|
||||
|
||||
Result is cached after the first successful config load. If config loading
|
||||
fails an empty list is returned *without* caching so that a later call can
|
||||
pick up the real value once the config is available.
|
||||
"""
|
||||
cached = getattr(_get_custom_mounts, "_cached", None)
|
||||
if cached is not None:
|
||||
return cached
|
||||
try:
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
config = get_app_config()
|
||||
mounts = []
|
||||
if config.sandbox and config.sandbox.mounts:
|
||||
# Only include mounts whose host_path exists, consistent with
|
||||
# LocalSandboxProvider._setup_path_mappings() which also filters
|
||||
# by host_path.exists().
|
||||
mounts = [m for m in config.sandbox.mounts if Path(m.host_path).exists()]
|
||||
_get_custom_mounts._cached = mounts # type: ignore[attr-defined]
|
||||
return mounts
|
||||
except Exception:
|
||||
# If config loading fails, return an empty list without caching so that
|
||||
# a later call can retry once the config is available.
|
||||
return []
|
||||
|
||||
|
||||
def _is_custom_mount_path(path: str) -> bool:
|
||||
"""Check if path is under a custom mount container_path."""
|
||||
for mount in _get_custom_mounts():
|
||||
if path == mount.container_path or path.startswith(f"{mount.container_path}/"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_custom_mount_for_path(path: str):
|
||||
"""Get the mount config matching this path (longest prefix first)."""
|
||||
best = None
|
||||
for mount in _get_custom_mounts():
|
||||
if path == mount.container_path or path.startswith(f"{mount.container_path}/"):
|
||||
if best is None or len(mount.container_path) > len(best.container_path):
|
||||
best = mount
|
||||
return best
|
||||
|
||||
|
||||
def _extract_thread_id_from_thread_data(thread_data: "ThreadDataState | None") -> str | None:
|
||||
"""Extract thread_id from thread_data by inspecting workspace_path.
|
||||
|
||||
@ -245,16 +299,84 @@ def _get_mcp_allowed_paths() -> list[str]:
|
||||
return allowed_paths
|
||||
|
||||
|
||||
def _get_tool_config_int(name: str, key: str, default: int) -> int:
|
||||
try:
|
||||
tool_config = get_app_config().get_tool_config(name)
|
||||
if tool_config is not None and key in tool_config.model_extra:
|
||||
value = tool_config.model_extra.get(key)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
except Exception:
|
||||
pass
|
||||
return default
|
||||
|
||||
|
||||
def _clamp_max_results(value: int, *, default: int, upper_bound: int) -> int:
|
||||
if value <= 0:
|
||||
return default
|
||||
return min(value, upper_bound)
|
||||
|
||||
|
||||
def _resolve_max_results(name: str, requested: int, *, default: int, upper_bound: int) -> int:
|
||||
requested_max_results = _clamp_max_results(requested, default=default, upper_bound=upper_bound)
|
||||
configured_max_results = _clamp_max_results(
|
||||
_get_tool_config_int(name, "max_results", default),
|
||||
default=default,
|
||||
upper_bound=upper_bound,
|
||||
)
|
||||
return min(requested_max_results, configured_max_results)
|
||||
|
||||
|
||||
def _resolve_local_read_path(path: str, thread_data: ThreadDataState) -> str:
|
||||
validate_local_tool_path(path, thread_data, read_only=True)
|
||||
if _is_skills_path(path):
|
||||
return _resolve_skills_path(path)
|
||||
if _is_acp_workspace_path(path):
|
||||
return _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
|
||||
return _resolve_and_validate_user_data_path(path, thread_data)
|
||||
|
||||
|
||||
def _format_glob_results(root_path: str, matches: list[str], truncated: bool) -> str:
|
||||
if not matches:
|
||||
return f"No files matched under {root_path}"
|
||||
|
||||
lines = [f"Found {len(matches)} paths under {root_path}"]
|
||||
if truncated:
|
||||
lines[0] += f" (showing first {len(matches)})"
|
||||
lines.extend(f"{index}. {path}" for index, path in enumerate(matches, start=1))
|
||||
if truncated:
|
||||
lines.append("Results truncated. Narrow the path or pattern to see fewer matches.")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_grep_results(root_path: str, matches: list[GrepMatch], truncated: bool) -> str:
|
||||
if not matches:
|
||||
return f"No matches found under {root_path}"
|
||||
|
||||
lines = [f"Found {len(matches)} matches under {root_path}"]
|
||||
if truncated:
|
||||
lines[0] += f" (showing first {len(matches)})"
|
||||
lines.extend(f"{match.path}:{match.line_number}: {match.line}" for match in matches)
|
||||
if truncated:
|
||||
lines.append("Results truncated. Narrow the path or add a glob filter.")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _path_variants(path: str) -> set[str]:
|
||||
return {path, path.replace("\\", "/"), path.replace("/", "\\")}
|
||||
|
||||
|
||||
def _path_separator_for_style(path: str) -> str:
|
||||
return "\\" if "\\" in path and "/" not in path else "/"
|
||||
|
||||
|
||||
def _join_path_preserving_style(base: str, relative: str) -> str:
|
||||
if not relative:
|
||||
return base
|
||||
if "/" in base and "\\" not in base:
|
||||
return f"{base.rstrip('/')}/{relative}"
|
||||
return str(Path(base) / relative)
|
||||
separator = _path_separator_for_style(base)
|
||||
normalized_relative = relative.replace("\\" if separator == "/" else "/", separator).lstrip("/\\")
|
||||
stripped_base = base.rstrip("/\\")
|
||||
return f"{stripped_base}{separator}{normalized_relative}"
|
||||
|
||||
|
||||
def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str:
|
||||
@ -299,7 +421,10 @@ def replace_virtual_path(path: str, thread_data: ThreadDataState | None) -> str:
|
||||
return actual_base
|
||||
if path.startswith(f"{virtual_base}/"):
|
||||
rest = path[len(virtual_base) :].lstrip("/")
|
||||
return _join_path_preserving_style(actual_base, rest)
|
||||
result = _join_path_preserving_style(actual_base, rest)
|
||||
if path.endswith("/") and not result.endswith(("/", "\\")):
|
||||
result += _path_separator_for_style(actual_base)
|
||||
return result
|
||||
|
||||
return path
|
||||
|
||||
@ -379,6 +504,8 @@ def mask_local_paths_in_output(output: str, thread_data: ThreadDataState | None)
|
||||
|
||||
result = pattern.sub(replace_acp, result)
|
||||
|
||||
# Custom mount host paths are masked by LocalSandbox._reverse_resolve_paths_in_output()
|
||||
|
||||
# Mask user-data host paths
|
||||
if thread_data is None:
|
||||
return result
|
||||
@ -427,6 +554,7 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
|
||||
- ``/mnt/user-data/*`` — always allowed (read + write)
|
||||
- ``/mnt/skills/*`` — allowed only when *read_only* is True
|
||||
- ``/mnt/acp-workspace/*`` — allowed only when *read_only* is True
|
||||
- Custom mount paths (from config.yaml) — respects per-mount ``read_only`` flag
|
||||
|
||||
Args:
|
||||
path: The virtual path to validate.
|
||||
@ -458,7 +586,14 @@ def validate_local_tool_path(path: str, thread_data: ThreadDataState | None, *,
|
||||
if path.startswith(f"{VIRTUAL_PATH_PREFIX}/"):
|
||||
return
|
||||
|
||||
raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/, {_get_skills_container_path()}/, or {_ACP_WORKSPACE_VIRTUAL_PATH}/ are allowed")
|
||||
# Custom mount paths — respect read_only config
|
||||
if _is_custom_mount_path(path):
|
||||
mount = _get_custom_mount_for_path(path)
|
||||
if mount and mount.read_only and not read_only:
|
||||
raise PermissionError(f"Write access to read-only mount is not allowed: {path}")
|
||||
return
|
||||
|
||||
raise PermissionError(f"Only paths under {VIRTUAL_PATH_PREFIX}/, {_get_skills_container_path()}/, {_ACP_WORKSPACE_VIRTUAL_PATH}/, or configured mount paths are allowed")
|
||||
|
||||
|
||||
def _validate_resolved_user_data_path(resolved: Path, thread_data: ThreadDataState) -> None:
|
||||
@ -508,9 +643,10 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
|
||||
boundary and must not be treated as isolation from the host filesystem.
|
||||
|
||||
In local mode, commands must use virtual paths under /mnt/user-data for
|
||||
user data access. Skills paths under /mnt/skills and ACP workspace paths
|
||||
under /mnt/acp-workspace are allowed (path-traversal checks only; write
|
||||
prevention for bash commands is not enforced here).
|
||||
user data access. Skills paths under /mnt/skills, ACP workspace paths
|
||||
under /mnt/acp-workspace, and custom mount container paths (configured in
|
||||
config.yaml) are allowed (path-traversal checks only; write prevention
|
||||
for bash commands is not enforced here).
|
||||
A small allowlist of common system path prefixes is kept for executable
|
||||
and device references (e.g. /bin/sh, /dev/null).
|
||||
"""
|
||||
@ -545,6 +681,11 @@ def validate_local_bash_command_paths(command: str, thread_data: ThreadDataState
|
||||
_reject_path_traversal(absolute_path)
|
||||
continue
|
||||
|
||||
# Allow custom mount container paths
|
||||
if _is_custom_mount_path(absolute_path):
|
||||
_reject_path_traversal(absolute_path)
|
||||
continue
|
||||
|
||||
if any(absolute_path == prefix.rstrip("/") or absolute_path.startswith(prefix) for prefix in _LOCAL_BASH_SYSTEM_PATH_PREFIXES):
|
||||
continue
|
||||
|
||||
@ -589,6 +730,8 @@ def replace_virtual_paths_in_command(command: str, thread_data: ThreadDataState
|
||||
|
||||
result = acp_pattern.sub(replace_acp_match, result)
|
||||
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_paths_in_command()
|
||||
|
||||
# Replace user-data paths
|
||||
if VIRTUAL_PATH_PREFIX in result and thread_data is not None:
|
||||
pattern = re.compile(rf"{re.escape(VIRTUAL_PATH_PREFIX)}(/[^\s\"';&|<>()]*)?")
|
||||
@ -666,7 +809,8 @@ def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = No
|
||||
if sandbox is None:
|
||||
raise SandboxNotFoundError(f"Sandbox with ID '{sandbox_id}' not found", sandbox_id=sandbox_id)
|
||||
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for downstream use
|
||||
if runtime.context is not None:
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for downstream use
|
||||
return sandbox
|
||||
|
||||
|
||||
@ -701,7 +845,8 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
|
||||
if sandbox_id is not None:
|
||||
sandbox = get_sandbox_provider().get(sandbox_id)
|
||||
if sandbox is not None:
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
|
||||
if runtime.context is not None:
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
|
||||
return sandbox
|
||||
# Sandbox was released, fall through to acquire new one
|
||||
|
||||
@ -723,7 +868,8 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non
|
||||
if sandbox is None:
|
||||
raise SandboxNotFoundError("Sandbox not found after acquisition", sandbox_id=sandbox_id)
|
||||
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
|
||||
if runtime.context is not None:
|
||||
runtime.context["sandbox_id"] = sandbox_id # Ensure sandbox_id is in context for releasing in after_agent
|
||||
return sandbox
|
||||
|
||||
|
||||
@ -885,8 +1031,9 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
|
||||
path = _resolve_skills_path(path)
|
||||
elif _is_acp_workspace_path(path):
|
||||
path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
|
||||
else:
|
||||
elif not _is_custom_mount_path(path):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_path()
|
||||
children = sandbox.list_dir(path)
|
||||
if not children:
|
||||
return "(empty)"
|
||||
@ -901,6 +1048,126 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path:
|
||||
return f"Error: Unexpected error listing directory: {_sanitize_error(e, runtime)}"
|
||||
|
||||
|
||||
@tool("glob", parse_docstring=True)
|
||||
def glob_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
description: str,
|
||||
pattern: str,
|
||||
path: str,
|
||||
include_dirs: bool = False,
|
||||
max_results: int = _DEFAULT_GLOB_MAX_RESULTS,
|
||||
) -> str:
|
||||
"""Find files or directories that match a glob pattern under a root directory.
|
||||
|
||||
Args:
|
||||
description: Explain why you are searching for these paths in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
pattern: The glob pattern to match relative to the root path, for example `**/*.py`.
|
||||
path: The **absolute** root directory to search under.
|
||||
include_dirs: Whether matching directories should also be returned. Default is False.
|
||||
max_results: Maximum number of paths to return. Default is 200.
|
||||
"""
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
requested_path = path
|
||||
effective_max_results = _resolve_max_results(
|
||||
"glob",
|
||||
max_results,
|
||||
default=_DEFAULT_GLOB_MAX_RESULTS,
|
||||
upper_bound=_MAX_GLOB_MAX_RESULTS,
|
||||
)
|
||||
thread_data = None
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
if thread_data is None:
|
||||
raise SandboxRuntimeError("Thread data not available for local sandbox")
|
||||
path = _resolve_local_read_path(path, thread_data)
|
||||
matches, truncated = sandbox.glob(path, pattern, include_dirs=include_dirs, max_results=effective_max_results)
|
||||
if thread_data is not None:
|
||||
matches = [mask_local_paths_in_output(match, thread_data) for match in matches]
|
||||
return _format_glob_results(requested_path, matches, truncated)
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
except FileNotFoundError:
|
||||
return f"Error: Directory not found: {requested_path}"
|
||||
except NotADirectoryError:
|
||||
return f"Error: Path is not a directory: {requested_path}"
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied: {requested_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error searching paths: {_sanitize_error(e, runtime)}"
|
||||
|
||||
|
||||
@tool("grep", parse_docstring=True)
|
||||
def grep_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
description: str,
|
||||
pattern: str,
|
||||
path: str,
|
||||
glob: str | None = None,
|
||||
literal: bool = False,
|
||||
case_sensitive: bool = False,
|
||||
max_results: int = _DEFAULT_GREP_MAX_RESULTS,
|
||||
) -> str:
|
||||
"""Search for matching lines inside text files under a root directory.
|
||||
|
||||
Args:
|
||||
description: Explain why you are searching file contents in short words. ALWAYS PROVIDE THIS PARAMETER FIRST.
|
||||
pattern: The string or regex pattern to search for.
|
||||
path: The **absolute** root directory to search under.
|
||||
glob: Optional glob filter for candidate files, for example `**/*.py`.
|
||||
literal: Whether to treat `pattern` as a plain string. Default is False.
|
||||
case_sensitive: Whether matching is case-sensitive. Default is False.
|
||||
max_results: Maximum number of matching lines to return. Default is 100.
|
||||
"""
|
||||
try:
|
||||
sandbox = ensure_sandbox_initialized(runtime)
|
||||
ensure_thread_directories_exist(runtime)
|
||||
requested_path = path
|
||||
effective_max_results = _resolve_max_results(
|
||||
"grep",
|
||||
max_results,
|
||||
default=_DEFAULT_GREP_MAX_RESULTS,
|
||||
upper_bound=_MAX_GREP_MAX_RESULTS,
|
||||
)
|
||||
thread_data = None
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
if thread_data is None:
|
||||
raise SandboxRuntimeError("Thread data not available for local sandbox")
|
||||
path = _resolve_local_read_path(path, thread_data)
|
||||
matches, truncated = sandbox.grep(
|
||||
path,
|
||||
pattern,
|
||||
glob=glob,
|
||||
literal=literal,
|
||||
case_sensitive=case_sensitive,
|
||||
max_results=effective_max_results,
|
||||
)
|
||||
if thread_data is not None:
|
||||
matches = [
|
||||
GrepMatch(
|
||||
path=mask_local_paths_in_output(match.path, thread_data),
|
||||
line_number=match.line_number,
|
||||
line=match.line,
|
||||
)
|
||||
for match in matches
|
||||
]
|
||||
return _format_grep_results(requested_path, matches, truncated)
|
||||
except SandboxError as e:
|
||||
return f"Error: {e}"
|
||||
except FileNotFoundError:
|
||||
return f"Error: Directory not found: {requested_path}"
|
||||
except NotADirectoryError:
|
||||
return f"Error: Path is not a directory: {requested_path}"
|
||||
except re.error as e:
|
||||
return f"Error: Invalid regex pattern: {e}"
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied: {requested_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Unexpected error searching file contents: {_sanitize_error(e, runtime)}"
|
||||
|
||||
|
||||
@tool("read_file", parse_docstring=True)
|
||||
def read_file_tool(
|
||||
runtime: ToolRuntime[ContextT, ThreadState],
|
||||
@ -928,8 +1195,9 @@ def read_file_tool(
|
||||
path = _resolve_skills_path(path)
|
||||
elif _is_acp_workspace_path(path):
|
||||
path = _resolve_acp_workspace_path(path, _extract_thread_id_from_thread_data(thread_data))
|
||||
else:
|
||||
elif not _is_custom_mount_path(path):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_path()
|
||||
content = sandbox.read_file(path)
|
||||
if not content:
|
||||
return "(empty)"
|
||||
@ -977,7 +1245,9 @@ def write_file_tool(
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
validate_local_tool_path(path, thread_data)
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
if not _is_custom_mount_path(path):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_path()
|
||||
with get_file_operation_lock(sandbox, path):
|
||||
sandbox.write_file(path, content, append)
|
||||
return "OK"
|
||||
@ -1019,7 +1289,9 @@ def str_replace_tool(
|
||||
if is_local_sandbox(runtime):
|
||||
thread_data = get_thread_data(runtime)
|
||||
validate_local_tool_path(path, thread_data)
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
if not _is_custom_mount_path(path):
|
||||
path = _resolve_and_validate_user_data_path(path, thread_data)
|
||||
# Custom mount paths are resolved by LocalSandbox._resolve_path()
|
||||
with get_file_operation_lock(sandbox, path):
|
||||
content = sandbox.read_file(path)
|
||||
if not content:
|
||||
|
||||
@ -43,5 +43,5 @@ You have access to the sandbox environment:
|
||||
tools=["bash", "ls", "read_file", "write_file", "str_replace"], # Sandbox tools only
|
||||
disallowed_tools=["task", "ask_clarification", "present_files"],
|
||||
model="inherit",
|
||||
max_turns=30,
|
||||
max_turns=60,
|
||||
)
|
||||
|
||||
@ -44,5 +44,5 @@ You have access to the same sandbox environment as the parent agent:
|
||||
tools=None, # Inherit all tools from parent
|
||||
disallowed_tools=["task", "ask_clarification", "present_files"], # Prevent nesting and clarification
|
||||
model="inherit",
|
||||
max_turns=50,
|
||||
max_turns=100,
|
||||
)
|
||||
|
||||
@ -28,9 +28,27 @@ def get_subagent_config(name: str) -> SubagentConfig | None:
|
||||
|
||||
app_config = get_subagents_app_config()
|
||||
effective_timeout = app_config.get_timeout_for(name)
|
||||
effective_max_turns = app_config.get_max_turns_for(name, config.max_turns)
|
||||
|
||||
overrides = {}
|
||||
if effective_timeout != config.timeout_seconds:
|
||||
logger.debug(f"Subagent '{name}': timeout overridden by config.yaml ({config.timeout_seconds}s -> {effective_timeout}s)")
|
||||
config = replace(config, timeout_seconds=effective_timeout)
|
||||
logger.debug(
|
||||
"Subagent '%s': timeout overridden by config.yaml (%ss -> %ss)",
|
||||
name,
|
||||
config.timeout_seconds,
|
||||
effective_timeout,
|
||||
)
|
||||
overrides["timeout_seconds"] = effective_timeout
|
||||
if effective_max_turns != config.max_turns:
|
||||
logger.debug(
|
||||
"Subagent '%s': max_turns overridden by config.yaml (%s -> %s)",
|
||||
name,
|
||||
config.max_turns,
|
||||
effective_max_turns,
|
||||
)
|
||||
overrides["max_turns"] = effective_max_turns
|
||||
if overrides:
|
||||
config = replace(config, **overrides)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@ -57,6 +57,42 @@ def _build_mcp_servers() -> dict[str, dict[str, Any]]:
|
||||
return build_servers_config(ExtensionsConfig.from_file())
|
||||
|
||||
|
||||
def _build_acp_mcp_servers() -> list[dict[str, Any]]:
|
||||
"""Build ACP ``mcpServers`` payload for ``new_session``.
|
||||
|
||||
The ACP client expects a list of server objects, while DeerFlow's MCP helper
|
||||
returns a name -> config mapping for the LangChain MCP adapter. This helper
|
||||
converts the enabled servers into the ACP wire format.
|
||||
"""
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
|
||||
extensions_config = ExtensionsConfig.from_file()
|
||||
enabled_servers = extensions_config.get_enabled_mcp_servers()
|
||||
|
||||
mcp_servers: list[dict[str, Any]] = []
|
||||
for name, server_config in enabled_servers.items():
|
||||
transport_type = server_config.type or "stdio"
|
||||
payload: dict[str, Any] = {"name": name, "type": transport_type}
|
||||
|
||||
if transport_type == "stdio":
|
||||
if not server_config.command:
|
||||
raise ValueError(f"MCP server '{name}' with stdio transport requires 'command' field")
|
||||
payload["command"] = server_config.command
|
||||
payload["args"] = server_config.args
|
||||
payload["env"] = [{"name": key, "value": value} for key, value in server_config.env.items()]
|
||||
elif transport_type in ("http", "sse"):
|
||||
if not server_config.url:
|
||||
raise ValueError(f"MCP server '{name}' with {transport_type} transport requires 'url' field")
|
||||
payload["url"] = server_config.url
|
||||
payload["headers"] = [{"name": key, "value": value} for key, value in server_config.headers.items()]
|
||||
else:
|
||||
raise ValueError(f"MCP server '{name}' has unsupported transport type: {transport_type}")
|
||||
|
||||
mcp_servers.append(payload)
|
||||
|
||||
return mcp_servers
|
||||
|
||||
|
||||
def _build_permission_response(options: list[Any], *, auto_approve: bool) -> Any:
|
||||
"""Build an ACP permission response.
|
||||
|
||||
@ -173,7 +209,15 @@ def build_invoke_acp_agent_tool(agents: dict) -> BaseTool:
|
||||
cmd = agent_config.command
|
||||
args = agent_config.args or []
|
||||
physical_cwd = _get_work_dir(thread_id)
|
||||
mcp_servers = _build_mcp_servers()
|
||||
try:
|
||||
mcp_servers = _build_acp_mcp_servers()
|
||||
except ValueError as exc:
|
||||
logger.warning(
|
||||
"Invalid MCP server configuration for ACP agent '%s'; continuing without MCP servers: %s",
|
||||
agent,
|
||||
exc,
|
||||
)
|
||||
mcp_servers = []
|
||||
agent_env: dict[str, str] | None = None
|
||||
if agent_config.env:
|
||||
agent_env = {k: (os.environ.get(v[1:], "") if v.startswith("$") else v) for k, v in agent_config.env.items()}
|
||||
|
||||
@ -1,10 +1,22 @@
|
||||
"""File conversion utilities.
|
||||
|
||||
Converts document files (PDF, PPT, Excel, Word) to Markdown using markitdown.
|
||||
Converts document files (PDF, PPT, Excel, Word) to Markdown.
|
||||
|
||||
PDF conversion strategy (auto mode):
|
||||
1. Try pymupdf4llm if installed — better heading detection, faster on most files.
|
||||
2. If output is suspiciously short (< _MIN_CHARS_PER_PAGE chars/page, or < 200 chars
|
||||
total when page count is unavailable), treat as image-based and fall back to MarkItDown.
|
||||
3. If pymupdf4llm is not installed, use MarkItDown directly (existing behaviour).
|
||||
|
||||
Large files (> ASYNC_THRESHOLD_BYTES) are converted in a thread pool via
|
||||
asyncio.to_thread() to avoid blocking the event loop (fixes #1569).
|
||||
|
||||
No FastAPI or HTTP dependencies — pure utility functions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -20,28 +32,278 @@ CONVERTIBLE_EXTENSIONS = {
|
||||
".docx",
|
||||
}
|
||||
|
||||
# Files larger than this threshold are converted in a background thread.
|
||||
# Small files complete in < 1s synchronously; spawning a thread adds unnecessary
|
||||
# scheduling overhead for them.
|
||||
_ASYNC_THRESHOLD_BYTES = 1 * 1024 * 1024 # 1 MB
|
||||
|
||||
# If pymupdf4llm produces fewer characters *per page* than this threshold,
|
||||
# the PDF is likely image-based or encrypted — fall back to MarkItDown.
|
||||
# Rationale: normal text PDFs yield 200-2000 chars/page; image-based PDFs
|
||||
# yield close to 0. 50 chars/page gives a wide safety margin.
|
||||
# Falls back to absolute 200-char check when page count is unavailable.
|
||||
_MIN_CHARS_PER_PAGE = 50
|
||||
|
||||
|
||||
def _pymupdf_output_too_sparse(text: str, file_path: Path) -> bool:
|
||||
"""Return True if pymupdf4llm output is suspiciously short (image-based PDF).
|
||||
|
||||
Uses chars-per-page rather than an absolute threshold so that both short
|
||||
documents (few pages, few chars) and long documents (many pages, many chars)
|
||||
are handled correctly.
|
||||
"""
|
||||
chars = len(text.strip())
|
||||
doc = None
|
||||
pages: int | None = None
|
||||
try:
|
||||
import pymupdf
|
||||
|
||||
doc = pymupdf.open(str(file_path))
|
||||
pages = len(doc)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
if doc is not None:
|
||||
try:
|
||||
doc.close()
|
||||
except Exception:
|
||||
pass
|
||||
if pages is not None and pages > 0:
|
||||
return (chars / pages) < _MIN_CHARS_PER_PAGE
|
||||
# Fallback: absolute threshold when page count is unavailable
|
||||
return chars < 200
|
||||
|
||||
|
||||
def _convert_pdf_with_pymupdf4llm(file_path: Path) -> str | None:
|
||||
"""Attempt PDF conversion with pymupdf4llm.
|
||||
|
||||
Returns the markdown text, or None if pymupdf4llm is not installed or
|
||||
if conversion fails (e.g. encrypted/corrupt PDF).
|
||||
"""
|
||||
try:
|
||||
import pymupdf4llm
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
try:
|
||||
return pymupdf4llm.to_markdown(str(file_path))
|
||||
except Exception:
|
||||
logger.exception("pymupdf4llm failed to convert %s; falling back to MarkItDown", file_path.name)
|
||||
return None
|
||||
|
||||
|
||||
def _convert_with_markitdown(file_path: Path) -> str:
|
||||
"""Convert any supported file to markdown text using MarkItDown."""
|
||||
from markitdown import MarkItDown
|
||||
|
||||
md = MarkItDown()
|
||||
return md.convert(str(file_path)).text_content
|
||||
|
||||
|
||||
def _do_convert(file_path: Path, pdf_converter: str) -> str:
|
||||
"""Synchronous conversion — called directly or via asyncio.to_thread.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file.
|
||||
pdf_converter: "auto" | "pymupdf4llm" | "markitdown"
|
||||
"""
|
||||
is_pdf = file_path.suffix.lower() == ".pdf"
|
||||
|
||||
if is_pdf and pdf_converter != "markitdown":
|
||||
# Try pymupdf4llm first (auto or explicit)
|
||||
pymupdf_text = _convert_pdf_with_pymupdf4llm(file_path)
|
||||
|
||||
if pymupdf_text is not None:
|
||||
# pymupdf4llm is installed
|
||||
if pdf_converter == "pymupdf4llm":
|
||||
# Explicit — use as-is regardless of output length
|
||||
return pymupdf_text
|
||||
# auto mode: fall back if output looks like a failed parse.
|
||||
# Use chars-per-page to distinguish image-based PDFs (near 0) from
|
||||
# legitimately short documents.
|
||||
if not _pymupdf_output_too_sparse(pymupdf_text, file_path):
|
||||
return pymupdf_text
|
||||
logger.warning(
|
||||
"pymupdf4llm produced only %d chars for %s (likely image-based PDF); falling back to MarkItDown",
|
||||
len(pymupdf_text.strip()),
|
||||
file_path.name,
|
||||
)
|
||||
# pymupdf4llm not installed or fallback triggered → use MarkItDown
|
||||
|
||||
return _convert_with_markitdown(file_path)
|
||||
|
||||
|
||||
async def convert_file_to_markdown(file_path: Path) -> Path | None:
|
||||
"""Convert a file to markdown using markitdown.
|
||||
"""Convert a supported document file to Markdown.
|
||||
|
||||
PDF files are handled with a two-converter strategy (see module docstring).
|
||||
Large files (> 1 MB) are offloaded to a thread pool to avoid blocking the
|
||||
event loop.
|
||||
|
||||
Args:
|
||||
file_path: Path to the file to convert.
|
||||
|
||||
Returns:
|
||||
Path to the markdown file if conversion was successful, None otherwise.
|
||||
Path to the generated .md file, or None if conversion failed.
|
||||
"""
|
||||
try:
|
||||
from markitdown import MarkItDown
|
||||
pdf_converter = _get_pdf_converter()
|
||||
file_size = file_path.stat().st_size
|
||||
|
||||
md = MarkItDown()
|
||||
result = md.convert(str(file_path))
|
||||
if file_size > _ASYNC_THRESHOLD_BYTES:
|
||||
text = await asyncio.to_thread(_do_convert, file_path, pdf_converter)
|
||||
else:
|
||||
text = _do_convert(file_path, pdf_converter)
|
||||
|
||||
# Save as .md file with same name
|
||||
md_path = file_path.with_suffix(".md")
|
||||
md_path.write_text(result.text_content, encoding="utf-8")
|
||||
md_path.write_text(text, encoding="utf-8")
|
||||
|
||||
logger.info(f"Converted {file_path.name} to markdown: {md_path.name}")
|
||||
logger.info("Converted %s to markdown: %s (%d chars)", file_path.name, md_path.name, len(text))
|
||||
return md_path
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert {file_path.name} to markdown: {e}")
|
||||
logger.error("Failed to convert %s to markdown: %s", file_path.name, e)
|
||||
return None
|
||||
|
||||
|
||||
# Regex for bold-only lines that look like section headings.
|
||||
# Targets SEC filing structural headings that pymupdf4llm renders as **bold**
|
||||
# rather than # Markdown headings (because they use same font size as body text,
|
||||
# distinguished only by bold+caps formatting).
|
||||
#
|
||||
# Pattern requires ALL of:
|
||||
# 1. Entire line is a single **...** block (no surrounding prose)
|
||||
# 2. Starts with a recognised structural keyword:
|
||||
# - ITEM / PART / SECTION (with optional number/letter after)
|
||||
# - SCHEDULE, EXHIBIT, APPENDIX, ANNEX, CHAPTER
|
||||
# All-caps addresses, boilerplate ("CURRENT REPORT", "SIGNATURES",
|
||||
# "WASHINGTON, DC 20549") do NOT start with these keywords and are excluded.
|
||||
#
|
||||
# Chinese headings (第三节...) are already captured as standard # headings
|
||||
# by pymupdf4llm, so they don't need this pattern.
|
||||
_BOLD_HEADING_RE = re.compile(r"^\*\*((ITEM|PART|SECTION|SCHEDULE|EXHIBIT|APPENDIX|ANNEX|CHAPTER)\b[A-Z0-9 .,\-]*)\*\*\s*$")
|
||||
|
||||
# Regex for split-bold headings produced by pymupdf4llm when a heading spans
|
||||
# multiple text spans in the PDF (e.g. section number and title are separate spans).
|
||||
# Matches lines like: **1** **Introduction** or **3.2** **Multi-Head Attention**
|
||||
# Requirements:
|
||||
# 1. Entire line consists only of **...** blocks separated by whitespace (no prose)
|
||||
# 2. First block is a section number (digits and dots, e.g. "1", "3.2", "A.1")
|
||||
# 3. Second block must not be purely numeric/punctuation — excludes financial table
|
||||
# headers like **2023** **2022** **2021** while allowing non-ASCII titles such as
|
||||
# **1** **概述** or accented words (negative lookahead instead of [A-Za-z])
|
||||
# 4. At most two additional blocks (four total) with [^*]+ (no * inside) to keep
|
||||
# the regex linear and avoid ReDoS on attacker-controlled content
|
||||
_SPLIT_BOLD_HEADING_RE = re.compile(r"^\*\*[\dA-Z][\d\.]*\*\*\s+\*\*(?!\d[\d\s.,\-–—/:()%]*\*\*)[^*]+\*\*(?:\s+\*\*[^*]+\*\*){0,2}\s*$")
|
||||
|
||||
# Maximum number of outline entries injected into the agent context.
|
||||
# Keeps prompt size bounded even for very long documents.
|
||||
MAX_OUTLINE_ENTRIES = 50
|
||||
|
||||
_ALLOWED_PDF_CONVERTERS = {"auto", "pymupdf4llm", "markitdown"}
|
||||
|
||||
|
||||
def _clean_bold_title(raw: str) -> str:
|
||||
"""Normalise a title string that may contain pymupdf4llm bold artefacts.
|
||||
|
||||
pymupdf4llm sometimes emits adjacent bold spans as ``**A** **B**`` instead
|
||||
of a single ``**A B**`` block. This helper merges those fragments and then
|
||||
strips the outermost ``**...**`` wrapper so the caller gets plain text.
|
||||
|
||||
Examples::
|
||||
|
||||
"**Overview**" → "Overview"
|
||||
"**UNITED STATES** **SECURITIES**" → "UNITED STATES SECURITIES"
|
||||
"plain text" → "plain text" (unchanged)
|
||||
"""
|
||||
# Merge adjacent bold spans: "** **" → " "
|
||||
merged = re.sub(r"\*\*\s*\*\*", " ", raw).strip()
|
||||
# Strip outermost **...** if the whole string is wrapped
|
||||
if m := re.fullmatch(r"\*\*(.+?)\*\*", merged, re.DOTALL):
|
||||
return m.group(1).strip()
|
||||
return merged
|
||||
|
||||
|
||||
def extract_outline(md_path: Path) -> list[dict]:
|
||||
"""Extract document outline (headings) from a Markdown file.
|
||||
|
||||
Recognises three heading styles produced by pymupdf4llm:
|
||||
|
||||
1. Standard Markdown headings: lines starting with one or more '#'.
|
||||
Inline ``**...**`` wrappers and adjacent bold spans (``** **``) are
|
||||
cleaned so the title is plain text.
|
||||
|
||||
2. Bold-only structural headings: ``**ITEM 1. BUSINESS**``, ``**PART II**``,
|
||||
etc. SEC filings use bold+caps for section headings with the same font
|
||||
size as body text, so pymupdf4llm cannot promote them to # headings.
|
||||
|
||||
3. Split-bold headings: ``**1** **Introduction**``, ``**3.2** **Attention**``.
|
||||
pymupdf4llm emits these when the section number and title text are
|
||||
separate spans in the underlying PDF (common in academic papers).
|
||||
|
||||
Args:
|
||||
md_path: Path to the .md file.
|
||||
|
||||
Returns:
|
||||
List of dicts with keys: title (str), line (int, 1-based).
|
||||
When the outline is truncated at MAX_OUTLINE_ENTRIES, a sentinel entry
|
||||
``{"truncated": True}`` is appended as the last element so callers can
|
||||
render a "showing first N headings" hint without re-scanning the file.
|
||||
Returns an empty list if the file cannot be read or has no headings.
|
||||
"""
|
||||
outline: list[dict] = []
|
||||
try:
|
||||
with md_path.open(encoding="utf-8") as f:
|
||||
for lineno, line in enumerate(f, 1):
|
||||
stripped = line.strip()
|
||||
if not stripped:
|
||||
continue
|
||||
|
||||
# Style 1: standard Markdown heading
|
||||
if stripped.startswith("#"):
|
||||
title = _clean_bold_title(stripped.lstrip("#").strip())
|
||||
if title:
|
||||
outline.append({"title": title, "line": lineno})
|
||||
|
||||
# Style 2: single bold block with SEC structural keyword
|
||||
elif m := _BOLD_HEADING_RE.match(stripped):
|
||||
title = m.group(1).strip()
|
||||
if title:
|
||||
outline.append({"title": title, "line": lineno})
|
||||
|
||||
# Style 3: split-bold heading — **<num>** **<title>**
|
||||
# Regex already enforces max 4 blocks and non-numeric second block.
|
||||
elif _SPLIT_BOLD_HEADING_RE.match(stripped):
|
||||
title = " ".join(re.findall(r"\*\*([^*]+)\*\*", stripped))
|
||||
if title:
|
||||
outline.append({"title": title, "line": lineno})
|
||||
|
||||
if len(outline) >= MAX_OUTLINE_ENTRIES:
|
||||
outline.append({"truncated": True})
|
||||
break
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
return outline
|
||||
|
||||
|
||||
def _get_pdf_converter() -> str:
|
||||
"""Read pdf_converter setting from app config, defaulting to 'auto'.
|
||||
|
||||
Normalizes the value to lowercase and validates it against the allowed set
|
||||
so that values like 'AUTO' or 'MarkItDown' from config.yaml don't silently
|
||||
fall through to unexpected behaviour.
|
||||
"""
|
||||
try:
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
cfg = get_app_config()
|
||||
uploads_cfg = getattr(cfg, "uploads", None)
|
||||
if uploads_cfg is not None:
|
||||
raw = str(getattr(uploads_cfg, "pdf_converter", "auto")).strip().lower()
|
||||
if raw not in _ALLOWED_PDF_CONVERTERS:
|
||||
logger.warning("Invalid pdf_converter value %r; falling back to 'auto'", raw)
|
||||
return "auto"
|
||||
return raw
|
||||
except Exception:
|
||||
pass
|
||||
return "auto"
|
||||
|
||||
@ -9,16 +9,17 @@ dependencies = [
|
||||
"dotenv>=0.9.9",
|
||||
"httpx>=0.28.0",
|
||||
"kubernetes>=30.0.0",
|
||||
"langchain>=1.2.3",
|
||||
"langchain>=1.2.3,<1.2.10",
|
||||
"langchain-anthropic>=1.3.4",
|
||||
"langchain-deepseek>=1.0.1",
|
||||
"langchain-mcp-adapters>=0.1.0",
|
||||
"langchain-openai>=1.1.7",
|
||||
"langfuse>=3.4.1",
|
||||
"langgraph>=1.0.6,<1.0.10",
|
||||
"langgraph-prebuilt>=1.0.6,<1.0.9",
|
||||
"langgraph-api>=0.7.0,<0.8.0",
|
||||
"langgraph-cli>=0.4.14",
|
||||
"langgraph-runtime-inmem>=0.22.1",
|
||||
"langgraph-runtime-inmem>=0.22.1,<0.27.0",
|
||||
"markdownify>=1.2.2",
|
||||
"markitdown[all,xlsx]>=0.0.1a2",
|
||||
"pydantic>=2.12.5",
|
||||
@ -34,6 +35,9 @@ dependencies = [
|
||||
"langgraph-sdk>=0.1.51",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
pymupdf = ["pymupdf4llm>=0.0.17"]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
@ -16,6 +16,10 @@ dependencies = [
|
||||
"python-telegram-bot>=21.0",
|
||||
"langgraph-sdk>=0.1.51",
|
||||
"markdown-to-mrkdwn>=0.3.1",
|
||||
"wecom-aibot-python-sdk>=0.1.6",
|
||||
"bcrypt>=4.0.0",
|
||||
"pyjwt>=2.9.0",
|
||||
"email-validator>=2.0.0",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
|
||||
506
backend/tests/test_auth.py
Normal file
506
backend/tests/test_auth.py
Normal file
@ -0,0 +1,506 @@
|
||||
"""Tests for authentication module: JWT, password hashing, AuthContext, and authz decorators."""
|
||||
|
||||
from datetime import timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.auth import create_access_token, decode_token, hash_password, verify_password
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.authz import (
|
||||
AuthContext,
|
||||
Permissions,
|
||||
get_auth_context,
|
||||
require_auth,
|
||||
require_permission,
|
||||
)
|
||||
|
||||
# ── Password Hashing ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_hash_password_and_verify():
|
||||
"""Hashing and verification round-trip."""
|
||||
password = "s3cr3tP@ssw0rd!"
|
||||
hashed = hash_password(password)
|
||||
assert hashed != password
|
||||
assert verify_password(password, hashed) is True
|
||||
assert verify_password("wrongpassword", hashed) is False
|
||||
|
||||
|
||||
def test_hash_password_different_each_time():
|
||||
"""bcrypt generates unique salts, so same password has different hashes."""
|
||||
password = "testpassword"
|
||||
h1 = hash_password(password)
|
||||
h2 = hash_password(password)
|
||||
assert h1 != h2 # Different salts
|
||||
# But both verify correctly
|
||||
assert verify_password(password, h1) is True
|
||||
assert verify_password(password, h2) is True
|
||||
|
||||
|
||||
def test_verify_password_rejects_empty():
|
||||
"""Empty password should not verify."""
|
||||
hashed = hash_password("nonempty")
|
||||
assert verify_password("", hashed) is False
|
||||
|
||||
|
||||
# ── JWT ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_create_and_decode_token():
|
||||
"""JWT creation and decoding round-trip."""
|
||||
user_id = str(uuid4())
|
||||
# Set a valid JWT secret for this test
|
||||
import os
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(user_id)
|
||||
assert isinstance(token, str)
|
||||
|
||||
payload = decode_token(token)
|
||||
assert payload is not None
|
||||
assert payload.sub == user_id
|
||||
|
||||
|
||||
def test_decode_token_expired():
|
||||
"""Expired token returns TokenError.EXPIRED."""
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
user_id = str(uuid4())
|
||||
# Create token that expires immediately
|
||||
token = create_access_token(user_id, expires_delta=timedelta(seconds=-1))
|
||||
payload = decode_token(token)
|
||||
assert payload == TokenError.EXPIRED
|
||||
|
||||
|
||||
def test_decode_token_invalid():
|
||||
"""Invalid token returns TokenError."""
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
assert isinstance(decode_token("not.a.valid.token"), TokenError)
|
||||
assert isinstance(decode_token(""), TokenError)
|
||||
assert isinstance(decode_token("completely-wrong"), TokenError)
|
||||
|
||||
|
||||
def test_create_token_custom_expiry():
|
||||
"""Custom expiry is respected."""
|
||||
user_id = str(uuid4())
|
||||
token = create_access_token(user_id, expires_delta=timedelta(hours=1))
|
||||
payload = decode_token(token)
|
||||
assert payload is not None
|
||||
assert payload.sub == user_id
|
||||
|
||||
|
||||
# ── AuthContext ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_context_unauthenticated():
|
||||
"""AuthContext with no user."""
|
||||
ctx = AuthContext(user=None, permissions=[])
|
||||
assert ctx.is_authenticated is False
|
||||
assert ctx.has_permission("threads", "read") is False
|
||||
|
||||
|
||||
def test_auth_context_authenticated_no_perms():
|
||||
"""AuthContext with user but no permissions."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[])
|
||||
assert ctx.is_authenticated is True
|
||||
assert ctx.has_permission("threads", "read") is False
|
||||
|
||||
|
||||
def test_auth_context_has_permission():
|
||||
"""AuthContext permission checking."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
perms = [Permissions.THREADS_READ, Permissions.THREADS_WRITE]
|
||||
ctx = AuthContext(user=user, permissions=perms)
|
||||
assert ctx.has_permission("threads", "read") is True
|
||||
assert ctx.has_permission("threads", "write") is True
|
||||
assert ctx.has_permission("threads", "delete") is False
|
||||
assert ctx.has_permission("runs", "read") is False
|
||||
|
||||
|
||||
def test_auth_context_require_user_raises():
|
||||
"""require_user raises 401 when not authenticated."""
|
||||
ctx = AuthContext(user=None, permissions=[])
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
ctx.require_user()
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
def test_auth_context_require_user_returns_user():
|
||||
"""require_user returns user when authenticated."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[])
|
||||
returned = ctx.require_user()
|
||||
assert returned == user
|
||||
|
||||
|
||||
# ── get_auth_context helper ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_auth_context_not_set():
|
||||
"""get_auth_context returns None when auth not set on request."""
|
||||
mock_request = MagicMock()
|
||||
# Make getattr return None (simulating attribute not set)
|
||||
mock_request.state = MagicMock()
|
||||
del mock_request.state.auth
|
||||
assert get_auth_context(mock_request) is None
|
||||
|
||||
|
||||
def test_get_auth_context_set():
|
||||
"""get_auth_context returns the AuthContext from request."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.auth = ctx
|
||||
|
||||
assert get_auth_context(mock_request) == ctx
|
||||
|
||||
|
||||
# ── require_auth decorator ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_require_auth_sets_auth_context():
|
||||
"""require_auth sets auth context on request from cookie."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
@require_auth
|
||||
async def endpoint(request: Request):
|
||||
ctx = get_auth_context(request)
|
||||
return {"authenticated": ctx.is_authenticated}
|
||||
|
||||
with TestClient(app) as client:
|
||||
# No cookie → anonymous
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["authenticated"] is False
|
||||
|
||||
|
||||
def test_require_auth_requires_request_param():
|
||||
"""require_auth raises ValueError if request parameter is missing."""
|
||||
import asyncio
|
||||
|
||||
@require_auth
|
||||
async def bad_endpoint(): # Missing `request` parameter
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError, match="require_auth decorator requires 'request' parameter"):
|
||||
asyncio.run(bad_endpoint())
|
||||
|
||||
|
||||
# ── require_permission decorator ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_require_permission_requires_auth():
|
||||
"""require_permission raises 401 when not authenticated."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
@require_permission("threads", "read")
|
||||
async def endpoint(request: Request):
|
||||
return {"ok": True}
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Authentication required" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_require_permission_denies_wrong_permission():
|
||||
"""User without required permission gets 403."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
|
||||
@app.get("/test")
|
||||
@require_permission("threads", "delete")
|
||||
async def endpoint(request: Request):
|
||||
return {"ok": True}
|
||||
|
||||
mock_auth = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
|
||||
|
||||
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 403
|
||||
assert "Permission denied" in response.json()["detail"]
|
||||
|
||||
|
||||
# ── Weak JWT secret warning ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ── User Model Fields ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_user_model_has_needs_setup_default_false():
|
||||
"""New users default to needs_setup=False."""
|
||||
user = User(email="test@example.com", password_hash="hash")
|
||||
assert user.needs_setup is False
|
||||
|
||||
|
||||
def test_user_model_has_token_version_default_zero():
|
||||
"""New users default to token_version=0."""
|
||||
user = User(email="test@example.com", password_hash="hash")
|
||||
assert user.token_version == 0
|
||||
|
||||
|
||||
def test_user_model_needs_setup_true():
|
||||
"""Auto-created admin has needs_setup=True."""
|
||||
user = User(email="admin@example.com", password_hash="hash", needs_setup=True)
|
||||
assert user.needs_setup is True
|
||||
|
||||
|
||||
def test_sqlite_round_trip_new_fields():
|
||||
"""needs_setup and token_version survive create → read round-trip."""
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from app.gateway.auth.repositories import sqlite as sqlite_mod
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = os.path.join(tmpdir, "test_users.db")
|
||||
old_path = sqlite_mod._resolved_db_path
|
||||
old_init = sqlite_mod._table_initialized
|
||||
sqlite_mod._resolved_db_path = Path(db_path)
|
||||
sqlite_mod._table_initialized = False
|
||||
try:
|
||||
repo = sqlite_mod.SQLiteUserRepository()
|
||||
user = User(
|
||||
email="setup@test.com",
|
||||
password_hash="fakehash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=3,
|
||||
)
|
||||
created = asyncio.run(repo.create_user(user))
|
||||
assert created.needs_setup is True
|
||||
assert created.token_version == 3
|
||||
|
||||
fetched = asyncio.run(repo.get_user_by_email("setup@test.com"))
|
||||
assert fetched is not None
|
||||
assert fetched.needs_setup is True
|
||||
assert fetched.token_version == 3
|
||||
|
||||
fetched.needs_setup = False
|
||||
fetched.token_version = 4
|
||||
asyncio.run(repo.update_user(fetched))
|
||||
refetched = asyncio.run(repo.get_user_by_id(str(fetched.id)))
|
||||
assert refetched.needs_setup is False
|
||||
assert refetched.token_version == 4
|
||||
finally:
|
||||
sqlite_mod._resolved_db_path = old_path
|
||||
sqlite_mod._table_initialized = old_init
|
||||
|
||||
|
||||
# ── Token Versioning ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_jwt_encodes_ver():
|
||||
"""JWT payload includes ver field."""
|
||||
import os
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(str(uuid4()), token_version=3)
|
||||
payload = decode_token(token)
|
||||
assert not isinstance(payload, TokenError)
|
||||
assert payload.ver == 3
|
||||
|
||||
|
||||
def test_jwt_default_ver_zero():
|
||||
"""JWT ver defaults to 0."""
|
||||
import os
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(str(uuid4()))
|
||||
payload = decode_token(token)
|
||||
assert not isinstance(payload, TokenError)
|
||||
assert payload.ver == 0
|
||||
|
||||
|
||||
def test_token_version_mismatch_rejects():
|
||||
"""Token with stale ver is rejected by get_current_user_from_request."""
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
|
||||
user_id = str(uuid4())
|
||||
token = create_access_token(user_id, token_version=0)
|
||||
|
||||
mock_user = User(id=user_id, email="test@example.com", password_hash="hash", token_version=1)
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.cookies = {"access_token": token}
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider") as mock_provider_fn:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_user = AsyncMock(return_value=mock_user)
|
||||
mock_provider_fn.return_value = mock_provider
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "revoked" in str(exc_info.value.detail).lower()
|
||||
|
||||
|
||||
# ── change-password extension ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_change_password_request_accepts_new_email():
|
||||
"""ChangePasswordRequest model accepts optional new_email."""
|
||||
from app.gateway.routers.auth import ChangePasswordRequest
|
||||
|
||||
req = ChangePasswordRequest(
|
||||
current_password="old",
|
||||
new_password="newpassword",
|
||||
new_email="new@example.com",
|
||||
)
|
||||
assert req.new_email == "new@example.com"
|
||||
|
||||
|
||||
def test_change_password_request_new_email_optional():
|
||||
"""ChangePasswordRequest model works without new_email."""
|
||||
from app.gateway.routers.auth import ChangePasswordRequest
|
||||
|
||||
req = ChangePasswordRequest(current_password="old", new_password="newpassword")
|
||||
assert req.new_email is None
|
||||
|
||||
|
||||
def test_login_response_includes_needs_setup():
|
||||
"""LoginResponse includes needs_setup field."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
resp = LoginResponse(expires_in=3600, needs_setup=True)
|
||||
assert resp.needs_setup is True
|
||||
resp2 = LoginResponse(expires_in=3600)
|
||||
assert resp2.needs_setup is False
|
||||
|
||||
|
||||
# ── Rate Limiting ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_rate_limiter_allows_under_limit():
|
||||
"""Requests under the limit are allowed."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts
|
||||
|
||||
_login_attempts.clear()
|
||||
_check_rate_limit("192.168.1.1") # Should not raise
|
||||
|
||||
|
||||
def test_rate_limiter_blocks_after_max_failures():
|
||||
"""IP is blocked after 5 consecutive failures."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure
|
||||
|
||||
_login_attempts.clear()
|
||||
ip = "10.0.0.1"
|
||||
for _ in range(5):
|
||||
_record_login_failure(ip)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
_check_rate_limit(ip)
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
|
||||
def test_rate_limiter_resets_on_success():
|
||||
"""Successful login clears the failure counter."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure, _record_login_success
|
||||
|
||||
_login_attempts.clear()
|
||||
ip = "10.0.0.2"
|
||||
for _ in range(4):
|
||||
_record_login_failure(ip)
|
||||
_record_login_success(ip)
|
||||
_check_rate_limit(ip) # Should not raise
|
||||
|
||||
|
||||
# ── Client IP extraction ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_client_ip_direct_connection():
|
||||
"""Without nginx (no X-Real-IP), falls back to request.client.host."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "203.0.113.42"
|
||||
req.headers = {}
|
||||
assert _get_client_ip(req) == "203.0.113.42"
|
||||
|
||||
|
||||
def test_get_client_ip_uses_x_real_ip():
|
||||
"""X-Real-IP (set by nginx) is used when present."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "10.0.0.1" # uvicorn may have replaced this with XFF[0]
|
||||
req.headers = {"x-real-ip": "203.0.113.42"}
|
||||
assert _get_client_ip(req) == "203.0.113.42"
|
||||
|
||||
|
||||
def test_get_client_ip_xff_ignored():
|
||||
"""X-Forwarded-For is never used; only X-Real-IP matters."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "10.0.0.1"
|
||||
req.headers = {"x-forwarded-for": "10.0.0.1, 198.51.100.5", "x-real-ip": "198.51.100.5"}
|
||||
assert _get_client_ip(req) == "198.51.100.5"
|
||||
|
||||
|
||||
def test_get_client_ip_no_real_ip_fallback():
|
||||
"""No X-Real-IP → falls back to client.host (direct connection)."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "127.0.0.1"
|
||||
req.headers = {}
|
||||
assert _get_client_ip(req) == "127.0.0.1"
|
||||
|
||||
|
||||
def test_get_client_ip_x_real_ip_always_preferred():
|
||||
"""X-Real-IP is always preferred over client.host regardless of IP."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "203.0.113.99"
|
||||
req.headers = {"x-real-ip": "198.51.100.7"}
|
||||
assert _get_client_ip(req) == "198.51.100.7"
|
||||
|
||||
|
||||
# ── Weak JWT secret warning ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_missing_jwt_secret_generates_ephemeral(monkeypatch, caplog):
|
||||
"""get_auth_config() auto-generates an ephemeral secret when AUTH_JWT_SECRET is unset."""
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as config_module
|
||||
|
||||
config_module._auth_config = None
|
||||
monkeypatch.delenv("AUTH_JWT_SECRET", raising=False)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = config_module.get_auth_config()
|
||||
|
||||
assert config.jwt_secret # non-empty ephemeral secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
|
||||
# Cleanup
|
||||
config_module._auth_config = None
|
||||
54
backend/tests/test_auth_config.py
Normal file
54
backend/tests/test_auth_config.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""Tests for AuthConfig typed configuration."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.gateway.auth.config import AuthConfig
|
||||
|
||||
|
||||
def test_auth_config_defaults():
|
||||
config = AuthConfig(jwt_secret="test-secret-key-123")
|
||||
assert config.token_expiry_days == 7
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_range():
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=1)
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=30)
|
||||
with pytest.raises(Exception):
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=0)
|
||||
with pytest.raises(Exception):
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=31)
|
||||
|
||||
|
||||
def test_auth_config_from_env():
|
||||
env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret == "test-jwt-secret-from-env"
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
|
||||
|
||||
def test_auth_config_missing_secret_generates_ephemeral(caplog):
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
75
backend/tests/test_auth_errors.py
Normal file
75
backend/tests/test_auth_errors.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""Tests for auth error types and typed decode_token."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.gateway.auth.jwt import create_access_token, decode_token
|
||||
|
||||
|
||||
def test_auth_error_code_values():
|
||||
assert AuthErrorCode.INVALID_CREDENTIALS == "invalid_credentials"
|
||||
assert AuthErrorCode.TOKEN_EXPIRED == "token_expired"
|
||||
assert AuthErrorCode.NOT_AUTHENTICATED == "not_authenticated"
|
||||
|
||||
|
||||
def test_token_error_values():
|
||||
assert TokenError.EXPIRED == "expired"
|
||||
assert TokenError.INVALID_SIGNATURE == "invalid_signature"
|
||||
assert TokenError.MALFORMED == "malformed"
|
||||
|
||||
|
||||
def test_auth_error_response_serialization():
|
||||
err = AuthErrorResponse(
|
||||
code=AuthErrorCode.TOKEN_EXPIRED,
|
||||
message="Token has expired",
|
||||
)
|
||||
d = err.model_dump()
|
||||
assert d == {"code": "token_expired", "message": "Token has expired"}
|
||||
|
||||
|
||||
def test_auth_error_response_from_dict():
|
||||
d = {"code": "invalid_credentials", "message": "Wrong password"}
|
||||
err = AuthErrorResponse(**d)
|
||||
assert err.code == AuthErrorCode.INVALID_CREDENTIALS
|
||||
|
||||
|
||||
# ── decode_token typed failure tests ──────────────────────────────
|
||||
|
||||
_TEST_SECRET = "test-secret-for-jwt-decode-token-tests"
|
||||
|
||||
|
||||
def _setup_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
|
||||
|
||||
def test_decode_token_returns_token_error_on_expired():
|
||||
_setup_config()
|
||||
expired_payload = {"sub": "user-1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired_payload, _TEST_SECRET, algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.EXPIRED
|
||||
|
||||
|
||||
def test_decode_token_returns_token_error_on_bad_signature():
|
||||
_setup_config()
|
||||
payload = {"sub": "user-1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.INVALID_SIGNATURE
|
||||
|
||||
|
||||
def test_decode_token_returns_token_error_on_malformed():
|
||||
_setup_config()
|
||||
result = decode_token("not-a-jwt")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
|
||||
def test_decode_token_returns_payload_on_valid():
|
||||
_setup_config()
|
||||
token = create_access_token("user-123")
|
||||
result = decode_token(token)
|
||||
assert not isinstance(result, TokenError)
|
||||
assert result.sub == "user-123"
|
||||
216
backend/tests/test_auth_middleware.py
Normal file
216
backend/tests/test_auth_middleware.py
Normal file
@ -0,0 +1,216 @@
|
||||
"""Tests for the global AuthMiddleware (fail-closed safety net)."""
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from app.gateway.auth_middleware import AuthMiddleware, _is_public
|
||||
|
||||
# ── _is_public unit tests ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/health",
|
||||
"/health/",
|
||||
"/docs",
|
||||
"/docs/",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/setup-status",
|
||||
],
|
||||
)
|
||||
def test_public_paths(path: str):
|
||||
assert _is_public(path) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/api/models",
|
||||
"/api/mcp/config",
|
||||
"/api/memory",
|
||||
"/api/skills",
|
||||
"/api/threads/123",
|
||||
"/api/threads/123/uploads",
|
||||
"/api/agents",
|
||||
"/api/channels",
|
||||
"/api/runs/stream",
|
||||
"/api/threads/123/runs",
|
||||
"/api/v1/auth/me",
|
||||
"/api/v1/auth/change-password",
|
||||
],
|
||||
)
|
||||
def test_protected_paths(path: str):
|
||||
assert _is_public(path) is False
|
||||
|
||||
|
||||
# ── Trailing slash / normalization edge cases ─────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/api/v1/auth/login/local/",
|
||||
"/api/v1/auth/register/",
|
||||
"/api/v1/auth/logout/",
|
||||
"/api/v1/auth/setup-status/",
|
||||
],
|
||||
)
|
||||
def test_public_auth_paths_with_trailing_slash(path: str):
|
||||
assert _is_public(path) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/api/models/",
|
||||
"/api/v1/auth/me/",
|
||||
"/api/v1/auth/change-password/",
|
||||
],
|
||||
)
|
||||
def test_protected_paths_with_trailing_slash(path: str):
|
||||
assert _is_public(path) is False
|
||||
|
||||
|
||||
def test_unknown_api_path_is_protected():
|
||||
"""Fail-closed: any new /api/* path is protected by default."""
|
||||
assert _is_public("/api/new-feature") is False
|
||||
assert _is_public("/api/v2/something") is False
|
||||
assert _is_public("/api/v1/auth/new-endpoint") is False
|
||||
|
||||
|
||||
# ── Middleware integration tests ──────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_app():
|
||||
"""Create a minimal FastAPI app with AuthMiddleware for testing."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/v1/auth/me")
|
||||
async def auth_me():
|
||||
return {"id": "1", "email": "test@test.com"}
|
||||
|
||||
@app.get("/api/v1/auth/setup-status")
|
||||
async def setup_status():
|
||||
return {"needs_setup": False}
|
||||
|
||||
@app.get("/api/models")
|
||||
async def models_get():
|
||||
return {"models": []}
|
||||
|
||||
@app.put("/api/mcp/config")
|
||||
async def mcp_put():
|
||||
return {"ok": True}
|
||||
|
||||
@app.delete("/api/threads/abc")
|
||||
async def thread_delete():
|
||||
return {"ok": True}
|
||||
|
||||
@app.patch("/api/threads/abc")
|
||||
async def thread_patch():
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/threads/abc/runs/stream")
|
||||
async def stream():
|
||||
return {"ok": True}
|
||||
|
||||
@app.get("/api/future-endpoint")
|
||||
async def future():
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return TestClient(_make_app())
|
||||
|
||||
|
||||
def test_public_path_no_cookie(client):
|
||||
res = client.get("/health")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_public_auth_path_no_cookie(client):
|
||||
"""Public auth endpoints (login/register) pass without cookie."""
|
||||
res = client.get("/api/v1/auth/setup-status")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_protected_auth_path_no_cookie(client):
|
||||
"""/auth/me requires cookie even though it's under /api/v1/auth/."""
|
||||
res = client.get("/api/v1/auth/me")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_path_no_cookie_returns_401(client):
|
||||
res = client.get("/api/models")
|
||||
assert res.status_code == 401
|
||||
body = res.json()
|
||||
assert body["detail"]["code"] == "not_authenticated"
|
||||
|
||||
|
||||
def test_protected_path_with_cookie_passes(client):
|
||||
res = client.get("/api/models", cookies={"access_token": "some-token"})
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_protected_post_no_cookie_returns_401(client):
|
||||
res = client.post("/api/threads/abc/runs/stream")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
# ── Method matrix: PUT/DELETE/PATCH also protected ────────────────────────
|
||||
|
||||
|
||||
def test_protected_put_no_cookie(client):
|
||||
res = client.put("/api/mcp/config")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_delete_no_cookie(client):
|
||||
res = client.delete("/api/threads/abc")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_patch_no_cookie(client):
|
||||
res = client.patch("/api/threads/abc")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_put_with_cookie_passes(client):
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.put("/api/mcp/config")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_delete_with_cookie_passes(client):
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.delete("/api/threads/abc")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
# ── Fail-closed: unknown future endpoints ─────────────────────────────────
|
||||
|
||||
|
||||
def test_unknown_endpoint_no_cookie_returns_401(client):
|
||||
"""Any new /api/* endpoint is blocked by default without cookie."""
|
||||
res = client.get("/api/future-endpoint")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_unknown_endpoint_with_cookie_passes(client):
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.get("/api/future-endpoint")
|
||||
assert res.status_code == 200
|
||||
675
backend/tests/test_auth_type_system.py
Normal file
675
backend/tests/test_auth_type_system.py
Normal file
@ -0,0 +1,675 @@
|
||||
"""Tests for auth type system hardening.
|
||||
|
||||
Covers structured error responses, typed decode_token callers,
|
||||
CSRF middleware path matching, config-driven cookie security,
|
||||
and unhappy paths / edge cases for all auth boundaries.
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import jwt as pyjwt
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.gateway.auth.jwt import decode_token
|
||||
from app.gateway.csrf_middleware import (
|
||||
CSRF_COOKIE_NAME,
|
||||
CSRF_HEADER_NAME,
|
||||
CSRFMiddleware,
|
||||
is_auth_endpoint,
|
||||
should_check_csrf,
|
||||
)
|
||||
|
||||
# ── Setup ────────────────────────────────────────────────────────────
|
||||
|
||||
_TEST_SECRET = "test-secret-for-auth-type-system-tests-min32"
|
||||
|
||||
|
||||
def _setup_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
|
||||
|
||||
# ── CSRF Middleware Path Matching ────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeRequest:
|
||||
"""Minimal request mock for CSRF path matching tests."""
|
||||
|
||||
def __init__(self, path: str, method: str = "POST"):
|
||||
self.method = method
|
||||
|
||||
class _URL:
|
||||
def __init__(self, p):
|
||||
self.path = p
|
||||
|
||||
self.url = _URL(path)
|
||||
self.cookies = {}
|
||||
self.headers = {}
|
||||
|
||||
|
||||
def test_csrf_exempts_login_local():
|
||||
"""login/local (actual route) should be exempt from CSRF."""
|
||||
req = _FakeRequest("/api/v1/auth/login/local")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_exempts_login_local_trailing_slash():
|
||||
"""Trailing slash should also be exempt."""
|
||||
req = _FakeRequest("/api/v1/auth/login/local/")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_exempts_logout():
|
||||
req = _FakeRequest("/api/v1/auth/logout")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_exempts_register():
|
||||
req = _FakeRequest("/api/v1/auth/register")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_does_not_exempt_old_login_path():
|
||||
"""Old /api/v1/auth/login (without /local) should NOT be exempt."""
|
||||
req = _FakeRequest("/api/v1/auth/login")
|
||||
assert is_auth_endpoint(req) is False
|
||||
|
||||
|
||||
def test_csrf_does_not_exempt_me():
|
||||
req = _FakeRequest("/api/v1/auth/me")
|
||||
assert is_auth_endpoint(req) is False
|
||||
|
||||
|
||||
def test_csrf_skips_get_requests():
|
||||
req = _FakeRequest("/api/v1/auth/me", method="GET")
|
||||
assert should_check_csrf(req) is False
|
||||
|
||||
|
||||
def test_csrf_checks_post_to_protected():
|
||||
req = _FakeRequest("/api/v1/some/endpoint", method="POST")
|
||||
assert should_check_csrf(req) is True
|
||||
|
||||
|
||||
# ── Structured Error Response Format ────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_error_response_has_code_and_message():
|
||||
"""All auth errors should have structured {code, message} format."""
|
||||
err = AuthErrorResponse(
|
||||
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||
message="Wrong password",
|
||||
)
|
||||
d = err.model_dump()
|
||||
assert "code" in d
|
||||
assert "message" in d
|
||||
assert d["code"] == "invalid_credentials"
|
||||
|
||||
|
||||
def test_auth_error_response_all_codes_serializable():
|
||||
"""Every AuthErrorCode should be serializable in AuthErrorResponse."""
|
||||
for code in AuthErrorCode:
|
||||
err = AuthErrorResponse(code=code, message=f"Test {code.value}")
|
||||
d = err.model_dump()
|
||||
assert d["code"] == code.value
|
||||
|
||||
|
||||
# ── decode_token Caller Pattern ──────────────────────────────────────
|
||||
|
||||
|
||||
def test_decode_token_expired_maps_to_token_expired_code():
|
||||
"""TokenError.EXPIRED should map to AuthErrorCode.TOKEN_EXPIRED."""
|
||||
_setup_config()
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.EXPIRED
|
||||
|
||||
# Verify the mapping pattern used in route handlers
|
||||
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||
assert code == AuthErrorCode.TOKEN_EXPIRED
|
||||
|
||||
|
||||
def test_decode_token_invalid_sig_maps_to_token_invalid_code():
|
||||
"""TokenError.INVALID_SIGNATURE should map to AuthErrorCode.TOKEN_INVALID."""
|
||||
_setup_config()
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.INVALID_SIGNATURE
|
||||
|
||||
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||
assert code == AuthErrorCode.TOKEN_INVALID
|
||||
|
||||
|
||||
def test_decode_token_malformed_maps_to_token_invalid_code():
|
||||
"""TokenError.MALFORMED should map to AuthErrorCode.TOKEN_INVALID."""
|
||||
_setup_config()
|
||||
result = decode_token("garbage")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||
assert code == AuthErrorCode.TOKEN_INVALID
|
||||
|
||||
|
||||
# ── Login Response Format ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_login_response_model_has_no_access_token():
|
||||
"""LoginResponse should NOT contain access_token field (RFC-001)."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
resp = LoginResponse(expires_in=604800)
|
||||
d = resp.model_dump()
|
||||
assert "access_token" not in d
|
||||
assert "expires_in" in d
|
||||
assert d["expires_in"] == 604800
|
||||
|
||||
|
||||
def test_login_response_model_fields():
|
||||
"""LoginResponse has expires_in and needs_setup."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
fields = set(LoginResponse.model_fields.keys())
|
||||
assert fields == {"expires_in", "needs_setup"}
|
||||
|
||||
|
||||
# ── AuthConfig in Route ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_used_in_login_response():
|
||||
"""LoginResponse.expires_in should come from config.token_expiry_days."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
expected_seconds = 14 * 24 * 3600
|
||||
resp = LoginResponse(expires_in=expected_seconds)
|
||||
assert resp.expires_in == expected_seconds
|
||||
|
||||
|
||||
# ── UserResponse Type Preservation ───────────────────────────────────
|
||||
|
||||
|
||||
def test_user_response_system_role_literal():
|
||||
"""UserResponse.system_role should only accept 'admin' or 'user'."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
# Valid roles
|
||||
resp = UserResponse(id="1", email="a@b.com", system_role="admin")
|
||||
assert resp.system_role == "admin"
|
||||
|
||||
resp = UserResponse(id="1", email="a@b.com", system_role="user")
|
||||
assert resp.system_role == "user"
|
||||
|
||||
|
||||
def test_user_response_rejects_invalid_role():
|
||||
"""UserResponse should reject invalid system_role values."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com", system_role="superadmin")
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
# UNHAPPY PATHS / EDGE CASES
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
# ── get_current_user structured 401 responses ────────────────────────
|
||||
|
||||
|
||||
def test_get_current_user_no_cookie_returns_not_authenticated():
|
||||
"""No cookie → 401 with code=not_authenticated."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "not_authenticated"
|
||||
|
||||
|
||||
def test_get_current_user_expired_token_returns_token_expired():
|
||||
"""Expired token → 401 with code=token_expired."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "token_expired"
|
||||
|
||||
|
||||
def test_get_current_user_invalid_token_returns_token_invalid():
|
||||
"""Bad signature → 401 with code=token_invalid."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "token_invalid"
|
||||
|
||||
|
||||
def test_get_current_user_malformed_token_returns_token_invalid():
|
||||
"""Garbage token → 401 with code=token_invalid."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": "not-a-jwt"}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "token_invalid"
|
||||
|
||||
|
||||
# ── decode_token edge cases ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_decode_token_empty_string_returns_malformed():
|
||||
_setup_config()
|
||||
result = decode_token("")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
|
||||
def test_decode_token_whitespace_returns_malformed():
|
||||
_setup_config()
|
||||
result = decode_token(" ")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
|
||||
# ── AuthConfig validation edge cases ─────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_config_missing_jwt_secret_raises():
|
||||
"""AuthConfig requires jwt_secret — no default allowed."""
|
||||
with pytest.raises(ValidationError):
|
||||
AuthConfig()
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_zero_raises():
|
||||
"""token_expiry_days must be >= 1."""
|
||||
with pytest.raises(ValidationError):
|
||||
AuthConfig(jwt_secret="secret", token_expiry_days=0)
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_31_raises():
|
||||
"""token_expiry_days must be <= 30."""
|
||||
with pytest.raises(ValidationError):
|
||||
AuthConfig(jwt_secret="secret", token_expiry_days=31)
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_boundary_1_ok():
|
||||
config = AuthConfig(jwt_secret="secret", token_expiry_days=1)
|
||||
assert config.token_expiry_days == 1
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_boundary_30_ok():
|
||||
config = AuthConfig(jwt_secret="secret", token_expiry_days=30)
|
||||
assert config.token_expiry_days == 30
|
||||
|
||||
|
||||
def test_get_auth_config_missing_env_var_generates_ephemeral(caplog):
|
||||
"""get_auth_config() auto-generates ephemeral secret when AUTH_JWT_SECRET is unset."""
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
|
||||
|
||||
# ── CSRF middleware integration (unhappy paths) ──────────────────────
|
||||
|
||||
|
||||
def _make_csrf_app():
|
||||
"""Create a minimal FastAPI app with CSRFMiddleware for testing."""
|
||||
from fastapi import HTTPException as _HTTPException
|
||||
from fastapi.responses import JSONResponse as _JSONResponse
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.exception_handler(_HTTPException)
|
||||
async def _http_exc_handler(request, exc):
|
||||
return _JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
@app.post("/api/v1/test/protected")
|
||||
async def protected():
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/v1/auth/login/local")
|
||||
async def login():
|
||||
return {"ok": True}
|
||||
|
||||
@app.get("/api/v1/test/read")
|
||||
async def read_endpoint():
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_csrf_middleware_blocks_post_without_token():
|
||||
"""POST to protected endpoint without CSRF token → 403 with structured detail."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.post("/api/v1/test/protected")
|
||||
assert resp.status_code == 403
|
||||
assert "CSRF" in resp.json()["detail"]
|
||||
assert "missing" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_csrf_middleware_blocks_post_with_mismatched_token():
|
||||
"""POST with mismatched CSRF cookie/header → 403 with mismatch detail."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
client.cookies.set(CSRF_COOKIE_NAME, "token-a")
|
||||
resp = client.post(
|
||||
"/api/v1/test/protected",
|
||||
headers={CSRF_HEADER_NAME: "token-b"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert "mismatch" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_csrf_middleware_allows_post_with_matching_token():
|
||||
"""POST with matching CSRF cookie/header → 200."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
token = secrets.token_urlsafe(64)
|
||||
client.cookies.set(CSRF_COOKIE_NAME, token)
|
||||
resp = client.post(
|
||||
"/api/v1/test/protected",
|
||||
headers={CSRF_HEADER_NAME: token},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_csrf_middleware_allows_get_without_token():
|
||||
"""GET requests bypass CSRF check."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.get("/api/v1/test/read")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_csrf_middleware_exempts_login_local():
|
||||
"""POST to login/local is exempt from CSRF (no token yet)."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.post("/api/v1/auth/login/local")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_csrf_middleware_sets_cookie_on_auth_endpoint():
|
||||
"""Auth endpoints should receive a CSRF cookie in response."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.post("/api/v1/auth/login/local")
|
||||
assert CSRF_COOKIE_NAME in resp.cookies
|
||||
|
||||
|
||||
# ── UserResponse edge cases ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_user_response_missing_required_fields():
|
||||
"""UserResponse with missing fields → ValidationError."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1") # missing email, system_role
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com") # missing system_role
|
||||
|
||||
|
||||
def test_user_response_empty_string_role_rejected():
|
||||
"""Empty string is not a valid role."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com", system_role="")
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
# HTTP-LEVEL API CONTRACT TESTS
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def _make_auth_app():
|
||||
"""Create FastAPI app with auth routes for contract testing."""
|
||||
from app.gateway.app import create_app
|
||||
|
||||
return create_app()
|
||||
|
||||
|
||||
def _get_auth_client():
|
||||
"""Get TestClient for auth API contract tests."""
|
||||
return TestClient(_make_auth_app())
|
||||
|
||||
|
||||
def test_api_auth_me_no_cookie_returns_structured_401():
|
||||
"""/api/v1/auth/me without cookie → 401 with {code: 'not_authenticated'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "not_authenticated"
|
||||
assert "message" in body["detail"]
|
||||
|
||||
|
||||
def test_api_auth_me_expired_token_returns_structured_401():
|
||||
"""/api/v1/auth/me with expired token → 401 with {code: 'token_expired'}."""
|
||||
_setup_config()
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
|
||||
client = _get_auth_client()
|
||||
client.cookies.set("access_token", token)
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "token_expired"
|
||||
|
||||
|
||||
def test_api_auth_me_invalid_sig_returns_structured_401():
|
||||
"""/api/v1/auth/me with bad signature → 401 with {code: 'token_invalid'}."""
|
||||
_setup_config()
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
|
||||
|
||||
client = _get_auth_client()
|
||||
client.cookies.set("access_token", token)
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "token_invalid"
|
||||
|
||||
|
||||
def test_api_login_bad_credentials_returns_structured_401():
|
||||
"""Login with wrong password → 401 with {code: 'invalid_credentials'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": "nonexistent@test.com", "password": "wrongpassword"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "invalid_credentials"
|
||||
|
||||
|
||||
def test_api_login_success_no_token_in_body():
|
||||
"""Successful login → response body has expires_in but NOT access_token."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
# Register first
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "contract-test@test.com", "password": "securepassword123"},
|
||||
)
|
||||
# Login
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": "contract-test@test.com", "password": "securepassword123"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "expires_in" in body
|
||||
assert "access_token" not in body
|
||||
# Token should be in cookie, not body
|
||||
assert "access_token" in resp.cookies
|
||||
|
||||
|
||||
def test_api_register_duplicate_returns_structured_400():
|
||||
"""Register with duplicate email → 400 with {code: 'email_already_exists'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
email = "dup-contract-test@test.com"
|
||||
# First register
|
||||
client.post("/api/v1/auth/register", json={"email": email, "password": "password123"})
|
||||
# Duplicate
|
||||
resp = client.post("/api/v1/auth/register", json={"email": email, "password": "password456"})
|
||||
assert resp.status_code == 400
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "email_already_exists"
|
||||
|
||||
|
||||
# ── Cookie security: HTTP vs HTTPS ────────────────────────────────────
|
||||
|
||||
|
||||
def _unique_email(prefix: str) -> str:
|
||||
return f"{prefix}-{secrets.token_hex(4)}@test.com"
|
||||
|
||||
|
||||
def _get_set_cookie_headers(resp) -> list[str]:
|
||||
"""Extract all set-cookie header values from a TestClient response."""
|
||||
return [v for k, v in resp.headers.multi_items() if k.lower() == "set-cookie"]
|
||||
|
||||
|
||||
def test_register_http_cookie_httponly_true_secure_false():
|
||||
"""HTTP register → access_token cookie is httponly=True, secure=False, no max_age."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("http-cookie"), "password": "password123"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" not in cookie_header.lower().replace("samesite", "")
|
||||
|
||||
|
||||
def test_register_https_cookie_httponly_true_secure_true():
|
||||
"""HTTPS register (x-forwarded-proto) → access_token cookie is httponly=True, secure=True, has max_age."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("https-cookie"), "password": "password123"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" in cookie_header.lower()
|
||||
assert "max-age" in cookie_header.lower()
|
||||
|
||||
|
||||
def test_login_https_sets_secure_cookie():
|
||||
"""HTTPS login → access_token cookie has secure flag."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
email = _unique_email("https-login")
|
||||
client.post("/api/v1/auth/register", json={"email": email, "password": "password123"})
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": email, "password": "password123"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" in cookie_header.lower()
|
||||
|
||||
|
||||
def test_csrf_cookie_secure_on_https():
|
||||
"""HTTPS register → csrf_token cookie has secure flag but NOT httponly."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("csrf-https"), "password": "password123"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
|
||||
assert csrf_cookies, "csrf_token cookie not set on HTTPS register"
|
||||
csrf_header = csrf_cookies[0]
|
||||
assert "secure" in csrf_header.lower()
|
||||
assert "httponly" not in csrf_header.lower()
|
||||
|
||||
|
||||
def test_csrf_cookie_not_secure_on_http():
|
||||
"""HTTP register → csrf_token cookie does NOT have secure flag."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("csrf-http"), "password": "password123"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
|
||||
assert csrf_cookies, "csrf_token cookie not set on HTTP register"
|
||||
csrf_header = csrf_cookies[0]
|
||||
assert "secure" not in csrf_header.lower().replace("samesite", "")
|
||||
@ -7,12 +7,12 @@ import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from app.channels.store import ChannelStore
|
||||
|
||||
|
||||
@ -1718,6 +1718,159 @@ class TestFeishuChannel:
|
||||
_run(go())
|
||||
|
||||
|
||||
class TestWeComChannel:
|
||||
def test_publish_ws_inbound_starts_stream_and_publishes_message(self, monkeypatch):
|
||||
from app.channels.wecom import WeComChannel
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
bus.publish_inbound = AsyncMock()
|
||||
channel = WeComChannel(bus, config={})
|
||||
channel._ws_client = SimpleNamespace(reply_stream=AsyncMock())
|
||||
|
||||
monkeypatch.setitem(
|
||||
__import__("sys").modules,
|
||||
"aibot",
|
||||
SimpleNamespace(generate_req_id=lambda prefix: "stream-1"),
|
||||
)
|
||||
|
||||
frame = {
|
||||
"body": {
|
||||
"msgid": "msg-1",
|
||||
"from": {"userid": "user-1"},
|
||||
"aibotid": "bot-1",
|
||||
"chattype": "single",
|
||||
}
|
||||
}
|
||||
files = [{"type": "image", "url": "https://example.com/image.png"}]
|
||||
|
||||
await channel._publish_ws_inbound(frame, "hello", files=files)
|
||||
|
||||
channel._ws_client.reply_stream.assert_awaited_once_with(frame, "stream-1", "Working on it...", False)
|
||||
bus.publish_inbound.assert_awaited_once()
|
||||
|
||||
inbound = bus.publish_inbound.await_args.args[0]
|
||||
assert inbound.channel_name == "wecom"
|
||||
assert inbound.chat_id == "user-1"
|
||||
assert inbound.user_id == "user-1"
|
||||
assert inbound.text == "hello"
|
||||
assert inbound.thread_ts == "msg-1"
|
||||
assert inbound.topic_id == "user-1"
|
||||
assert inbound.files == files
|
||||
assert inbound.metadata == {"aibotid": "bot-1", "chattype": "single"}
|
||||
assert channel._ws_frames["msg-1"] is frame
|
||||
assert channel._ws_stream_ids["msg-1"] == "stream-1"
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_publish_ws_inbound_uses_configured_working_message(self, monkeypatch):
|
||||
from app.channels.wecom import WeComChannel
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
bus.publish_inbound = AsyncMock()
|
||||
channel = WeComChannel(bus, config={"working_message": "Please wait..."})
|
||||
channel._ws_client = SimpleNamespace(reply_stream=AsyncMock())
|
||||
channel._working_message = "Please wait..."
|
||||
|
||||
monkeypatch.setitem(
|
||||
__import__("sys").modules,
|
||||
"aibot",
|
||||
SimpleNamespace(generate_req_id=lambda prefix: "stream-1"),
|
||||
)
|
||||
|
||||
frame = {
|
||||
"body": {
|
||||
"msgid": "msg-1",
|
||||
"from": {"userid": "user-1"},
|
||||
}
|
||||
}
|
||||
|
||||
await channel._publish_ws_inbound(frame, "hello")
|
||||
|
||||
channel._ws_client.reply_stream.assert_awaited_once_with(frame, "stream-1", "Please wait...", False)
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_on_outbound_sends_attachment_before_clearing_context(self, tmp_path):
|
||||
from app.channels.wecom import WeComChannel
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
channel = WeComChannel(bus, config={})
|
||||
|
||||
frame = {"body": {"msgid": "msg-1"}}
|
||||
ws_client = SimpleNamespace(
|
||||
reply_stream=AsyncMock(),
|
||||
reply=AsyncMock(),
|
||||
)
|
||||
channel._ws_client = ws_client
|
||||
channel._ws_frames["msg-1"] = frame
|
||||
channel._ws_stream_ids["msg-1"] = "stream-1"
|
||||
channel._upload_media_ws = AsyncMock(return_value="media-1")
|
||||
|
||||
attachment_path = tmp_path / "image.png"
|
||||
attachment_path.write_bytes(b"png")
|
||||
attachment = ResolvedAttachment(
|
||||
virtual_path="/mnt/user-data/outputs/image.png",
|
||||
actual_path=attachment_path,
|
||||
filename="image.png",
|
||||
mime_type="image/png",
|
||||
size=attachment_path.stat().st_size,
|
||||
is_image=True,
|
||||
)
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel_name="wecom",
|
||||
chat_id="user-1",
|
||||
thread_id="thread-1",
|
||||
text="done",
|
||||
attachments=[attachment],
|
||||
is_final=True,
|
||||
thread_ts="msg-1",
|
||||
)
|
||||
|
||||
await channel._on_outbound(msg)
|
||||
|
||||
ws_client.reply_stream.assert_awaited_once_with(frame, "stream-1", "done", True)
|
||||
channel._upload_media_ws.assert_awaited_once_with(
|
||||
media_type="image",
|
||||
filename="image.png",
|
||||
path=str(attachment_path),
|
||||
size=attachment.size,
|
||||
)
|
||||
ws_client.reply.assert_awaited_once_with(frame, {"image": {"media_id": "media-1"}, "msgtype": "image"})
|
||||
assert "msg-1" not in channel._ws_frames
|
||||
assert "msg-1" not in channel._ws_stream_ids
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_send_falls_back_to_send_message_without_thread_context(self):
|
||||
from app.channels.wecom import WeComChannel
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
channel = WeComChannel(bus, config={})
|
||||
channel._ws_client = SimpleNamespace(send_message=AsyncMock())
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel_name="wecom",
|
||||
chat_id="user-1",
|
||||
thread_id="thread-1",
|
||||
text="hello",
|
||||
thread_ts=None,
|
||||
)
|
||||
|
||||
await channel.send(msg)
|
||||
|
||||
channel._ws_client.send_message.assert_awaited_once_with(
|
||||
"user-1",
|
||||
{"msgtype": "markdown", "markdown": {"content": "hello"}},
|
||||
)
|
||||
|
||||
_run(go())
|
||||
|
||||
|
||||
class TestChannelService:
|
||||
def test_get_status_no_channels(self):
|
||||
from app.channels.service import ChannelService
|
||||
@ -1835,6 +1988,47 @@ class TestSlackSendRetry:
|
||||
|
||||
_run(go())
|
||||
|
||||
|
||||
class TestSlackAllowedUsers:
|
||||
def test_numeric_allowed_users_match_string_event_user_id(self):
|
||||
from app.channels.slack import SlackChannel
|
||||
|
||||
bus = MessageBus()
|
||||
bus.publish_inbound = AsyncMock()
|
||||
channel = SlackChannel(
|
||||
bus=bus,
|
||||
config={"allowed_users": [123456]},
|
||||
)
|
||||
channel._loop = MagicMock()
|
||||
channel._loop.is_running.return_value = True
|
||||
channel._add_reaction = MagicMock()
|
||||
channel._send_running_reply = MagicMock()
|
||||
|
||||
event = {
|
||||
"user": "123456",
|
||||
"text": "hello from slack",
|
||||
"channel": "C123",
|
||||
"ts": "1710000000.000100",
|
||||
}
|
||||
|
||||
def submit_coro(coro, loop):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
with patch(
|
||||
"app.channels.slack.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=submit_coro,
|
||||
) as submit:
|
||||
channel._handle_message_event(event)
|
||||
|
||||
channel._add_reaction.assert_called_once_with("C123", "1710000000.000100", "eyes")
|
||||
channel._send_running_reply.assert_called_once_with("C123", "1710000000.000100")
|
||||
submit.assert_called_once()
|
||||
inbound = bus.publish_inbound.call_args.args[0]
|
||||
assert inbound.user_id == "123456"
|
||||
assert inbound.chat_id == "C123"
|
||||
assert inbound.text == "hello from slack"
|
||||
|
||||
def test_raises_after_all_retries_exhausted(self):
|
||||
from app.channels.slack import SlackChannel
|
||||
|
||||
@ -1854,6 +2048,20 @@ class TestSlackSendRetry:
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_raises_runtime_error_when_no_attempts_configured(self):
|
||||
from app.channels.slack import SlackChannel
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
ch = SlackChannel(bus=bus, config={"bot_token": "xoxb-test", "app_token": "xapp-test"})
|
||||
ch._web_client = MagicMock()
|
||||
|
||||
msg = OutboundMessage(channel_name="slack", chat_id="C123", thread_id="t1", text="hello")
|
||||
with pytest.raises(RuntimeError, match="without an exception"):
|
||||
await ch.send(msg, _max_retries=0)
|
||||
|
||||
_run(go())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Telegram send retry tests
|
||||
@ -1912,6 +2120,36 @@ class TestTelegramSendRetry:
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_raises_runtime_error_when_no_attempts_configured(self):
|
||||
from app.channels.telegram import TelegramChannel
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
ch = TelegramChannel(bus=bus, config={"bot_token": "test-token"})
|
||||
ch._application = MagicMock()
|
||||
|
||||
msg = OutboundMessage(channel_name="telegram", chat_id="12345", thread_id="t1", text="hello")
|
||||
with pytest.raises(RuntimeError, match="without an exception"):
|
||||
await ch.send(msg, _max_retries=0)
|
||||
|
||||
_run(go())
|
||||
|
||||
|
||||
class TestFeishuSendRetry:
|
||||
def test_raises_runtime_error_when_no_attempts_configured(self):
|
||||
from app.channels.feishu import FeishuChannel
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
ch = FeishuChannel(bus=bus, config={"app_id": "id", "app_secret": "secret"})
|
||||
ch._api_client = MagicMock()
|
||||
|
||||
msg = OutboundMessage(channel_name="feishu", chat_id="chat", thread_id="t1", text="hello")
|
||||
with pytest.raises(RuntimeError, match="without an exception"):
|
||||
await ch.send(msg, _max_retries=0)
|
||||
|
||||
_run(go())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Telegram private-chat thread context tests
|
||||
|
||||
@ -59,18 +59,20 @@ class TestClientInit:
|
||||
assert client._subagent_enabled is False
|
||||
assert client._plan_mode is False
|
||||
assert client._agent_name is None
|
||||
assert client._available_skills is None
|
||||
assert client._checkpointer is None
|
||||
assert client._agent is None
|
||||
|
||||
def test_custom_params(self, mock_app_config):
|
||||
mock_middleware = MagicMock()
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", middlewares=[mock_middleware])
|
||||
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", available_skills={"skill1", "skill2"}, middlewares=[mock_middleware])
|
||||
assert c._model_name == "gpt-4"
|
||||
assert c._thinking_enabled is False
|
||||
assert c._subagent_enabled is True
|
||||
assert c._plan_mode is True
|
||||
assert c._agent_name == "test-agent"
|
||||
assert c._available_skills == {"skill1", "skill2"}
|
||||
assert c._middlewares == [mock_middleware]
|
||||
|
||||
def test_invalid_agent_name(self, mock_app_config):
|
||||
@ -394,8 +396,10 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._agent_name = "custom-agent"
|
||||
client._available_skills = {"test_skill"}
|
||||
client._ensure_agent(config)
|
||||
|
||||
assert client._agent is mock_agent
|
||||
@ -404,6 +408,7 @@ class TestEnsureAgent:
|
||||
assert mock_build_middlewares.call_args.kwargs.get("agent_name") == "custom-agent"
|
||||
mock_apply_prompt.assert_called_once()
|
||||
assert mock_apply_prompt.call_args.kwargs.get("agent_name") == "custom-agent"
|
||||
assert mock_apply_prompt.call_args.kwargs.get("available_skills") == {"test_skill"}
|
||||
|
||||
def test_uses_default_checkpointer_when_available(self, client):
|
||||
mock_agent = MagicMock()
|
||||
@ -441,6 +446,7 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
|
||||
@ -469,7 +475,7 @@ class TestEnsureAgent:
|
||||
"""_ensure_agent does not recreate if config key unchanged."""
|
||||
mock_agent = MagicMock()
|
||||
client._agent = mock_agent
|
||||
client._agent_config_key = (None, True, False, False)
|
||||
client._agent_config_key = (None, True, False, False, None, None)
|
||||
|
||||
config = client._get_runnable_config("t1")
|
||||
client._ensure_agent(config)
|
||||
@ -1276,6 +1282,7 @@ class TestScenarioAgentRecreation:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config_a)
|
||||
first_agent = client._agent
|
||||
@ -1303,6 +1310,7 @@ class TestScenarioAgentRecreation:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
client._ensure_agent(config)
|
||||
@ -1327,6 +1335,7 @@ class TestScenarioAgentRecreation:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
client.reset_agent()
|
||||
|
||||
@ -439,6 +439,15 @@ class TestAgentsAPI:
|
||||
assert "agent-one" in names
|
||||
assert "agent-two" in names
|
||||
|
||||
def test_list_agents_includes_soul(self, agent_client):
|
||||
agent_client.post("/api/agents", json={"name": "soul-agent", "soul": "My soul content"})
|
||||
|
||||
response = agent_client.get("/api/agents")
|
||||
assert response.status_code == 200
|
||||
agents = response.json()["agents"]
|
||||
soul_agent = next(a for a in agents if a["name"] == "soul-agent")
|
||||
assert soul_agent["soul"] == "My soul content"
|
||||
|
||||
def test_get_agent(self, agent_client):
|
||||
agent_client.post("/api/agents", json={"name": "test-agent", "soul": "Hello world"})
|
||||
|
||||
|
||||
214
backend/tests/test_ensure_admin.py
Normal file
214
backend/tests/test_ensure_admin.py
Normal file
@ -0,0 +1,214 @@
|
||||
"""Tests for _ensure_admin_user() in app.py.
|
||||
|
||||
Covers: first-boot admin creation, auto-reset on needs_setup=True,
|
||||
no-op on needs_setup=False, migration, and edge cases.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-ensure-admin-testing-min-32")
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.models import User
|
||||
|
||||
_JWT_SECRET = "test-secret-key-ensure-admin-testing-min-32"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_auth_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
yield
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
|
||||
|
||||
def _make_app_stub(store=None):
|
||||
"""Minimal app-like object with state.store."""
|
||||
app = SimpleNamespace()
|
||||
app.state = SimpleNamespace()
|
||||
app.state.store = store
|
||||
return app
|
||||
|
||||
|
||||
def _make_provider(user_count=0, admin_user=None):
|
||||
p = AsyncMock()
|
||||
p.count_users = AsyncMock(return_value=user_count)
|
||||
p.create_user = AsyncMock(
|
||||
side_effect=lambda **kw: User(
|
||||
email=kw["email"],
|
||||
password_hash="hashed",
|
||||
system_role=kw.get("system_role", "user"),
|
||||
needs_setup=kw.get("needs_setup", False),
|
||||
)
|
||||
)
|
||||
p.get_user_by_email = AsyncMock(return_value=admin_user)
|
||||
p.update_user = AsyncMock(side_effect=lambda u: u)
|
||||
return p
|
||||
|
||||
|
||||
# ── First boot: no users ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_first_boot_creates_admin():
|
||||
"""count_users==0 → create admin with needs_setup=True."""
|
||||
provider = _make_provider(user_count=0)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.create_user.assert_called_once()
|
||||
call_kwargs = provider.create_user.call_args[1]
|
||||
assert call_kwargs["email"] == "admin@deerflow.dev"
|
||||
assert call_kwargs["system_role"] == "admin"
|
||||
assert call_kwargs["needs_setup"] is True
|
||||
assert len(call_kwargs["password"]) > 10 # random password generated
|
||||
|
||||
|
||||
def test_first_boot_triggers_migration_if_store_present():
|
||||
"""First boot with store → _migrate_orphaned_threads called."""
|
||||
provider = _make_provider(user_count=0)
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(return_value=[])
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
store.asearch.assert_called_once()
|
||||
|
||||
|
||||
def test_first_boot_no_store_skips_migration():
|
||||
"""First boot without store → no crash, migration skipped."""
|
||||
provider = _make_provider(user_count=0)
|
||||
app = _make_app_stub(store=None)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.create_user.assert_called_once()
|
||||
|
||||
|
||||
# ── Subsequent boot: needs_setup=True → auto-reset ───────────────────────
|
||||
|
||||
|
||||
def test_needs_setup_true_resets_password():
|
||||
"""Existing admin with needs_setup=True → password reset + token_version bumped."""
|
||||
admin = User(
|
||||
email="admin@deerflow.dev",
|
||||
password_hash="old-hash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=0,
|
||||
created_at=datetime.now(UTC) - timedelta(seconds=30),
|
||||
)
|
||||
provider = _make_provider(user_count=1, admin_user=admin)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
# Password was reset
|
||||
provider.update_user.assert_called_once()
|
||||
updated = provider.update_user.call_args[0][0]
|
||||
assert updated.password_hash == "new-hash"
|
||||
assert updated.token_version == 1
|
||||
|
||||
|
||||
def test_needs_setup_true_consecutive_resets_increment_version():
|
||||
"""Two boots with needs_setup=True → token_version increments each time."""
|
||||
admin = User(
|
||||
email="admin@deerflow.dev",
|
||||
password_hash="hash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=3,
|
||||
created_at=datetime.now(UTC) - timedelta(seconds=30),
|
||||
)
|
||||
provider = _make_provider(user_count=1, admin_user=admin)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
updated = provider.update_user.call_args[0][0]
|
||||
assert updated.token_version == 4
|
||||
|
||||
|
||||
# ── Subsequent boot: needs_setup=False → no-op ──────────────────────────
|
||||
|
||||
|
||||
def test_needs_setup_false_no_reset():
|
||||
"""Admin with needs_setup=False → no password reset, no update."""
|
||||
admin = User(
|
||||
email="admin@deerflow.dev",
|
||||
password_hash="stable-hash",
|
||||
system_role="admin",
|
||||
needs_setup=False,
|
||||
token_version=2,
|
||||
)
|
||||
provider = _make_provider(user_count=1, admin_user=admin)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.update_user.assert_not_called()
|
||||
assert admin.password_hash == "stable-hash"
|
||||
assert admin.token_version == 2
|
||||
|
||||
|
||||
# ── Edge cases ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_no_admin_email_found_no_crash():
|
||||
"""Users exist but no admin@deerflow.dev → no crash, no reset."""
|
||||
provider = _make_provider(user_count=3, admin_user=None)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.update_user.assert_not_called()
|
||||
provider.create_user.assert_not_called()
|
||||
|
||||
|
||||
def test_migration_failure_is_non_fatal():
|
||||
"""_migrate_orphaned_threads exception is caught and logged."""
|
||||
provider = _make_provider(user_count=0)
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(side_effect=RuntimeError("store crashed"))
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
# Should not raise
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.create_user.assert_called_once()
|
||||
459
backend/tests/test_file_conversion.py
Normal file
459
backend/tests/test_file_conversion.py
Normal file
@ -0,0 +1,459 @@
|
||||
"""Tests for file_conversion utilities (PR1: pymupdf4llm + asyncio.to_thread; PR2: extract_outline)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.utils.file_conversion import (
|
||||
_ASYNC_THRESHOLD_BYTES,
|
||||
_MIN_CHARS_PER_PAGE,
|
||||
MAX_OUTLINE_ENTRIES,
|
||||
_do_convert,
|
||||
_pymupdf_output_too_sparse,
|
||||
convert_file_to_markdown,
|
||||
extract_outline,
|
||||
)
|
||||
|
||||
|
||||
def _make_pymupdf_mock(page_count: int) -> ModuleType:
|
||||
"""Return a fake *pymupdf* module whose ``open()`` reports *page_count* pages."""
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.__len__ = MagicMock(return_value=page_count)
|
||||
fake_pymupdf = ModuleType("pymupdf")
|
||||
fake_pymupdf.open = MagicMock(return_value=mock_doc) # type: ignore[attr-defined]
|
||||
return fake_pymupdf
|
||||
|
||||
|
||||
def _run(coro):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _pymupdf_output_too_sparse
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPymupdfOutputTooSparse:
|
||||
"""Check the chars-per-page sparsity heuristic."""
|
||||
|
||||
def test_dense_text_pdf_not_sparse(self, tmp_path):
|
||||
"""Normal text PDF: many chars per page → not sparse."""
|
||||
pdf = tmp_path / "dense.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
# 10 pages × 10 000 chars → 1000/page ≫ threshold
|
||||
with patch.dict(sys.modules, {"pymupdf": _make_pymupdf_mock(page_count=10)}):
|
||||
result = _pymupdf_output_too_sparse("x" * 10_000, pdf)
|
||||
assert result is False
|
||||
|
||||
def test_image_based_pdf_is_sparse(self, tmp_path):
|
||||
"""Image-based PDF: near-zero chars per page → sparse."""
|
||||
pdf = tmp_path / "image.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
# 612 chars / 31 pages ≈ 19.7/page < _MIN_CHARS_PER_PAGE (50)
|
||||
with patch.dict(sys.modules, {"pymupdf": _make_pymupdf_mock(page_count=31)}):
|
||||
result = _pymupdf_output_too_sparse("x" * 612, pdf)
|
||||
assert result is True
|
||||
|
||||
def test_fallback_when_pymupdf_unavailable(self, tmp_path):
|
||||
"""When pymupdf is not installed, fall back to absolute 200-char threshold."""
|
||||
pdf = tmp_path / "broken.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
# Remove pymupdf from sys.modules so the `import pymupdf` inside the
|
||||
# function raises ImportError, triggering the absolute-threshold fallback.
|
||||
with patch.dict(sys.modules, {"pymupdf": None}):
|
||||
sparse = _pymupdf_output_too_sparse("x" * 100, pdf)
|
||||
not_sparse = _pymupdf_output_too_sparse("x" * 300, pdf)
|
||||
|
||||
assert sparse is True
|
||||
assert not_sparse is False
|
||||
|
||||
def test_exactly_at_threshold_is_not_sparse(self, tmp_path):
|
||||
"""Chars-per-page == threshold is treated as NOT sparse (boundary inclusive)."""
|
||||
pdf = tmp_path / "boundary.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
# 2 pages × _MIN_CHARS_PER_PAGE chars = exactly at threshold
|
||||
with patch.dict(sys.modules, {"pymupdf": _make_pymupdf_mock(page_count=2)}):
|
||||
result = _pymupdf_output_too_sparse("x" * (_MIN_CHARS_PER_PAGE * 2), pdf)
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _do_convert — routing logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDoConvert:
|
||||
"""Verify that _do_convert routes to the right sub-converter."""
|
||||
|
||||
def test_non_pdf_always_uses_markitdown(self, tmp_path):
|
||||
"""DOCX / XLSX / PPTX always go through MarkItDown regardless of setting."""
|
||||
docx = tmp_path / "report.docx"
|
||||
docx.write_bytes(b"PK fake docx")
|
||||
|
||||
with patch(
|
||||
"deerflow.utils.file_conversion._convert_with_markitdown",
|
||||
return_value="# Markdown from MarkItDown",
|
||||
) as mock_md:
|
||||
result = _do_convert(docx, "auto")
|
||||
|
||||
mock_md.assert_called_once_with(docx)
|
||||
assert result == "# Markdown from MarkItDown"
|
||||
|
||||
def test_pdf_auto_uses_pymupdf4llm_when_dense(self, tmp_path):
|
||||
"""auto mode: use pymupdf4llm output when it's dense enough."""
|
||||
pdf = tmp_path / "report.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
dense_text = "# Heading\n" + "word " * 2000 # clearly dense
|
||||
|
||||
with (
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
|
||||
return_value=dense_text,
|
||||
),
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._pymupdf_output_too_sparse",
|
||||
return_value=False,
|
||||
),
|
||||
patch("deerflow.utils.file_conversion._convert_with_markitdown") as mock_md,
|
||||
):
|
||||
result = _do_convert(pdf, "auto")
|
||||
|
||||
mock_md.assert_not_called()
|
||||
assert result == dense_text
|
||||
|
||||
def test_pdf_auto_falls_back_when_sparse(self, tmp_path):
|
||||
"""auto mode: fall back to MarkItDown when pymupdf4llm output is sparse."""
|
||||
pdf = tmp_path / "scanned.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
|
||||
return_value="x" * 612, # 19.7 chars/page for 31-page doc
|
||||
),
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._pymupdf_output_too_sparse",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._convert_with_markitdown",
|
||||
return_value="OCR result via MarkItDown",
|
||||
) as mock_md,
|
||||
):
|
||||
result = _do_convert(pdf, "auto")
|
||||
|
||||
mock_md.assert_called_once_with(pdf)
|
||||
assert result == "OCR result via MarkItDown"
|
||||
|
||||
def test_pdf_explicit_pymupdf4llm_skips_sparsity_check(self, tmp_path):
|
||||
"""'pymupdf4llm' mode: use output as-is even if sparse."""
|
||||
pdf = tmp_path / "explicit.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
sparse_text = "x" * 10 # very short
|
||||
|
||||
with (
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
|
||||
return_value=sparse_text,
|
||||
),
|
||||
patch("deerflow.utils.file_conversion._convert_with_markitdown") as mock_md,
|
||||
):
|
||||
result = _do_convert(pdf, "pymupdf4llm")
|
||||
|
||||
mock_md.assert_not_called()
|
||||
assert result == sparse_text
|
||||
|
||||
def test_pdf_explicit_markitdown_skips_pymupdf4llm(self, tmp_path):
|
||||
"""'markitdown' mode: never attempt pymupdf4llm."""
|
||||
pdf = tmp_path / "force_md.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
with (
|
||||
patch("deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm") as mock_pymu,
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._convert_with_markitdown",
|
||||
return_value="MarkItDown result",
|
||||
),
|
||||
):
|
||||
result = _do_convert(pdf, "markitdown")
|
||||
|
||||
mock_pymu.assert_not_called()
|
||||
assert result == "MarkItDown result"
|
||||
|
||||
def test_pdf_auto_falls_back_when_pymupdf4llm_not_installed(self, tmp_path):
|
||||
"""auto mode: if pymupdf4llm is not installed, use MarkItDown directly."""
|
||||
pdf = tmp_path / "no_pymupdf.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
|
||||
return_value=None, # None signals not installed
|
||||
),
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._convert_with_markitdown",
|
||||
return_value="MarkItDown fallback",
|
||||
) as mock_md,
|
||||
):
|
||||
result = _do_convert(pdf, "auto")
|
||||
|
||||
mock_md.assert_called_once_with(pdf)
|
||||
assert result == "MarkItDown fallback"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# convert_file_to_markdown — async + file writing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConvertFileToMarkdown:
|
||||
def test_small_file_runs_synchronously(self, tmp_path):
|
||||
"""Small files (< 1 MB) are converted in the event loop thread."""
|
||||
pdf = tmp_path / "small.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 " + b"x" * 100) # well under 1 MB
|
||||
|
||||
with (
|
||||
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._do_convert",
|
||||
return_value="# Small PDF",
|
||||
) as mock_convert,
|
||||
patch("asyncio.to_thread") as mock_thread,
|
||||
):
|
||||
md_path = _run(convert_file_to_markdown(pdf))
|
||||
|
||||
# asyncio.to_thread must NOT have been called
|
||||
mock_thread.assert_not_called()
|
||||
mock_convert.assert_called_once()
|
||||
assert md_path == pdf.with_suffix(".md")
|
||||
assert md_path.read_text() == "# Small PDF"
|
||||
|
||||
def test_large_file_offloaded_to_thread(self, tmp_path):
|
||||
"""Large files (> 1 MB) are offloaded via asyncio.to_thread."""
|
||||
pdf = tmp_path / "large.pdf"
|
||||
# Write slightly more than the threshold
|
||||
pdf.write_bytes(b"%PDF-1.4 " + b"x" * (_ASYNC_THRESHOLD_BYTES + 1))
|
||||
|
||||
async def fake_to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
with (
|
||||
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._do_convert",
|
||||
return_value="# Large PDF",
|
||||
),
|
||||
patch("asyncio.to_thread", side_effect=fake_to_thread) as mock_thread,
|
||||
):
|
||||
md_path = _run(convert_file_to_markdown(pdf))
|
||||
|
||||
mock_thread.assert_called_once()
|
||||
assert md_path == pdf.with_suffix(".md")
|
||||
assert md_path.read_text() == "# Large PDF"
|
||||
|
||||
def test_returns_none_on_conversion_error(self, tmp_path):
|
||||
"""If conversion raises, return None without propagating the exception."""
|
||||
pdf = tmp_path / "broken.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
with (
|
||||
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._do_convert",
|
||||
side_effect=RuntimeError("conversion failed"),
|
||||
),
|
||||
):
|
||||
result = _run(convert_file_to_markdown(pdf))
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_writes_utf8_markdown_file(self, tmp_path):
|
||||
"""Generated .md file is written with UTF-8 encoding."""
|
||||
pdf = tmp_path / "report.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
chinese_content = "# 中文报告\n\n这是测试内容。"
|
||||
|
||||
with (
|
||||
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._do_convert",
|
||||
return_value=chinese_content,
|
||||
),
|
||||
):
|
||||
md_path = _run(convert_file_to_markdown(pdf))
|
||||
|
||||
assert md_path is not None
|
||||
assert md_path.read_text(encoding="utf-8") == chinese_content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_outline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractOutline:
|
||||
"""Tests for extract_outline()."""
|
||||
|
||||
def test_empty_file_returns_empty(self, tmp_path):
|
||||
"""Empty markdown file yields no outline entries."""
|
||||
md = tmp_path / "empty.md"
|
||||
md.write_text("", encoding="utf-8")
|
||||
assert extract_outline(md) == []
|
||||
|
||||
def test_missing_file_returns_empty(self, tmp_path):
|
||||
"""Non-existent path returns [] without raising."""
|
||||
assert extract_outline(tmp_path / "nonexistent.md") == []
|
||||
|
||||
def test_standard_markdown_headings(self, tmp_path):
|
||||
"""# / ## / ### headings are all recognised."""
|
||||
md = tmp_path / "doc.md"
|
||||
md.write_text(
|
||||
"# Chapter One\n\nSome text.\n\n## Section 1.1\n\nMore text.\n\n### Sub 1.1.1\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
outline = extract_outline(md)
|
||||
assert len(outline) == 3
|
||||
assert outline[0] == {"title": "Chapter One", "line": 1}
|
||||
assert outline[1] == {"title": "Section 1.1", "line": 5}
|
||||
assert outline[2] == {"title": "Sub 1.1.1", "line": 9}
|
||||
|
||||
def test_bold_sec_item_heading(self, tmp_path):
|
||||
"""**ITEM N. TITLE** lines in SEC filings are recognised."""
|
||||
md = tmp_path / "10k.md"
|
||||
md.write_text(
|
||||
"Cover page text.\n\n**ITEM 1. BUSINESS**\n\nBody.\n\n**ITEM 1A. RISK FACTORS**\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
outline = extract_outline(md)
|
||||
assert len(outline) == 2
|
||||
assert outline[0] == {"title": "ITEM 1. BUSINESS", "line": 3}
|
||||
assert outline[1] == {"title": "ITEM 1A. RISK FACTORS", "line": 7}
|
||||
|
||||
def test_bold_part_heading(self, tmp_path):
|
||||
"""**PART I** / **PART II** headings are recognised."""
|
||||
md = tmp_path / "10k.md"
|
||||
md.write_text("**PART I**\n\n**PART II**\n\n**PART III**\n", encoding="utf-8")
|
||||
outline = extract_outline(md)
|
||||
assert len(outline) == 3
|
||||
titles = [e["title"] for e in outline]
|
||||
assert "PART I" in titles
|
||||
assert "PART II" in titles
|
||||
assert "PART III" in titles
|
||||
|
||||
def test_sec_cover_page_boilerplate_excluded(self, tmp_path):
|
||||
"""Address lines and short cover boilerplate must NOT appear in outline."""
|
||||
md = tmp_path / "8k.md"
|
||||
md.write_text(
|
||||
"## **UNITED STATES SECURITIES AND EXCHANGE COMMISSION**\n\n**WASHINGTON, DC 20549**\n\n**CURRENT REPORT**\n\n**SIGNATURES**\n\n**TESLA, INC.**\n\n**ITEM 2.02. RESULTS OF OPERATIONS**\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
outline = extract_outline(md)
|
||||
titles = [e["title"] for e in outline]
|
||||
# Cover-page boilerplate should be excluded
|
||||
assert "WASHINGTON, DC 20549" not in titles
|
||||
assert "CURRENT REPORT" not in titles
|
||||
assert "SIGNATURES" not in titles
|
||||
assert "TESLA, INC." not in titles
|
||||
# Real SEC heading must be included
|
||||
assert "ITEM 2.02. RESULTS OF OPERATIONS" in titles
|
||||
|
||||
def test_chinese_headings_via_standard_markdown(self, tmp_path):
|
||||
"""Chinese annual report headings emitted as # by pymupdf4llm are captured."""
|
||||
md = tmp_path / "annual.md"
|
||||
md.write_text(
|
||||
"# 第一节 公司简介\n\n内容。\n\n## 第三节 管理层讨论与分析\n\n分析内容。\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
outline = extract_outline(md)
|
||||
assert len(outline) == 2
|
||||
assert outline[0]["title"] == "第一节 公司简介"
|
||||
assert outline[1]["title"] == "第三节 管理层讨论与分析"
|
||||
|
||||
def test_outline_capped_at_max_entries(self, tmp_path):
|
||||
"""When truncated, result has MAX_OUTLINE_ENTRIES real entries + 1 sentinel."""
|
||||
lines = [f"# Heading {i}" for i in range(MAX_OUTLINE_ENTRIES + 10)]
|
||||
md = tmp_path / "long.md"
|
||||
md.write_text("\n".join(lines), encoding="utf-8")
|
||||
outline = extract_outline(md)
|
||||
# Last entry is the truncation sentinel
|
||||
assert outline[-1] == {"truncated": True}
|
||||
# Visible entries are exactly MAX_OUTLINE_ENTRIES
|
||||
visible = [e for e in outline if not e.get("truncated")]
|
||||
assert len(visible) == MAX_OUTLINE_ENTRIES
|
||||
|
||||
def test_no_truncation_sentinel_when_under_limit(self, tmp_path):
|
||||
"""Short documents produce no sentinel entry."""
|
||||
lines = [f"# Heading {i}" for i in range(5)]
|
||||
md = tmp_path / "short.md"
|
||||
md.write_text("\n".join(lines), encoding="utf-8")
|
||||
outline = extract_outline(md)
|
||||
assert len(outline) == 5
|
||||
assert not any(e.get("truncated") for e in outline)
|
||||
|
||||
def test_blank_lines_and_whitespace_ignored(self, tmp_path):
|
||||
"""Blank lines between headings do not produce empty entries."""
|
||||
md = tmp_path / "spaced.md"
|
||||
md.write_text("\n\n# Title One\n\n\n\n# Title Two\n\n", encoding="utf-8")
|
||||
outline = extract_outline(md)
|
||||
assert len(outline) == 2
|
||||
assert all(e["title"] for e in outline)
|
||||
|
||||
def test_inline_bold_not_confused_with_heading(self, tmp_path):
|
||||
"""Mid-sentence bold text must not be mistaken for a heading."""
|
||||
md = tmp_path / "prose.md"
|
||||
md.write_text(
|
||||
"This sentence has **bold words** inside it.\n\nAnother with **MULTIPLE CAPS** inline.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
outline = extract_outline(md)
|
||||
assert outline == []
|
||||
|
||||
def test_split_bold_heading_academic_paper(self, tmp_path):
|
||||
"""**<num>** **<title>** lines from academic papers are recognised (Style 3)."""
|
||||
md = tmp_path / "paper.md"
|
||||
md.write_text(
|
||||
"## **Attention Is All You Need**\n\n**1** **Introduction**\n\nBody text.\n\n**2** **Background**\n\nMore text.\n\n**3.1** **Encoder and Decoder Stacks**\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
outline = extract_outline(md)
|
||||
titles = [e["title"] for e in outline]
|
||||
assert "1 Introduction" in titles
|
||||
assert "2 Background" in titles
|
||||
assert "3.1 Encoder and Decoder Stacks" in titles
|
||||
|
||||
def test_split_bold_year_columns_excluded(self, tmp_path):
|
||||
"""Financial table headers like **2023** **2022** **2021** are NOT headings."""
|
||||
md = tmp_path / "annual.md"
|
||||
md.write_text(
|
||||
"# Financial Summary\n\n**2023** **2022** **2021**\n\nRevenue 100 90 80\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
outline = extract_outline(md)
|
||||
titles = [e["title"] for e in outline]
|
||||
# Only the # heading should appear, not the year-column row
|
||||
assert titles == ["Financial Summary"]
|
||||
|
||||
def test_adjacent_bold_spans_merged_in_markdown_heading(self, tmp_path):
|
||||
"""** ** artefacts inside a # heading are merged into clean plain text."""
|
||||
md = tmp_path / "sec.md"
|
||||
md.write_text(
|
||||
"## **UNITED STATES** **SECURITIES AND EXCHANGE COMMISSION**\n\nBody text.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
outline = extract_outline(md)
|
||||
assert len(outline) == 1
|
||||
# Title must be clean — no ** ** artefacts
|
||||
assert outline[0]["title"] == "UNITED STATES SECURITIES AND EXCHANGE COMMISSION"
|
||||
@ -8,6 +8,7 @@ import pytest
|
||||
from deerflow.config.acp_config import ACPAgentConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig, set_extensions_config
|
||||
from deerflow.tools.builtins.invoke_acp_agent_tool import (
|
||||
_build_acp_mcp_servers,
|
||||
_build_mcp_servers,
|
||||
_build_permission_response,
|
||||
_get_work_dir,
|
||||
@ -42,6 +43,43 @@ def test_build_mcp_servers_filters_disabled_and_maps_transports():
|
||||
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
|
||||
|
||||
|
||||
def test_build_acp_mcp_servers_formats_list_payload():
|
||||
set_extensions_config(ExtensionsConfig(mcp_servers={"stale": McpServerConfig(enabled=True, type="stdio", command="echo")}, skills={}))
|
||||
fresh_config = ExtensionsConfig(
|
||||
mcp_servers={
|
||||
"stdio": McpServerConfig(enabled=True, type="stdio", command="npx", args=["srv"], env={"FOO": "bar"}),
|
||||
"http": McpServerConfig(enabled=True, type="http", url="https://example.com/mcp", headers={"Authorization": "Bearer token"}),
|
||||
"disabled": McpServerConfig(enabled=False, type="stdio", command="echo"),
|
||||
},
|
||||
skills={},
|
||||
)
|
||||
monkeypatch = pytest.MonkeyPatch()
|
||||
monkeypatch.setattr(
|
||||
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
||||
classmethod(lambda cls: fresh_config),
|
||||
)
|
||||
|
||||
try:
|
||||
assert _build_acp_mcp_servers() == [
|
||||
{
|
||||
"name": "stdio",
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["srv"],
|
||||
"env": [{"name": "FOO", "value": "bar"}],
|
||||
},
|
||||
{
|
||||
"name": "http",
|
||||
"type": "http",
|
||||
"url": "https://example.com/mcp",
|
||||
"headers": [{"name": "Authorization", "value": "Bearer token"}],
|
||||
},
|
||||
]
|
||||
finally:
|
||||
monkeypatch.undo()
|
||||
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
|
||||
|
||||
|
||||
def test_build_permission_response_prefers_allow_once():
|
||||
response = _build_permission_response(
|
||||
[
|
||||
@ -251,9 +289,15 @@ async def test_invoke_acp_agent_uses_fixed_acp_workspace(monkeypatch, tmp_path):
|
||||
assert captured["spawn"] == {"cmd": "codex-acp", "args": ["--json"], "cwd": expected_cwd}
|
||||
assert captured["new_session"] == {
|
||||
"cwd": expected_cwd,
|
||||
"mcp_servers": {
|
||||
"github": {"transport": "stdio", "command": "npx", "args": ["github-mcp"]},
|
||||
},
|
||||
"mcp_servers": [
|
||||
{
|
||||
"name": "github",
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["github-mcp"],
|
||||
"env": [],
|
||||
}
|
||||
],
|
||||
"model": "gpt-5-codex",
|
||||
}
|
||||
assert captured["prompt"] == {
|
||||
@ -448,6 +492,94 @@ async def test_invoke_acp_agent_passes_env_to_spawn(monkeypatch, tmp_path):
|
||||
assert captured["env"] == {"OPENAI_API_KEY": "sk-from-env", "FOO": "bar"}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_invoke_acp_agent_skips_invalid_mcp_servers(monkeypatch, tmp_path, caplog):
|
||||
"""Invalid MCP config should be logged and skipped instead of failing ACP invocation."""
|
||||
from deerflow.config import paths as paths_module
|
||||
|
||||
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(
|
||||
"deerflow.tools.builtins.invoke_acp_agent_tool._build_acp_mcp_servers",
|
||||
lambda: (_ for _ in ()).throw(ValueError("missing command")),
|
||||
)
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class DummyClient:
|
||||
def __init__(self) -> None:
|
||||
self._chunks: list[str] = []
|
||||
|
||||
@property
|
||||
def collected_text(self) -> str:
|
||||
return ""
|
||||
|
||||
async def session_update(self, session_id, update, **kwargs):
|
||||
pass
|
||||
|
||||
async def request_permission(self, options, session_id, tool_call, **kwargs):
|
||||
raise AssertionError("should not be called")
|
||||
|
||||
class DummyConn:
|
||||
async def initialize(self, **kwargs):
|
||||
pass
|
||||
|
||||
async def new_session(self, **kwargs):
|
||||
captured["new_session"] = kwargs
|
||||
return SimpleNamespace(session_id="s1")
|
||||
|
||||
async def prompt(self, **kwargs):
|
||||
pass
|
||||
|
||||
class DummyProcessContext:
|
||||
def __init__(self, client, cmd, *args, env=None, cwd=None):
|
||||
captured["spawn"] = {"cmd": cmd, "args": list(args), "env": env, "cwd": cwd}
|
||||
|
||||
async def __aenter__(self):
|
||||
return DummyConn(), object()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class DummyRequestError(Exception):
|
||||
@staticmethod
|
||||
def method_not_found(method):
|
||||
return DummyRequestError(method)
|
||||
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"acp",
|
||||
SimpleNamespace(
|
||||
PROTOCOL_VERSION="2026-03-24",
|
||||
Client=DummyClient,
|
||||
RequestError=DummyRequestError,
|
||||
spawn_agent_process=lambda client, cmd, *args, env=None, cwd: DummyProcessContext(client, cmd, *args, env=env, cwd=cwd),
|
||||
text_block=lambda text: {"type": "text", "text": text},
|
||||
),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"acp.schema",
|
||||
SimpleNamespace(
|
||||
ClientCapabilities=lambda: {},
|
||||
Implementation=lambda **kwargs: kwargs,
|
||||
TextContentBlock=type("TextContentBlock", (), {"__init__": lambda self, text: setattr(self, "text", text)}),
|
||||
),
|
||||
)
|
||||
|
||||
tool = build_invoke_acp_agent_tool({"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")})
|
||||
caplog.set_level("WARNING")
|
||||
|
||||
try:
|
||||
await tool.coroutine(agent="codex", prompt="Do something")
|
||||
finally:
|
||||
sys.modules.pop("acp", None)
|
||||
sys.modules.pop("acp.schema", None)
|
||||
|
||||
assert captured["new_session"]["mcp_servers"] == []
|
||||
assert "continuing without MCP servers" in caplog.text
|
||||
assert "missing command" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_invoke_acp_agent_passes_none_env_when_not_configured(monkeypatch, tmp_path):
|
||||
"""When env is empty, None is passed to spawn_agent_process (subprocess inherits parent env)."""
|
||||
|
||||
312
backend/tests/test_langgraph_auth.py
Normal file
312
backend/tests/test_langgraph_auth.py
Normal file
@ -0,0 +1,312 @@
|
||||
"""Tests for LangGraph Server auth handler (langgraph_auth.py).
|
||||
|
||||
Validates that the LangGraph auth layer enforces the same rules as Gateway:
|
||||
cookie → JWT decode → DB lookup → token_version check → owner filter
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-for-langgraph-auth-testing-min-32")
|
||||
|
||||
from langgraph_sdk import Auth
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.jwt import create_access_token, decode_token
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.langgraph_auth import add_owner_filter, authenticate
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
_JWT_SECRET = "test-secret-key-for-langgraph-auth-testing-min-32"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_auth_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
yield
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
|
||||
|
||||
def _req(cookies=None, method="GET", headers=None):
|
||||
return SimpleNamespace(cookies=cookies or {}, method=method, headers=headers or {})
|
||||
|
||||
|
||||
def _user(user_id=None, token_version=0):
|
||||
return User(email="test@example.com", password_hash="fakehash", system_role="user", id=user_id or uuid4(), token_version=token_version)
|
||||
|
||||
|
||||
def _mock_provider(user=None):
|
||||
p = AsyncMock()
|
||||
p.get_user = AsyncMock(return_value=user)
|
||||
return p
|
||||
|
||||
|
||||
# ── @auth.authenticate ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_no_cookie_raises_401():
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req()))
|
||||
assert exc.value.status_code == 401
|
||||
assert "Not authenticated" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_invalid_jwt_raises_401():
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": "garbage"})))
|
||||
assert exc.value.status_code == 401
|
||||
assert "Token error" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_expired_jwt_raises_401():
|
||||
token = create_access_token("user-1", expires_delta=timedelta(seconds=-1))
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
def test_user_not_found_raises_401():
|
||||
token = create_access_token("ghost")
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(None)):
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert exc.value.status_code == 401
|
||||
assert "User not found" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_token_version_mismatch_raises_401():
|
||||
user = _user(token_version=2)
|
||||
token = create_access_token(str(user.id), token_version=1)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert exc.value.status_code == 401
|
||||
assert "revoked" in str(exc.value.detail).lower()
|
||||
|
||||
|
||||
def test_valid_token_returns_user_id():
|
||||
user = _user(token_version=0)
|
||||
token = create_access_token(str(user.id), token_version=0)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
result = asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert result == str(user.id)
|
||||
|
||||
|
||||
def test_valid_token_matching_version():
|
||||
user = _user(token_version=5)
|
||||
token = create_access_token(str(user.id), token_version=5)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
result = asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert result == str(user.id)
|
||||
|
||||
|
||||
# ── @auth.authenticate edge cases ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_provider_exception_propagates():
|
||||
"""Provider raises → should not be swallowed silently."""
|
||||
token = create_access_token("user-1")
|
||||
p = AsyncMock()
|
||||
p.get_user = AsyncMock(side_effect=RuntimeError("DB down"))
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=p):
|
||||
with pytest.raises(RuntimeError, match="DB down"):
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
|
||||
|
||||
def test_jwt_missing_ver_defaults_to_zero():
|
||||
"""JWT without 'ver' claim → decoded as ver=0, matches user with token_version=0."""
|
||||
import jwt as pyjwt
|
||||
|
||||
uid = str(uuid4())
|
||||
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
|
||||
user = _user(user_id=uid, token_version=0)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
result = asyncio.run(authenticate(_req({"access_token": raw})))
|
||||
assert result == uid
|
||||
|
||||
|
||||
def test_jwt_missing_ver_rejected_when_user_version_nonzero():
|
||||
"""JWT without 'ver' (defaults 0) vs user with token_version=1 → 401."""
|
||||
import jwt as pyjwt
|
||||
|
||||
uid = str(uuid4())
|
||||
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
|
||||
user = _user(user_id=uid, token_version=1)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": raw})))
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
def test_wrong_secret_raises_401():
|
||||
"""Token signed with different secret → 401."""
|
||||
import jwt as pyjwt
|
||||
|
||||
raw = pyjwt.encode({"sub": "user-1", "exp": 9999999999, "ver": 0}, "wrong-secret-that-is-long-enough-32chars!", algorithm="HS256")
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": raw})))
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
# ── @auth.on (owner filter) ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeUser:
|
||||
"""Minimal BaseUser-compatible object without langgraph_api.config dependency."""
|
||||
|
||||
def __init__(self, identity: str):
|
||||
self.identity = identity
|
||||
self.is_authenticated = True
|
||||
self.display_name = identity
|
||||
|
||||
|
||||
def _make_ctx(user_id):
|
||||
return Auth.types.AuthContext(resource="threads", action="create", user=_FakeUser(user_id), permissions=[])
|
||||
|
||||
|
||||
def test_filter_injects_user_id():
|
||||
value = {}
|
||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||
assert value["metadata"]["user_id"] == "user-a"
|
||||
|
||||
|
||||
def test_filter_preserves_existing_metadata():
|
||||
value = {"metadata": {"title": "hello"}}
|
||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||
assert value["metadata"]["user_id"] == "user-a"
|
||||
assert value["metadata"]["title"] == "hello"
|
||||
|
||||
|
||||
def test_filter_returns_user_id_dict():
|
||||
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
|
||||
assert result == {"user_id": "user-x"}
|
||||
|
||||
|
||||
def test_filter_read_write_consistency():
|
||||
value = {}
|
||||
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
|
||||
assert value["metadata"]["user_id"] == filter_dict["user_id"]
|
||||
|
||||
|
||||
def test_different_users_different_filters():
|
||||
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
|
||||
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
|
||||
assert f_a["user_id"] != f_b["user_id"]
|
||||
|
||||
|
||||
def test_filter_overrides_conflicting_user_id():
|
||||
"""If value already has a different user_id in metadata, it gets overwritten."""
|
||||
value = {"metadata": {"user_id": "attacker"}}
|
||||
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
|
||||
assert value["metadata"]["user_id"] == "real-owner"
|
||||
|
||||
|
||||
def test_filter_with_empty_metadata():
|
||||
"""Explicit empty metadata dict is fine."""
|
||||
value = {"metadata": {}}
|
||||
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
|
||||
assert value["metadata"]["user_id"] == "user-z"
|
||||
assert result == {"user_id": "user-z"}
|
||||
|
||||
|
||||
# ── Gateway parity ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_shared_jwt_secret():
|
||||
token = create_access_token("user-1", token_version=3)
|
||||
payload = decode_token(token)
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
assert not isinstance(payload, TokenError)
|
||||
assert payload.sub == "user-1"
|
||||
assert payload.ver == 3
|
||||
|
||||
|
||||
def test_langgraph_json_has_auth_path():
|
||||
import json
|
||||
|
||||
config = json.loads((Path(__file__).parent.parent / "langgraph.json").read_text())
|
||||
assert "auth" in config
|
||||
assert "langgraph_auth" in config["auth"]["path"]
|
||||
|
||||
|
||||
def test_auth_handler_has_both_layers():
|
||||
from app.gateway.langgraph_auth import auth
|
||||
|
||||
assert auth._authenticate_handler is not None
|
||||
assert len(auth._global_handlers) == 1
|
||||
|
||||
|
||||
# ── CSRF in LangGraph auth ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_csrf_get_no_check():
|
||||
"""GET requests skip CSRF — should proceed to JWT validation."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="GET")))
|
||||
# Rejected by missing cookie, NOT by CSRF
|
||||
assert exc.value.status_code == 401
|
||||
assert "Not authenticated" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_post_missing_token():
|
||||
"""POST without CSRF token → 403."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="POST", cookies={"access_token": "some-jwt"})))
|
||||
assert exc.value.status_code == 403
|
||||
assert "CSRF token missing" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_post_mismatched_token():
|
||||
"""POST with mismatched CSRF tokens → 403."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(
|
||||
authenticate(
|
||||
_req(
|
||||
method="POST",
|
||||
cookies={"access_token": "some-jwt", "csrf_token": "real-token"},
|
||||
headers={"x-csrf-token": "wrong-token"},
|
||||
)
|
||||
)
|
||||
)
|
||||
assert exc.value.status_code == 403
|
||||
assert "mismatch" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_post_matching_token_proceeds_to_jwt():
|
||||
"""POST with matching CSRF tokens passes CSRF check, then fails on JWT."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(
|
||||
authenticate(
|
||||
_req(
|
||||
method="POST",
|
||||
cookies={"access_token": "garbage", "csrf_token": "same-token"},
|
||||
headers={"x-csrf-token": "same-token"},
|
||||
)
|
||||
)
|
||||
)
|
||||
# Past CSRF, rejected by JWT decode
|
||||
assert exc.value.status_code == 401
|
||||
assert "Token error" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_put_requires_token():
|
||||
"""PUT also requires CSRF."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="PUT", cookies={"access_token": "jwt"})))
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
def test_csrf_delete_requires_token():
|
||||
"""DELETE also requires CSRF."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="DELETE", cookies={"access_token": "jwt"})))
|
||||
assert exc.value.status_code == 403
|
||||
388
backend/tests/test_local_sandbox_provider_mounts.py
Normal file
388
backend/tests/test_local_sandbox_provider_mounts.py
Normal file
@ -0,0 +1,388 @@
|
||||
import errno
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping
|
||||
from deerflow.sandbox.local.local_sandbox_provider import LocalSandboxProvider
|
||||
|
||||
|
||||
class TestPathMapping:
|
||||
def test_path_mapping_dataclass(self):
|
||||
mapping = PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True)
|
||||
assert mapping.container_path == "/mnt/skills"
|
||||
assert mapping.local_path == "/home/user/skills"
|
||||
assert mapping.read_only is True
|
||||
|
||||
def test_path_mapping_defaults_to_false(self):
|
||||
mapping = PathMapping(container_path="/mnt/data", local_path="/home/user/data")
|
||||
assert mapping.read_only is False
|
||||
|
||||
|
||||
class TestLocalSandboxPathResolution:
|
||||
def test_resolve_path_exact_match(self):
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
|
||||
],
|
||||
)
|
||||
resolved = sandbox._resolve_path("/mnt/skills")
|
||||
assert resolved == "/home/user/skills"
|
||||
|
||||
def test_resolve_path_nested_path(self):
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
|
||||
],
|
||||
)
|
||||
resolved = sandbox._resolve_path("/mnt/skills/agent/prompt.py")
|
||||
assert resolved == "/home/user/skills/agent/prompt.py"
|
||||
|
||||
def test_resolve_path_no_mapping(self):
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
|
||||
],
|
||||
)
|
||||
resolved = sandbox._resolve_path("/mnt/other/file.txt")
|
||||
assert resolved == "/mnt/other/file.txt"
|
||||
|
||||
def test_resolve_path_longest_prefix_first(self):
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
|
||||
PathMapping(container_path="/mnt", local_path="/var/mnt"),
|
||||
],
|
||||
)
|
||||
resolved = sandbox._resolve_path("/mnt/skills/file.py")
|
||||
# Should match /mnt/skills first (longer prefix)
|
||||
assert resolved == "/home/user/skills/file.py"
|
||||
|
||||
def test_reverse_resolve_path_exact_match(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir)),
|
||||
],
|
||||
)
|
||||
resolved = sandbox._reverse_resolve_path(str(skills_dir))
|
||||
assert resolved == "/mnt/skills"
|
||||
|
||||
def test_reverse_resolve_path_nested(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
file_path = skills_dir / "agent" / "prompt.py"
|
||||
file_path.parent.mkdir()
|
||||
file_path.write_text("test")
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir)),
|
||||
],
|
||||
)
|
||||
resolved = sandbox._reverse_resolve_path(str(file_path))
|
||||
assert resolved == "/mnt/skills/agent/prompt.py"
|
||||
|
||||
|
||||
class TestReadOnlyPath:
|
||||
def test_is_read_only_true(self):
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True),
|
||||
],
|
||||
)
|
||||
assert sandbox._is_read_only_path("/home/user/skills/file.py") is True
|
||||
|
||||
def test_is_read_only_false_for_writable(self):
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/data", local_path="/home/user/data", read_only=False),
|
||||
],
|
||||
)
|
||||
assert sandbox._is_read_only_path("/home/user/data/file.txt") is False
|
||||
|
||||
def test_is_read_only_false_for_unmapped_path(self):
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True),
|
||||
],
|
||||
)
|
||||
# Path not under any mapping
|
||||
assert sandbox._is_read_only_path("/tmp/other/file.txt") is False
|
||||
|
||||
def test_is_read_only_true_for_exact_match(self):
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True),
|
||||
],
|
||||
)
|
||||
assert sandbox._is_read_only_path("/home/user/skills") is True
|
||||
|
||||
def test_write_file_blocked_on_read_only(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir), read_only=True),
|
||||
],
|
||||
)
|
||||
# Skills dir is read-only, write should be blocked
|
||||
with pytest.raises(OSError) as exc_info:
|
||||
sandbox.write_file("/mnt/skills/new_file.py", "content")
|
||||
assert exc_info.value.errno == errno.EROFS
|
||||
|
||||
def test_write_file_allowed_on_writable_mount(self, tmp_path):
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir()
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/data", local_path=str(data_dir), read_only=False),
|
||||
],
|
||||
)
|
||||
sandbox.write_file("/mnt/data/file.txt", "content")
|
||||
assert (data_dir / "file.txt").read_text() == "content"
|
||||
|
||||
def test_update_file_blocked_on_read_only(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
existing_file = skills_dir / "existing.py"
|
||||
existing_file.write_bytes(b"original")
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir), read_only=True),
|
||||
],
|
||||
)
|
||||
with pytest.raises(OSError) as exc_info:
|
||||
sandbox.update_file("/mnt/skills/existing.py", b"updated")
|
||||
assert exc_info.value.errno == errno.EROFS
|
||||
|
||||
|
||||
class TestMultipleMounts:
|
||||
def test_multiple_read_write_mounts(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir()
|
||||
external_dir = tmp_path / "external"
|
||||
external_dir.mkdir()
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir), read_only=True),
|
||||
PathMapping(container_path="/mnt/data", local_path=str(data_dir), read_only=False),
|
||||
PathMapping(container_path="/mnt/external", local_path=str(external_dir), read_only=True),
|
||||
],
|
||||
)
|
||||
|
||||
# Skills is read-only
|
||||
with pytest.raises(OSError):
|
||||
sandbox.write_file("/mnt/skills/file.py", "content")
|
||||
|
||||
# Data is writable
|
||||
sandbox.write_file("/mnt/data/file.txt", "data content")
|
||||
assert (data_dir / "file.txt").read_text() == "data content"
|
||||
|
||||
# External is read-only
|
||||
with pytest.raises(OSError):
|
||||
sandbox.write_file("/mnt/external/file.txt", "content")
|
||||
|
||||
def test_nested_mounts_writable_under_readonly(self, tmp_path):
|
||||
"""A writable mount nested under a read-only mount should allow writes."""
|
||||
ro_dir = tmp_path / "ro"
|
||||
ro_dir.mkdir()
|
||||
rw_dir = ro_dir / "writable"
|
||||
rw_dir.mkdir()
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/repo", local_path=str(ro_dir), read_only=True),
|
||||
PathMapping(container_path="/mnt/repo/writable", local_path=str(rw_dir), read_only=False),
|
||||
],
|
||||
)
|
||||
|
||||
# Parent mount is read-only
|
||||
with pytest.raises(OSError):
|
||||
sandbox.write_file("/mnt/repo/file.txt", "content")
|
||||
|
||||
# Nested writable mount should allow writes
|
||||
sandbox.write_file("/mnt/repo/writable/file.txt", "content")
|
||||
assert (rw_dir / "file.txt").read_text() == "content"
|
||||
|
||||
def test_execute_command_path_replacement(self, tmp_path, monkeypatch):
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir()
|
||||
test_file = data_dir / "test.txt"
|
||||
test_file.write_text("hello")
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/data", local_path=str(data_dir)),
|
||||
],
|
||||
)
|
||||
|
||||
# Mock subprocess to capture the resolved command
|
||||
captured = {}
|
||||
original_run = __import__("subprocess").run
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
if len(args) > 0:
|
||||
captured["command"] = args[0]
|
||||
return original_run(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.local.local_sandbox.subprocess.run", mock_run)
|
||||
monkeypatch.setattr("deerflow.sandbox.local.local_sandbox.LocalSandbox._get_shell", lambda self: "/bin/sh")
|
||||
|
||||
sandbox.execute_command("cat /mnt/data/test.txt")
|
||||
# Verify the command received the resolved local path
|
||||
assert str(data_dir) in captured.get("command", "")
|
||||
|
||||
def test_reverse_resolve_path_does_not_match_partial_prefix(self, tmp_path):
|
||||
foo_dir = tmp_path / "foo"
|
||||
foo_dir.mkdir()
|
||||
foobar_dir = tmp_path / "foobar"
|
||||
foobar_dir.mkdir()
|
||||
target = foobar_dir / "file.txt"
|
||||
target.write_text("test")
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/foo", local_path=str(foo_dir)),
|
||||
],
|
||||
)
|
||||
|
||||
resolved = sandbox._reverse_resolve_path(str(target))
|
||||
assert resolved == str(target.resolve())
|
||||
|
||||
def test_reverse_resolve_paths_in_output_supports_backslash_separator(self, tmp_path):
|
||||
mount_dir = tmp_path / "mount"
|
||||
mount_dir.mkdir()
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/data", local_path=str(mount_dir)),
|
||||
],
|
||||
)
|
||||
|
||||
output = f"Copied: {mount_dir}\\file.txt"
|
||||
masked = sandbox._reverse_resolve_paths_in_output(output)
|
||||
|
||||
assert "/mnt/data/file.txt" in masked
|
||||
assert str(mount_dir) not in masked
|
||||
|
||||
|
||||
class TestLocalSandboxProviderMounts:
|
||||
def test_setup_path_mappings_uses_configured_skills_container_path_as_reserved_prefix(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
custom_dir = tmp_path / "custom"
|
||||
custom_dir.mkdir()
|
||||
|
||||
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||
|
||||
sandbox_config = SandboxConfig(
|
||||
use="deerflow.sandbox.local:LocalSandboxProvider",
|
||||
mounts=[
|
||||
VolumeMountConfig(host_path=str(custom_dir), container_path="/custom-skills/nested", read_only=False),
|
||||
],
|
||||
)
|
||||
config = SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/custom-skills", get_skills_path=lambda: skills_dir),
|
||||
sandbox=sandbox_config,
|
||||
)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/custom-skills"]
|
||||
|
||||
def test_setup_path_mappings_skips_relative_host_path(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
|
||||
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||
|
||||
sandbox_config = SandboxConfig(
|
||||
use="deerflow.sandbox.local:LocalSandboxProvider",
|
||||
mounts=[
|
||||
VolumeMountConfig(host_path="relative/path", container_path="/mnt/data", read_only=False),
|
||||
],
|
||||
)
|
||||
config = SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir),
|
||||
sandbox=sandbox_config,
|
||||
)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
|
||||
|
||||
def test_setup_path_mappings_skips_non_absolute_container_path(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
custom_dir = tmp_path / "custom"
|
||||
custom_dir.mkdir()
|
||||
|
||||
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||
|
||||
sandbox_config = SandboxConfig(
|
||||
use="deerflow.sandbox.local:LocalSandboxProvider",
|
||||
mounts=[
|
||||
VolumeMountConfig(host_path=str(custom_dir), container_path="mnt/data", read_only=False),
|
||||
],
|
||||
)
|
||||
config = SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir),
|
||||
sandbox=sandbox_config,
|
||||
)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
|
||||
|
||||
def test_setup_path_mappings_normalizes_container_path_trailing_slash(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
custom_dir = tmp_path / "custom"
|
||||
custom_dir.mkdir()
|
||||
|
||||
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||
|
||||
sandbox_config = SandboxConfig(
|
||||
use="deerflow.sandbox.local:LocalSandboxProvider",
|
||||
mounts=[
|
||||
VolumeMountConfig(host_path=str(custom_dir), container_path="/mnt/data/", read_only=False),
|
||||
],
|
||||
)
|
||||
config = SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir),
|
||||
sandbox=sandbox_config,
|
||||
)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills", "/mnt/data"]
|
||||
@ -1,5 +1,6 @@
|
||||
"""Tests for LoopDetectionMiddleware."""
|
||||
|
||||
import copy
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
@ -19,8 +20,13 @@ def _make_runtime(thread_id="test-thread"):
|
||||
|
||||
|
||||
def _make_state(tool_calls=None, content=""):
|
||||
"""Build a minimal AgentState dict with an AIMessage."""
|
||||
msg = AIMessage(content=content, tool_calls=tool_calls or [])
|
||||
"""Build a minimal AgentState dict with an AIMessage.
|
||||
|
||||
Deep-copies *content* when it is mutable (e.g. list) so that
|
||||
successive calls never share the same object reference.
|
||||
"""
|
||||
safe_content = copy.deepcopy(content) if isinstance(content, list) else content
|
||||
msg = AIMessage(content=safe_content, tool_calls=tool_calls or [])
|
||||
return {"messages": [msg]}
|
||||
|
||||
|
||||
@ -229,3 +235,114 @@ class TestLoopDetection:
|
||||
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert "default" in mw._history
|
||||
|
||||
|
||||
class TestAppendText:
|
||||
"""Unit tests for LoopDetectionMiddleware._append_text."""
|
||||
|
||||
def test_none_content_returns_text(self):
|
||||
result = LoopDetectionMiddleware._append_text(None, "hello")
|
||||
assert result == "hello"
|
||||
|
||||
def test_str_content_concatenates(self):
|
||||
result = LoopDetectionMiddleware._append_text("existing", "appended")
|
||||
assert result == "existing\n\nappended"
|
||||
|
||||
def test_empty_str_content_concatenates(self):
|
||||
result = LoopDetectionMiddleware._append_text("", "appended")
|
||||
assert result == "\n\nappended"
|
||||
|
||||
def test_list_content_appends_text_block(self):
|
||||
"""List content (e.g. Anthropic thinking mode) should get a new text block."""
|
||||
content = [
|
||||
{"type": "thinking", "text": "Let me think..."},
|
||||
{"type": "text", "text": "Here is my answer"},
|
||||
]
|
||||
result = LoopDetectionMiddleware._append_text(content, "stop msg")
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 3
|
||||
assert result[0] == content[0]
|
||||
assert result[1] == content[1]
|
||||
assert result[2] == {"type": "text", "text": "\n\nstop msg"}
|
||||
|
||||
def test_empty_list_content_appends_text_block(self):
|
||||
result = LoopDetectionMiddleware._append_text([], "stop msg")
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert result[0] == {"type": "text", "text": "\n\nstop msg"}
|
||||
|
||||
def test_unexpected_type_coerced_to_str(self):
|
||||
"""Unexpected content types should be coerced to str as a fallback."""
|
||||
result = LoopDetectionMiddleware._append_text(42, "stop msg")
|
||||
assert isinstance(result, str)
|
||||
assert result == "42\n\nstop msg"
|
||||
|
||||
def test_list_content_not_mutated_in_place(self):
|
||||
"""_append_text must not modify the original list."""
|
||||
original = [{"type": "text", "text": "hello"}]
|
||||
result = LoopDetectionMiddleware._append_text(original, "appended")
|
||||
assert len(original) == 1 # original unchanged
|
||||
assert len(result) == 2 # new list has the appended block
|
||||
|
||||
|
||||
class TestHardStopWithListContent:
|
||||
"""Regression tests: hard stop must not crash when AIMessage.content is a list."""
|
||||
|
||||
def test_hard_stop_with_list_content(self):
|
||||
"""Hard stop on list content should not raise TypeError (regression)."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
# Build state with list content (e.g. Anthropic thinking mode)
|
||||
list_content = [
|
||||
{"type": "thinking", "text": "Let me think..."},
|
||||
{"type": "text", "text": "I'll run ls"},
|
||||
]
|
||||
|
||||
for _ in range(3):
|
||||
mw._apply(_make_state(tool_calls=call, content=list_content), runtime)
|
||||
|
||||
# Fourth call triggers hard stop — must not raise TypeError
|
||||
result = mw._apply(_make_state(tool_calls=call, content=list_content), runtime)
|
||||
assert result is not None
|
||||
msg = result["messages"][0]
|
||||
assert isinstance(msg, AIMessage)
|
||||
assert msg.tool_calls == []
|
||||
# Content should remain a list with the stop message appended
|
||||
assert isinstance(msg.content, list)
|
||||
assert len(msg.content) == 3
|
||||
assert msg.content[2]["type"] == "text"
|
||||
assert _HARD_STOP_MSG in msg.content[2]["text"]
|
||||
|
||||
def test_hard_stop_with_none_content(self):
|
||||
"""Hard stop on None content should produce a plain string."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
for _ in range(3):
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
# Fourth call with default empty-string content
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is not None
|
||||
msg = result["messages"][0]
|
||||
assert isinstance(msg.content, str)
|
||||
assert _HARD_STOP_MSG in msg.content
|
||||
|
||||
def test_hard_stop_with_str_content(self):
|
||||
"""Hard stop on str content should concatenate the stop message."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
for _ in range(3):
|
||||
mw._apply(_make_state(tool_calls=call, content="thinking..."), runtime)
|
||||
|
||||
result = mw._apply(_make_state(tool_calls=call, content="thinking..."), runtime)
|
||||
assert result is not None
|
||||
msg = result["messages"][0]
|
||||
assert isinstance(msg.content, str)
|
||||
assert msg.content.startswith("thinking...")
|
||||
assert _HARD_STOP_MSG in msg.content
|
||||
|
||||
@ -154,3 +154,22 @@ def test_format_memory_renders_correction_without_source_error_normally() -> Non
|
||||
|
||||
assert "Use make dev for local development." in result
|
||||
assert "avoid:" not in result
|
||||
|
||||
|
||||
def test_format_memory_includes_long_term_background() -> None:
|
||||
"""longTermBackground in history must be injected into the prompt."""
|
||||
memory_data = {
|
||||
"user": {},
|
||||
"history": {
|
||||
"recentMonths": {"summary": "Recent activity summary"},
|
||||
"earlierContext": {"summary": "Earlier context summary"},
|
||||
"longTermBackground": {"summary": "Core expertise in distributed systems"},
|
||||
},
|
||||
"facts": [],
|
||||
}
|
||||
|
||||
result = format_memory_for_injection(memory_data, max_tokens=2000)
|
||||
|
||||
assert "Background: Core expertise in distributed systems" in result
|
||||
assert "Recent: Recent activity summary" in result
|
||||
assert "Earlier: Earlier context summary" in result
|
||||
|
||||
@ -47,4 +47,45 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None:
|
||||
thread_id="thread-1",
|
||||
agent_name="lead_agent",
|
||||
correction_detected=True,
|
||||
reinforcement_detected=False,
|
||||
)
|
||||
|
||||
|
||||
def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(queue, "_reset_timer"),
|
||||
):
|
||||
queue.add(thread_id="thread-1", messages=["first"], reinforcement_detected=True)
|
||||
queue.add(thread_id="thread-1", messages=["second"], reinforcement_detected=False)
|
||||
|
||||
assert len(queue._queue) == 1
|
||||
assert queue._queue[0].messages == ["second"]
|
||||
assert queue._queue[0].reinforcement_detected is True
|
||||
|
||||
|
||||
def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue._queue = [
|
||||
ConversationContext(
|
||||
thread_id="thread-1",
|
||||
messages=["conversation"],
|
||||
agent_name="lead_agent",
|
||||
reinforcement_detected=True,
|
||||
)
|
||||
]
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.update_memory.return_value = True
|
||||
|
||||
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
|
||||
queue._process_queue()
|
||||
|
||||
mock_updater.update_memory.assert_called_once_with(
|
||||
messages=["conversation"],
|
||||
thread_id="thread-1",
|
||||
agent_name="lead_agent",
|
||||
correction_detected=False,
|
||||
reinforcement_detected=True,
|
||||
)
|
||||
|
||||
@ -619,3 +619,156 @@ class TestUpdateMemoryStructuredResponse:
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Explicit correction signals were detected" not in prompt
|
||||
|
||||
|
||||
class TestFactDeduplicationCaseInsensitive:
|
||||
"""Tests that fact deduplication is case-insensitive."""
|
||||
|
||||
def test_duplicate_fact_different_case_not_stored(self):
|
||||
updater = MemoryUpdater()
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
"id": "fact_1",
|
||||
"content": "User prefers Python",
|
||||
"category": "preference",
|
||||
"confidence": 0.9,
|
||||
"createdAt": "2026-01-01T00:00:00Z",
|
||||
"source": "thread-a",
|
||||
},
|
||||
]
|
||||
)
|
||||
# Same fact with different casing should be treated as duplicate
|
||||
update_data = {
|
||||
"factsToRemove": [],
|
||||
"newFacts": [
|
||||
{"content": "user prefers python", "category": "preference", "confidence": 0.95},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
|
||||
# Should still have only 1 fact (duplicate rejected)
|
||||
assert len(result["facts"]) == 1
|
||||
assert result["facts"][0]["content"] == "User prefers Python"
|
||||
|
||||
def test_unique_fact_different_case_and_content_stored(self):
|
||||
updater = MemoryUpdater()
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
"id": "fact_1",
|
||||
"content": "User prefers Python",
|
||||
"category": "preference",
|
||||
"confidence": 0.9,
|
||||
"createdAt": "2026-01-01T00:00:00Z",
|
||||
"source": "thread-a",
|
||||
},
|
||||
]
|
||||
)
|
||||
update_data = {
|
||||
"factsToRemove": [],
|
||||
"newFacts": [
|
||||
{"content": "User prefers Go", "category": "preference", "confidence": 0.85},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
|
||||
assert len(result["facts"]) == 2
|
||||
|
||||
|
||||
class TestReinforcementHint:
|
||||
"""Tests that reinforcement_detected injects the correct hint into the prompt."""
|
||||
|
||||
@staticmethod
|
||||
def _make_mock_model(json_response: str):
|
||||
model = MagicMock()
|
||||
response = MagicMock()
|
||||
response.content = f"```json\n{json_response}\n```"
|
||||
model.invoke.return_value = response
|
||||
return model
|
||||
|
||||
def test_reinforcement_hint_injected_when_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Yes, exactly! That's what I needed."
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Great to hear!"
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
def test_reinforcement_hint_absent_when_not_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Tell me more."
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Sure."
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Positive reinforcement signals were detected" not in prompt
|
||||
|
||||
def test_both_hints_present_when_both_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "No wait, that's wrong. Actually yes, exactly right."
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Got it."
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Explicit correction signals were detected" in prompt
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
@ -10,7 +10,7 @@ persisting in long-term memory:
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
|
||||
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
|
||||
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction
|
||||
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction, detect_reinforcement
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@ -270,3 +270,73 @@ class TestStripUploadMentionsFromMemory:
|
||||
mem = {"user": {}, "history": {}, "facts": []}
|
||||
result = _strip_upload_mentions_from_memory(mem)
|
||||
assert result == {"user": {}, "history": {}, "facts": []}
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# detect_reinforcement
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestDetectReinforcement:
|
||||
def test_detects_english_reinforcement_signal(self):
|
||||
msgs = [
|
||||
_human("Can you summarise it in bullet points?"),
|
||||
_ai("Here are the key points: ..."),
|
||||
_human("Yes, exactly! That's what I needed."),
|
||||
_ai("Glad it helped."),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is True
|
||||
|
||||
def test_detects_perfect_signal(self):
|
||||
msgs = [
|
||||
_human("Write it more concisely."),
|
||||
_ai("Here is the concise version."),
|
||||
_human("Perfect."),
|
||||
_ai("Great!"),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is True
|
||||
|
||||
def test_detects_chinese_reinforcement_signal(self):
|
||||
msgs = [
|
||||
_human("帮我用要点来总结"),
|
||||
_ai("好的,要点如下:..."),
|
||||
_human("完全正确,就是这个意思"),
|
||||
_ai("很高兴能帮到你"),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is True
|
||||
|
||||
def test_returns_false_without_signal(self):
|
||||
msgs = [
|
||||
_human("What does this function do?"),
|
||||
_ai("It processes the input data."),
|
||||
_human("Can you show me an example?"),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is False
|
||||
|
||||
def test_only_checks_recent_messages(self):
|
||||
# Reinforcement signal buried beyond the -6 window should not trigger
|
||||
msgs = [
|
||||
_human("Yes, exactly right."),
|
||||
_ai("Noted."),
|
||||
_human("Let's discuss tests."),
|
||||
_ai("Sure."),
|
||||
_human("What about linting?"),
|
||||
_ai("Use ruff."),
|
||||
_human("And formatting?"),
|
||||
_ai("Use make format."),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is False
|
||||
|
||||
def test_does_not_conflict_with_correction(self):
|
||||
# A message can trigger correction but not reinforcement
|
||||
msgs = [
|
||||
_human("That's wrong, try again."),
|
||||
_ai("Corrected."),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is False
|
||||
|
||||
393
backend/tests/test_sandbox_search_tools.py
Normal file
393
backend/tests/test_sandbox_search_tools.py
Normal file
@ -0,0 +1,393 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox
|
||||
from deerflow.sandbox.local.local_sandbox import LocalSandbox
|
||||
from deerflow.sandbox.search import GrepMatch, find_glob_matches, find_grep_matches
|
||||
from deerflow.sandbox.tools import glob_tool, grep_tool
|
||||
|
||||
|
||||
def _make_runtime(tmp_path):
|
||||
workspace = tmp_path / "workspace"
|
||||
uploads = tmp_path / "uploads"
|
||||
outputs = tmp_path / "outputs"
|
||||
workspace.mkdir()
|
||||
uploads.mkdir()
|
||||
outputs.mkdir()
|
||||
return SimpleNamespace(
|
||||
state={
|
||||
"sandbox": {"sandbox_id": "local"},
|
||||
"thread_data": {
|
||||
"workspace_path": str(workspace),
|
||||
"uploads_path": str(uploads),
|
||||
"outputs_path": str(outputs),
|
||||
},
|
||||
},
|
||||
context={"thread_id": "thread-1"},
|
||||
)
|
||||
|
||||
|
||||
def test_glob_tool_returns_virtual_paths_and_ignores_common_dirs(tmp_path, monkeypatch) -> None:
|
||||
runtime = _make_runtime(tmp_path)
|
||||
workspace = tmp_path / "workspace"
|
||||
(workspace / "app.py").write_text("print('hi')\n", encoding="utf-8")
|
||||
(workspace / "pkg").mkdir()
|
||||
(workspace / "pkg" / "util.py").write_text("print('util')\n", encoding="utf-8")
|
||||
(workspace / "node_modules").mkdir()
|
||||
(workspace / "node_modules" / "skip.py").write_text("ignored\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
|
||||
result = glob_tool.func(
|
||||
runtime=runtime,
|
||||
description="find python files",
|
||||
pattern="**/*.py",
|
||||
path="/mnt/user-data/workspace",
|
||||
)
|
||||
|
||||
assert "/mnt/user-data/workspace/app.py" in result
|
||||
assert "/mnt/user-data/workspace/pkg/util.py" in result
|
||||
assert "node_modules" not in result
|
||||
assert str(workspace) not in result
|
||||
|
||||
|
||||
def test_glob_tool_supports_skills_virtual_paths(tmp_path, monkeypatch) -> None:
|
||||
runtime = _make_runtime(tmp_path)
|
||||
skills_dir = tmp_path / "skills"
|
||||
(skills_dir / "public" / "demo").mkdir(parents=True)
|
||||
(skills_dir / "public" / "demo" / "SKILL.md").write_text("# Demo\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
|
||||
with (
|
||||
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value=str(skills_dir)),
|
||||
):
|
||||
result = glob_tool.func(
|
||||
runtime=runtime,
|
||||
description="find skills",
|
||||
pattern="**/SKILL.md",
|
||||
path="/mnt/skills",
|
||||
)
|
||||
|
||||
assert "/mnt/skills/public/demo/SKILL.md" in result
|
||||
assert str(skills_dir) not in result
|
||||
|
||||
|
||||
def test_grep_tool_filters_by_glob_and_skips_binary_files(tmp_path, monkeypatch) -> None:
|
||||
runtime = _make_runtime(tmp_path)
|
||||
workspace = tmp_path / "workspace"
|
||||
(workspace / "main.py").write_text("TODO = 'ship it'\nprint(TODO)\n", encoding="utf-8")
|
||||
(workspace / "notes.txt").write_text("TODO in txt should be filtered\n", encoding="utf-8")
|
||||
(workspace / "image.bin").write_bytes(b"\0binary TODO")
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
|
||||
result = grep_tool.func(
|
||||
runtime=runtime,
|
||||
description="find todo references",
|
||||
pattern="TODO",
|
||||
path="/mnt/user-data/workspace",
|
||||
glob="**/*.py",
|
||||
)
|
||||
|
||||
assert "/mnt/user-data/workspace/main.py:1: TODO = 'ship it'" in result
|
||||
assert "notes.txt" not in result
|
||||
assert "image.bin" not in result
|
||||
assert str(workspace) not in result
|
||||
|
||||
|
||||
def test_grep_tool_truncates_results(tmp_path, monkeypatch) -> None:
|
||||
runtime = _make_runtime(tmp_path)
|
||||
workspace = tmp_path / "workspace"
|
||||
(workspace / "main.py").write_text("TODO one\nTODO two\nTODO three\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
# Prevent config.yaml tool config from overriding the caller-supplied max_results=2.
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.get_app_config", lambda: SimpleNamespace(get_tool_config=lambda name: None))
|
||||
|
||||
result = grep_tool.func(
|
||||
runtime=runtime,
|
||||
description="limit matches",
|
||||
pattern="TODO",
|
||||
path="/mnt/user-data/workspace",
|
||||
max_results=2,
|
||||
)
|
||||
|
||||
assert "Found 2 matches under /mnt/user-data/workspace (showing first 2)" in result
|
||||
assert "TODO one" in result
|
||||
assert "TODO two" in result
|
||||
assert "TODO three" not in result
|
||||
assert "Results truncated." in result
|
||||
|
||||
|
||||
def test_glob_tool_include_dirs_filters_nested_ignored_paths(tmp_path, monkeypatch) -> None:
|
||||
runtime = _make_runtime(tmp_path)
|
||||
workspace = tmp_path / "workspace"
|
||||
(workspace / "src").mkdir()
|
||||
(workspace / "src" / "main.py").write_text("x\n", encoding="utf-8")
|
||||
(workspace / "node_modules").mkdir()
|
||||
(workspace / "node_modules" / "lib").mkdir()
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
|
||||
result = glob_tool.func(
|
||||
runtime=runtime,
|
||||
description="find dirs",
|
||||
pattern="**",
|
||||
path="/mnt/user-data/workspace",
|
||||
include_dirs=True,
|
||||
)
|
||||
|
||||
assert "src" in result
|
||||
assert "node_modules" not in result
|
||||
|
||||
|
||||
def test_grep_tool_literal_mode(tmp_path, monkeypatch) -> None:
|
||||
runtime = _make_runtime(tmp_path)
|
||||
workspace = tmp_path / "workspace"
|
||||
(workspace / "file.py").write_text("price = (a+b)\nresult = a+b\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
|
||||
# literal=True should treat (a+b) as a plain string, not a regex group
|
||||
result = grep_tool.func(
|
||||
runtime=runtime,
|
||||
description="literal search",
|
||||
pattern="(a+b)",
|
||||
path="/mnt/user-data/workspace",
|
||||
literal=True,
|
||||
)
|
||||
|
||||
assert "price = (a+b)" in result
|
||||
assert "result = a+b" not in result
|
||||
|
||||
|
||||
def test_grep_tool_case_sensitive(tmp_path, monkeypatch) -> None:
|
||||
runtime = _make_runtime(tmp_path)
|
||||
workspace = tmp_path / "workspace"
|
||||
(workspace / "file.py").write_text("TODO: fix\ntodo: also fix\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
|
||||
result = grep_tool.func(
|
||||
runtime=runtime,
|
||||
description="case sensitive search",
|
||||
pattern="TODO",
|
||||
path="/mnt/user-data/workspace",
|
||||
case_sensitive=True,
|
||||
)
|
||||
|
||||
assert "TODO: fix" in result
|
||||
assert "todo: also fix" not in result
|
||||
|
||||
|
||||
def test_grep_tool_invalid_regex_returns_error(tmp_path, monkeypatch) -> None:
|
||||
runtime = _make_runtime(tmp_path)
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
|
||||
result = grep_tool.func(
|
||||
runtime=runtime,
|
||||
description="bad pattern",
|
||||
pattern="[invalid",
|
||||
path="/mnt/user-data/workspace",
|
||||
)
|
||||
|
||||
assert "Invalid regex pattern" in result
|
||||
|
||||
|
||||
def test_aio_sandbox_glob_include_dirs_filters_nested_ignored(monkeypatch) -> None:
|
||||
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
||||
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
||||
monkeypatch.setattr(
|
||||
sandbox._client.file,
|
||||
"list_path",
|
||||
lambda **kwargs: SimpleNamespace(
|
||||
data=SimpleNamespace(
|
||||
files=[
|
||||
SimpleNamespace(name="src", path="/mnt/workspace/src"),
|
||||
SimpleNamespace(name="node_modules", path="/mnt/workspace/node_modules"),
|
||||
# child of node_modules — should be filtered via should_ignore_path
|
||||
SimpleNamespace(name="lib", path="/mnt/workspace/node_modules/lib"),
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
matches, truncated = sandbox.glob("/mnt/workspace", "**", include_dirs=True)
|
||||
|
||||
assert "/mnt/workspace/src" in matches
|
||||
assert "/mnt/workspace/node_modules" not in matches
|
||||
assert "/mnt/workspace/node_modules/lib" not in matches
|
||||
assert truncated is False
|
||||
|
||||
|
||||
def test_aio_sandbox_grep_invalid_regex_raises() -> None:
|
||||
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
||||
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
||||
|
||||
import re
|
||||
|
||||
try:
|
||||
sandbox.grep("/mnt/workspace", "[invalid")
|
||||
assert False, "Expected re.error"
|
||||
except re.error:
|
||||
pass
|
||||
|
||||
|
||||
def test_aio_sandbox_glob_parses_json(monkeypatch) -> None:
|
||||
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
||||
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
||||
monkeypatch.setattr(
|
||||
sandbox._client.file,
|
||||
"find_files",
|
||||
lambda **kwargs: SimpleNamespace(data=SimpleNamespace(files=["/mnt/user-data/workspace/app.py", "/mnt/user-data/workspace/node_modules/skip.py"])),
|
||||
)
|
||||
|
||||
matches, truncated = sandbox.glob("/mnt/user-data/workspace", "**/*.py")
|
||||
|
||||
assert matches == ["/mnt/user-data/workspace/app.py"]
|
||||
assert truncated is False
|
||||
|
||||
|
||||
def test_aio_sandbox_grep_parses_json(monkeypatch) -> None:
|
||||
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
||||
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
||||
monkeypatch.setattr(
|
||||
sandbox._client.file,
|
||||
"list_path",
|
||||
lambda **kwargs: SimpleNamespace(
|
||||
data=SimpleNamespace(
|
||||
files=[
|
||||
SimpleNamespace(
|
||||
name="app.py",
|
||||
path="/mnt/user-data/workspace/app.py",
|
||||
is_directory=False,
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
sandbox._client.file,
|
||||
"search_in_file",
|
||||
lambda **kwargs: SimpleNamespace(data=SimpleNamespace(line_numbers=[7], matches=["TODO = True"])),
|
||||
)
|
||||
|
||||
matches, truncated = sandbox.grep("/mnt/user-data/workspace", "TODO")
|
||||
|
||||
assert matches == [GrepMatch(path="/mnt/user-data/workspace/app.py", line_number=7, line="TODO = True")]
|
||||
assert truncated is False
|
||||
|
||||
|
||||
def test_find_glob_matches_raises_not_a_directory(tmp_path) -> None:
|
||||
file_path = tmp_path / "file.txt"
|
||||
file_path.write_text("x\n", encoding="utf-8")
|
||||
|
||||
try:
|
||||
find_glob_matches(file_path, "**/*.py")
|
||||
assert False, "Expected NotADirectoryError"
|
||||
except NotADirectoryError:
|
||||
pass
|
||||
|
||||
|
||||
def test_find_grep_matches_raises_not_a_directory(tmp_path) -> None:
|
||||
file_path = tmp_path / "file.txt"
|
||||
file_path.write_text("TODO\n", encoding="utf-8")
|
||||
|
||||
try:
|
||||
find_grep_matches(file_path, "TODO")
|
||||
assert False, "Expected NotADirectoryError"
|
||||
except NotADirectoryError:
|
||||
pass
|
||||
|
||||
|
||||
def test_find_grep_matches_skips_symlink_outside_root(tmp_path) -> None:
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
outside = tmp_path / "outside.txt"
|
||||
outside.write_text("TODO outside\n", encoding="utf-8")
|
||||
(workspace / "outside-link.txt").symlink_to(outside)
|
||||
|
||||
matches, truncated = find_grep_matches(workspace, "TODO")
|
||||
|
||||
assert matches == []
|
||||
assert truncated is False
|
||||
|
||||
|
||||
def test_glob_tool_honors_smaller_requested_max_results(tmp_path, monkeypatch) -> None:
|
||||
runtime = _make_runtime(tmp_path)
|
||||
workspace = tmp_path / "workspace"
|
||||
(workspace / "a.py").write_text("print('a')\n", encoding="utf-8")
|
||||
(workspace / "b.py").write_text("print('b')\n", encoding="utf-8")
|
||||
(workspace / "c.py").write_text("print('c')\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
monkeypatch.setattr(
|
||||
"deerflow.sandbox.tools.get_app_config",
|
||||
lambda: SimpleNamespace(get_tool_config=lambda name: SimpleNamespace(model_extra={"max_results": 50})),
|
||||
)
|
||||
|
||||
result = glob_tool.func(
|
||||
runtime=runtime,
|
||||
description="limit glob matches",
|
||||
pattern="**/*.py",
|
||||
path="/mnt/user-data/workspace",
|
||||
max_results=2,
|
||||
)
|
||||
|
||||
assert "Found 2 paths under /mnt/user-data/workspace (showing first 2)" in result
|
||||
assert "Results truncated." in result
|
||||
|
||||
|
||||
def test_aio_sandbox_glob_include_dirs_enforces_root_boundary(monkeypatch) -> None:
|
||||
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
||||
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
||||
monkeypatch.setattr(
|
||||
sandbox._client.file,
|
||||
"list_path",
|
||||
lambda **kwargs: SimpleNamespace(
|
||||
data=SimpleNamespace(
|
||||
files=[
|
||||
SimpleNamespace(name="src", path="/mnt/workspace/src"),
|
||||
SimpleNamespace(name="src2", path="/mnt/workspace2/src2"),
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
matches, truncated = sandbox.glob("/mnt/workspace", "**", include_dirs=True)
|
||||
|
||||
assert matches == ["/mnt/workspace/src"]
|
||||
assert truncated is False
|
||||
|
||||
|
||||
def test_aio_sandbox_grep_skips_mismatched_line_number_payloads(monkeypatch) -> None:
|
||||
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
||||
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
||||
monkeypatch.setattr(
|
||||
sandbox._client.file,
|
||||
"list_path",
|
||||
lambda **kwargs: SimpleNamespace(
|
||||
data=SimpleNamespace(
|
||||
files=[
|
||||
SimpleNamespace(
|
||||
name="app.py",
|
||||
path="/mnt/user-data/workspace/app.py",
|
||||
is_directory=False,
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
sandbox._client.file,
|
||||
"search_in_file",
|
||||
lambda **kwargs: SimpleNamespace(data=SimpleNamespace(line_numbers=[7], matches=["TODO = True", "extra"])),
|
||||
)
|
||||
|
||||
matches, truncated = sandbox.grep("/mnt/user-data/workspace", "TODO")
|
||||
|
||||
assert matches == [GrepMatch(path="/mnt/user-data/workspace/app.py", line_number=7, line="TODO = True")]
|
||||
assert truncated is False
|
||||
@ -8,7 +8,10 @@ import pytest
|
||||
from deerflow.sandbox.tools import (
|
||||
VIRTUAL_PATH_PREFIX,
|
||||
_apply_cwd_prefix,
|
||||
_get_custom_mount_for_path,
|
||||
_get_custom_mounts,
|
||||
_is_acp_workspace_path,
|
||||
_is_custom_mount_path,
|
||||
_is_skills_path,
|
||||
_reject_path_traversal,
|
||||
_resolve_acp_workspace_path,
|
||||
@ -39,6 +42,53 @@ def test_replace_virtual_path_maps_virtual_root_and_subpaths() -> None:
|
||||
assert Path(replace_virtual_path("/mnt/user-data", _THREAD_DATA)).as_posix() == "/tmp/deer-flow/threads/t1/user-data"
|
||||
|
||||
|
||||
def test_replace_virtual_path_preserves_trailing_slash() -> None:
|
||||
"""Trailing slash must survive virtual-to-actual path replacement.
|
||||
|
||||
Regression: '/mnt/user-data/workspace/' was previously returned without
|
||||
the trailing slash, causing string concatenations like
|
||||
output_dir + 'file.txt' to produce a missing-separator path.
|
||||
"""
|
||||
result = replace_virtual_path("/mnt/user-data/workspace/", _THREAD_DATA)
|
||||
assert result.endswith("/"), f"Expected trailing slash, got: {result!r}"
|
||||
assert result == "/tmp/deer-flow/threads/t1/user-data/workspace/"
|
||||
|
||||
|
||||
def test_replace_virtual_path_preserves_trailing_slash_windows_style() -> None:
|
||||
"""Trailing slash must be preserved as backslash when actual_base is Windows-style.
|
||||
|
||||
If actual_base uses backslash separators, appending '/' would produce a
|
||||
mixed-separator path. The separator must match the style of actual_base.
|
||||
"""
|
||||
win_thread_data = {
|
||||
"workspace_path": r"C:\deer-flow\threads\t1\user-data\workspace",
|
||||
"uploads_path": r"C:\deer-flow\threads\t1\user-data\uploads",
|
||||
"outputs_path": r"C:\deer-flow\threads\t1\user-data\outputs",
|
||||
}
|
||||
result = replace_virtual_path("/mnt/user-data/workspace/", win_thread_data)
|
||||
assert result.endswith("\\"), f"Expected trailing backslash for Windows path, got: {result!r}"
|
||||
assert "/" not in result, f"Mixed separators in Windows path: {result!r}"
|
||||
|
||||
|
||||
def test_replace_virtual_path_preserves_windows_style_for_nested_subdir_trailing_slash() -> None:
|
||||
"""Nested Windows-style subdirectories must keep backslashes throughout."""
|
||||
win_thread_data = {
|
||||
"workspace_path": r"C:\deer-flow\threads\t1\user-data\workspace",
|
||||
"uploads_path": r"C:\deer-flow\threads\t1\user-data\uploads",
|
||||
"outputs_path": r"C:\deer-flow\threads\t1\user-data\outputs",
|
||||
}
|
||||
result = replace_virtual_path("/mnt/user-data/workspace/subdir/", win_thread_data)
|
||||
assert result == "C:\\deer-flow\\threads\\t1\\user-data\\workspace\\subdir\\"
|
||||
assert "/" not in result, f"Mixed separators in Windows path: {result!r}"
|
||||
|
||||
|
||||
def test_replace_virtual_paths_in_command_preserves_trailing_slash() -> None:
|
||||
"""Trailing slash on a virtual path inside a command must be preserved."""
|
||||
cmd = """python -c "output_dir = '/mnt/user-data/workspace/'; print(output_dir + 'some_file.txt')\""""
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
|
||||
assert "/tmp/deer-flow/threads/t1/user-data/workspace/" in result, f"Trailing slash lost in: {result!r}"
|
||||
|
||||
|
||||
# ---------- mask_local_paths_in_output ----------
|
||||
|
||||
|
||||
@ -96,6 +146,25 @@ def test_validate_local_tool_path_rejects_non_virtual_path() -> None:
|
||||
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_non_virtual_path_mentions_configured_mounts() -> None:
|
||||
with pytest.raises(PermissionError, match="configured mount paths"):
|
||||
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_prioritizes_user_data_before_custom_mounts() -> None:
|
||||
from deerflow.config.sandbox_config import VolumeMountConfig
|
||||
|
||||
mounts = [
|
||||
VolumeMountConfig(host_path="/tmp/host-user-data", container_path=VIRTUAL_PATH_PREFIX, read_only=False),
|
||||
]
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, read_only=True)
|
||||
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, read_only=True)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_bare_virtual_root() -> None:
|
||||
"""The bare /mnt/user-data root without trailing slash is not a valid sub-path."""
|
||||
with pytest.raises(PermissionError, match="Only paths under"):
|
||||
@ -235,6 +304,22 @@ def test_validate_local_bash_command_paths_blocks_host_paths() -> None:
|
||||
validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_https_urls() -> None:
|
||||
"""URLs like https://github.com/... must not be flagged as unsafe absolute paths."""
|
||||
validate_local_bash_command_paths(
|
||||
"cd /mnt/user-data/workspace && git clone https://github.com/CherryHQ/cherry-studio.git",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_http_urls() -> None:
|
||||
"""HTTP URLs must not be flagged as unsafe absolute paths."""
|
||||
validate_local_bash_command_paths(
|
||||
"curl http://example.com/file.tar.gz -o /mnt/user-data/workspace/file.tar.gz",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_virtual_and_system_paths() -> None:
|
||||
validate_local_bash_command_paths(
|
||||
"/bin/echo ok > /mnt/user-data/workspace/out.txt && cat /dev/null",
|
||||
@ -567,6 +652,156 @@ def test_validate_local_bash_command_paths_allows_mcp_filesystem_paths() -> None
|
||||
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
|
||||
|
||||
|
||||
# ---------- Custom mount path tests ----------
|
||||
|
||||
|
||||
def _mock_custom_mounts():
|
||||
"""Create mock VolumeMountConfig objects for testing."""
|
||||
from deerflow.config.sandbox_config import VolumeMountConfig
|
||||
|
||||
return [
|
||||
VolumeMountConfig(host_path="/home/user/code-read", container_path="/mnt/code-read", read_only=True),
|
||||
VolumeMountConfig(host_path="/home/user/data", container_path="/mnt/data", read_only=False),
|
||||
]
|
||||
|
||||
|
||||
def test_is_custom_mount_path_recognises_configured_mounts() -> None:
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
assert _is_custom_mount_path("/mnt/code-read") is True
|
||||
assert _is_custom_mount_path("/mnt/code-read/src/main.py") is True
|
||||
assert _is_custom_mount_path("/mnt/data") is True
|
||||
assert _is_custom_mount_path("/mnt/data/file.txt") is True
|
||||
assert _is_custom_mount_path("/mnt/code-read-extra/foo") is False
|
||||
assert _is_custom_mount_path("/mnt/other") is False
|
||||
|
||||
|
||||
def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
|
||||
from deerflow.config.sandbox_config import VolumeMountConfig
|
||||
|
||||
mounts = [
|
||||
VolumeMountConfig(host_path="/var/mnt", container_path="/mnt", read_only=False),
|
||||
VolumeMountConfig(host_path="/home/user/code", container_path="/mnt/code", read_only=True),
|
||||
]
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
|
||||
mount = _get_custom_mount_for_path("/mnt/code/file.py")
|
||||
assert mount is not None
|
||||
assert mount.container_path == "/mnt/code"
|
||||
|
||||
|
||||
def test_validate_local_tool_path_allows_custom_mount_read() -> None:
|
||||
"""read_file / ls should be able to access custom mount paths."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=True)
|
||||
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=True)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_blocks_read_only_mount_write() -> None:
|
||||
"""write_file / str_replace must NOT write to read-only custom mounts."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="Write access to read-only mount is not allowed"):
|
||||
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=False)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_allows_writable_mount_write() -> None:
|
||||
"""write_file / str_replace should succeed on writable custom mounts."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=False)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_blocks_traversal_in_custom_mount() -> None:
|
||||
"""Path traversal via .. in custom mount paths must be rejected."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path("/mnt/code-read/../../etc/passwd", _THREAD_DATA, read_only=True)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_custom_mount() -> None:
|
||||
"""bash commands referencing custom mount paths should be allowed."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
validate_local_bash_command_paths("cat /mnt/code-read/src/main.py", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("ls /mnt/data", _THREAD_DATA)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_traversal_in_custom_mount() -> None:
|
||||
"""Bash commands with traversal in custom mount paths should be blocked."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths("cat /mnt/code-read/../../etc/passwd", _THREAD_DATA)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_still_blocks_non_mount_paths() -> None:
|
||||
"""Paths not matching any custom mount should still be blocked."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
|
||||
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA)
|
||||
|
||||
|
||||
def test_get_custom_mounts_caching(monkeypatch, tmp_path) -> None:
|
||||
"""_get_custom_mounts should cache after first successful load."""
|
||||
# Clear any existing cache
|
||||
if hasattr(_get_custom_mounts, "_cached"):
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
|
||||
# Use real directories so host_path.exists() filtering passes
|
||||
dir_a = tmp_path / "code-read"
|
||||
dir_a.mkdir()
|
||||
dir_b = tmp_path / "data"
|
||||
dir_b.mkdir()
|
||||
|
||||
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||
|
||||
mounts = [
|
||||
VolumeMountConfig(host_path=str(dir_a), container_path="/mnt/code-read", read_only=True),
|
||||
VolumeMountConfig(host_path=str(dir_b), container_path="/mnt/data", read_only=False),
|
||||
]
|
||||
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
|
||||
mock_config = SimpleNamespace(sandbox=mock_sandbox)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=mock_config):
|
||||
result = _get_custom_mounts()
|
||||
assert len(result) == 2
|
||||
|
||||
# After caching, should return cached value even without mock
|
||||
assert hasattr(_get_custom_mounts, "_cached")
|
||||
assert len(_get_custom_mounts()) == 2
|
||||
|
||||
# Cleanup
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
|
||||
|
||||
def test_get_custom_mounts_filters_nonexistent_host_path(monkeypatch, tmp_path) -> None:
|
||||
"""_get_custom_mounts should only return mounts whose host_path exists."""
|
||||
if hasattr(_get_custom_mounts, "_cached"):
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
|
||||
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||
|
||||
existing_dir = tmp_path / "existing"
|
||||
existing_dir.mkdir()
|
||||
|
||||
mounts = [
|
||||
VolumeMountConfig(host_path=str(existing_dir), container_path="/mnt/existing", read_only=True),
|
||||
VolumeMountConfig(host_path="/nonexistent/path/12345", container_path="/mnt/ghost", read_only=False),
|
||||
]
|
||||
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
|
||||
mock_config = SimpleNamespace(sandbox=mock_sandbox)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=mock_config):
|
||||
result = _get_custom_mounts()
|
||||
assert len(result) == 1
|
||||
assert result[0].container_path == "/mnt/existing"
|
||||
|
||||
# Cleanup
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
|
||||
|
||||
def test_get_custom_mount_for_path_boundary_no_false_prefix_match() -> None:
|
||||
"""_get_custom_mount_for_path must not match /mnt/code-read-extra for /mnt/code-read."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
mount = _get_custom_mount_for_path("/mnt/code-read-extra/foo")
|
||||
assert mount is None
|
||||
|
||||
|
||||
def test_str_replace_parallel_updates_should_preserve_both_edits(monkeypatch) -> None:
|
||||
class SharedSandbox:
|
||||
def __init__(self) -> None:
|
||||
|
||||
@ -140,6 +140,193 @@ async def test_event_id_format(bridge: MemoryStreamBridge):
|
||||
assert re.match(r"^\d+-\d+$", event.id), f"Expected timestamp-seq format, got {event.id}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# END sentinel guarantee tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_end_sentinel_delivered_when_queue_full():
|
||||
"""END sentinel must always be delivered, even when the queue is completely full.
|
||||
|
||||
This is the critical regression test for the bug where publish_end()
|
||||
would silently drop the END sentinel when the queue was full, causing
|
||||
subscribe() to hang forever and leaking resources.
|
||||
"""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=2)
|
||||
run_id = "run-end-full"
|
||||
|
||||
# Fill the queue to capacity
|
||||
await bridge.publish(run_id, "event-1", {"n": 1})
|
||||
await bridge.publish(run_id, "event-2", {"n": 2})
|
||||
assert bridge._queues[run_id].full()
|
||||
|
||||
# publish_end should succeed by evicting old events
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
# Subscriber must receive END_SENTINEL
|
||||
events = []
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
|
||||
events.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
assert any(e is END_SENTINEL for e in events), "END sentinel was not delivered"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_end_sentinel_evicts_oldest_events():
|
||||
"""When queue is full, publish_end evicts the oldest events to make room."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=1)
|
||||
run_id = "run-evict"
|
||||
|
||||
# Fill queue with one event
|
||||
await bridge.publish(run_id, "will-be-evicted", {})
|
||||
assert bridge._queues[run_id].full()
|
||||
|
||||
# publish_end must succeed
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
# The only event we should get is END_SENTINEL (the regular event was evicted)
|
||||
events = []
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
|
||||
events.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0] is END_SENTINEL
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_end_sentinel_no_eviction_when_space_available():
|
||||
"""When queue has space, publish_end should not evict anything."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=10)
|
||||
run_id = "run-no-evict"
|
||||
|
||||
await bridge.publish(run_id, "event-1", {"n": 1})
|
||||
await bridge.publish(run_id, "event-2", {"n": 2})
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
events = []
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
|
||||
events.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
# All events plus END should be present
|
||||
assert len(events) == 3
|
||||
assert events[0].event == "event-1"
|
||||
assert events[1].event == "event-2"
|
||||
assert events[2] is END_SENTINEL
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_concurrent_tasks_end_sentinel():
|
||||
"""Multiple concurrent producer/consumer pairs should all terminate properly.
|
||||
|
||||
Simulates the production scenario where multiple runs share a single
|
||||
bridge instance — each must receive its own END sentinel.
|
||||
"""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=4)
|
||||
num_runs = 4
|
||||
|
||||
async def producer(run_id: str):
|
||||
for i in range(10): # More events than queue capacity
|
||||
await bridge.publish(run_id, f"event-{i}", {"i": i})
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
async def consumer(run_id: str) -> list:
|
||||
events = []
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
|
||||
events.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
return events
|
||||
return events # pragma: no cover
|
||||
|
||||
# Run producers and consumers concurrently
|
||||
run_ids = [f"concurrent-{i}" for i in range(num_runs)]
|
||||
producers = [producer(rid) for rid in run_ids]
|
||||
consumers = [consumer(rid) for rid in run_ids]
|
||||
|
||||
# Start consumers first, then producers
|
||||
consumer_tasks = [asyncio.create_task(c) for c in consumers]
|
||||
await asyncio.gather(*producers)
|
||||
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*consumer_tasks),
|
||||
timeout=10.0,
|
||||
)
|
||||
|
||||
for i, events in enumerate(results):
|
||||
assert events[-1] is END_SENTINEL, f"Run {run_ids[i]} did not receive END sentinel"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Drop counter tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_dropped_count_tracking():
|
||||
"""Dropped events should be tracked per run_id."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=1)
|
||||
run_id = "run-drop-count"
|
||||
|
||||
# Fill the queue
|
||||
await bridge.publish(run_id, "first", {})
|
||||
|
||||
# This publish will time out and be dropped (we patch timeout to be instant)
|
||||
# Instead, we verify the counter after publish_end eviction
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
# dropped_count tracks publish() drops, not publish_end evictions
|
||||
assert bridge.dropped_count(run_id) == 0
|
||||
|
||||
# cleanup should also clear the counter
|
||||
await bridge.cleanup(run_id)
|
||||
assert bridge.dropped_count(run_id) == 0
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_dropped_total():
|
||||
"""dropped_total should sum across all runs."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=256)
|
||||
|
||||
# No drops yet
|
||||
assert bridge.dropped_total == 0
|
||||
|
||||
# Manually set some counts to verify the property
|
||||
bridge._dropped_counts["run-a"] = 3
|
||||
bridge._dropped_counts["run-b"] = 7
|
||||
assert bridge.dropped_total == 10
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cleanup_clears_dropped_counts():
|
||||
"""cleanup() should clear the dropped counter for the run."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=256)
|
||||
run_id = "run-cleanup-drops"
|
||||
|
||||
bridge._get_or_create_queue(run_id)
|
||||
bridge._dropped_counts[run_id] = 5
|
||||
|
||||
await bridge.cleanup(run_id)
|
||||
assert run_id not in bridge._dropped_counts
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_close_clears_dropped_counts():
|
||||
"""close() should clear all dropped counters."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=256)
|
||||
bridge._dropped_counts["run-x"] = 10
|
||||
bridge._dropped_counts["run-y"] = 20
|
||||
|
||||
await bridge.close()
|
||||
assert bridge.dropped_total == 0
|
||||
assert len(bridge._dropped_counts) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
"""Tests for subagent timeout configuration.
|
||||
"""Tests for subagent runtime configuration.
|
||||
|
||||
Covers:
|
||||
- SubagentsAppConfig / SubagentOverrideConfig model validation and defaults
|
||||
- get_timeout_for() resolution logic (global vs per-agent)
|
||||
- get_timeout_for() / get_max_turns_for() resolution logic
|
||||
- load_subagents_config_from_dict() and get_subagents_app_config() singleton
|
||||
- registry.get_subagent_config() applies config overrides
|
||||
- registry.list_subagents() applies overrides for all agents
|
||||
@ -24,9 +24,20 @@ from deerflow.subagents.config import SubagentConfig
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _reset_subagents_config(timeout_seconds: int = 900, agents: dict | None = None) -> None:
|
||||
def _reset_subagents_config(
|
||||
timeout_seconds: int = 900,
|
||||
*,
|
||||
max_turns: int | None = None,
|
||||
agents: dict | None = None,
|
||||
) -> None:
|
||||
"""Reset global subagents config to a known state."""
|
||||
load_subagents_config_from_dict({"timeout_seconds": timeout_seconds, "agents": agents or {}})
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": timeout_seconds,
|
||||
"max_turns": max_turns,
|
||||
"agents": agents or {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -38,22 +49,29 @@ class TestSubagentOverrideConfig:
|
||||
def test_default_is_none(self):
|
||||
override = SubagentOverrideConfig()
|
||||
assert override.timeout_seconds is None
|
||||
assert override.max_turns is None
|
||||
|
||||
def test_explicit_value(self):
|
||||
override = SubagentOverrideConfig(timeout_seconds=300)
|
||||
override = SubagentOverrideConfig(timeout_seconds=300, max_turns=42)
|
||||
assert override.timeout_seconds == 300
|
||||
assert override.max_turns == 42
|
||||
|
||||
def test_rejects_zero(self):
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(timeout_seconds=0)
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(max_turns=0)
|
||||
|
||||
def test_rejects_negative(self):
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(timeout_seconds=-1)
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(max_turns=-1)
|
||||
|
||||
def test_minimum_valid_value(self):
|
||||
override = SubagentOverrideConfig(timeout_seconds=1)
|
||||
override = SubagentOverrideConfig(timeout_seconds=1, max_turns=1)
|
||||
assert override.timeout_seconds == 1
|
||||
assert override.max_turns == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -66,66 +84,86 @@ class TestSubagentsAppConfigDefaults:
|
||||
config = SubagentsAppConfig()
|
||||
assert config.timeout_seconds == 900
|
||||
|
||||
def test_default_max_turns_override_is_none(self):
|
||||
config = SubagentsAppConfig()
|
||||
assert config.max_turns is None
|
||||
|
||||
def test_default_agents_empty(self):
|
||||
config = SubagentsAppConfig()
|
||||
assert config.agents == {}
|
||||
|
||||
def test_custom_global_timeout(self):
|
||||
config = SubagentsAppConfig(timeout_seconds=1800)
|
||||
def test_custom_global_runtime_overrides(self):
|
||||
config = SubagentsAppConfig(timeout_seconds=1800, max_turns=120)
|
||||
assert config.timeout_seconds == 1800
|
||||
assert config.max_turns == 120
|
||||
|
||||
def test_rejects_zero_timeout(self):
|
||||
with pytest.raises(ValueError):
|
||||
SubagentsAppConfig(timeout_seconds=0)
|
||||
with pytest.raises(ValueError):
|
||||
SubagentsAppConfig(max_turns=0)
|
||||
|
||||
def test_rejects_negative_timeout(self):
|
||||
with pytest.raises(ValueError):
|
||||
SubagentsAppConfig(timeout_seconds=-60)
|
||||
with pytest.raises(ValueError):
|
||||
SubagentsAppConfig(max_turns=-60)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubagentsAppConfig.get_timeout_for()
|
||||
# SubagentsAppConfig resolution helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetTimeoutFor:
|
||||
class TestRuntimeResolution:
|
||||
def test_returns_global_default_when_no_override(self):
|
||||
config = SubagentsAppConfig(timeout_seconds=600)
|
||||
assert config.get_timeout_for("general-purpose") == 600
|
||||
assert config.get_timeout_for("bash") == 600
|
||||
assert config.get_timeout_for("unknown-agent") == 600
|
||||
assert config.get_max_turns_for("general-purpose", 100) == 100
|
||||
assert config.get_max_turns_for("bash", 60) == 60
|
||||
|
||||
def test_returns_per_agent_override_when_set(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=300)},
|
||||
max_turns=120,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, max_turns=80)},
|
||||
)
|
||||
assert config.get_timeout_for("bash") == 300
|
||||
assert config.get_max_turns_for("bash", 60) == 80
|
||||
|
||||
def test_other_agents_still_use_global_default(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=300)},
|
||||
max_turns=140,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, max_turns=80)},
|
||||
)
|
||||
assert config.get_timeout_for("general-purpose") == 900
|
||||
assert config.get_max_turns_for("general-purpose", 100) == 140
|
||||
|
||||
def test_agent_with_none_override_falls_back_to_global(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
agents={"general-purpose": SubagentOverrideConfig(timeout_seconds=None)},
|
||||
max_turns=150,
|
||||
agents={"general-purpose": SubagentOverrideConfig(timeout_seconds=None, max_turns=None)},
|
||||
)
|
||||
assert config.get_timeout_for("general-purpose") == 900
|
||||
assert config.get_max_turns_for("general-purpose", 100) == 150
|
||||
|
||||
def test_multiple_per_agent_overrides(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
max_turns=120,
|
||||
agents={
|
||||
"general-purpose": SubagentOverrideConfig(timeout_seconds=1800),
|
||||
"bash": SubagentOverrideConfig(timeout_seconds=120),
|
||||
"general-purpose": SubagentOverrideConfig(timeout_seconds=1800, max_turns=200),
|
||||
"bash": SubagentOverrideConfig(timeout_seconds=120, max_turns=80),
|
||||
},
|
||||
)
|
||||
assert config.get_timeout_for("general-purpose") == 1800
|
||||
assert config.get_timeout_for("bash") == 120
|
||||
assert config.get_max_turns_for("general-purpose", 100) == 200
|
||||
assert config.get_max_turns_for("bash", 60) == 80
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -139,54 +177,63 @@ class TestLoadSubagentsConfig:
|
||||
_reset_subagents_config()
|
||||
|
||||
def test_load_global_timeout(self):
|
||||
load_subagents_config_from_dict({"timeout_seconds": 300})
|
||||
load_subagents_config_from_dict({"timeout_seconds": 300, "max_turns": 120})
|
||||
assert get_subagents_app_config().timeout_seconds == 300
|
||||
assert get_subagents_app_config().max_turns == 120
|
||||
|
||||
def test_load_with_per_agent_overrides(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"max_turns": 120,
|
||||
"agents": {
|
||||
"general-purpose": {"timeout_seconds": 1800},
|
||||
"bash": {"timeout_seconds": 60},
|
||||
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
|
||||
"bash": {"timeout_seconds": 60, "max_turns": 80},
|
||||
},
|
||||
}
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.get_timeout_for("general-purpose") == 1800
|
||||
assert cfg.get_timeout_for("bash") == 60
|
||||
assert cfg.get_max_turns_for("general-purpose", 100) == 200
|
||||
assert cfg.get_max_turns_for("bash", 60) == 80
|
||||
|
||||
def test_load_partial_override(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 600,
|
||||
"agents": {"bash": {"timeout_seconds": 120}},
|
||||
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 70}},
|
||||
}
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.get_timeout_for("general-purpose") == 600
|
||||
assert cfg.get_timeout_for("bash") == 120
|
||||
assert cfg.get_max_turns_for("general-purpose", 100) == 100
|
||||
assert cfg.get_max_turns_for("bash", 60) == 70
|
||||
|
||||
def test_load_empty_dict_uses_defaults(self):
|
||||
load_subagents_config_from_dict({})
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.timeout_seconds == 900
|
||||
assert cfg.max_turns is None
|
||||
assert cfg.agents == {}
|
||||
|
||||
def test_load_replaces_previous_config(self):
|
||||
load_subagents_config_from_dict({"timeout_seconds": 100})
|
||||
load_subagents_config_from_dict({"timeout_seconds": 100, "max_turns": 90})
|
||||
assert get_subagents_app_config().timeout_seconds == 100
|
||||
assert get_subagents_app_config().max_turns == 90
|
||||
|
||||
load_subagents_config_from_dict({"timeout_seconds": 200})
|
||||
load_subagents_config_from_dict({"timeout_seconds": 200, "max_turns": 110})
|
||||
assert get_subagents_app_config().timeout_seconds == 200
|
||||
assert get_subagents_app_config().max_turns == 110
|
||||
|
||||
def test_singleton_returns_same_instance_between_calls(self):
|
||||
load_subagents_config_from_dict({"timeout_seconds": 777})
|
||||
load_subagents_config_from_dict({"timeout_seconds": 777, "max_turns": 123})
|
||||
assert get_subagents_app_config() is get_subagents_app_config()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# registry.get_subagent_config – timeout override applied
|
||||
# registry.get_subagent_config – runtime overrides applied
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@ -211,25 +258,29 @@ class TestRegistryGetSubagentConfig:
|
||||
_reset_subagents_config(timeout_seconds=900)
|
||||
config = get_subagent_config("general-purpose")
|
||||
assert config.timeout_seconds == 900
|
||||
assert config.max_turns == 100
|
||||
|
||||
def test_global_timeout_override_applied(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
_reset_subagents_config(timeout_seconds=1800)
|
||||
_reset_subagents_config(timeout_seconds=1800, max_turns=140)
|
||||
config = get_subagent_config("general-purpose")
|
||||
assert config.timeout_seconds == 1800
|
||||
assert config.max_turns == 140
|
||||
|
||||
def test_per_agent_timeout_override_applied(self):
|
||||
def test_per_agent_runtime_override_applied(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"timeout_seconds": 120}},
|
||||
"max_turns": 120,
|
||||
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 80}},
|
||||
}
|
||||
)
|
||||
bash_config = get_subagent_config("bash")
|
||||
assert bash_config.timeout_seconds == 120
|
||||
assert bash_config.max_turns == 80
|
||||
|
||||
def test_per_agent_override_does_not_affect_other_agents(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
@ -237,11 +288,13 @@ class TestRegistryGetSubagentConfig:
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"timeout_seconds": 120}},
|
||||
"max_turns": 120,
|
||||
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 80}},
|
||||
}
|
||||
)
|
||||
gp_config = get_subagent_config("general-purpose")
|
||||
assert gp_config.timeout_seconds == 900
|
||||
assert gp_config.max_turns == 120
|
||||
|
||||
def test_builtin_config_object_is_not_mutated(self):
|
||||
"""Registry must return a new object, leaving the builtin default intact."""
|
||||
@ -249,24 +302,27 @@ class TestRegistryGetSubagentConfig:
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
original_timeout = BUILTIN_SUBAGENTS["bash"].timeout_seconds
|
||||
load_subagents_config_from_dict({"timeout_seconds": 42})
|
||||
original_max_turns = BUILTIN_SUBAGENTS["bash"].max_turns
|
||||
load_subagents_config_from_dict({"timeout_seconds": 42, "max_turns": 88})
|
||||
|
||||
returned = get_subagent_config("bash")
|
||||
assert returned.timeout_seconds == 42
|
||||
assert returned.max_turns == 88
|
||||
assert BUILTIN_SUBAGENTS["bash"].timeout_seconds == original_timeout
|
||||
assert BUILTIN_SUBAGENTS["bash"].max_turns == original_max_turns
|
||||
|
||||
def test_config_preserves_other_fields(self):
|
||||
"""Applying timeout override must not change other SubagentConfig fields."""
|
||||
"""Applying runtime overrides must not change other SubagentConfig fields."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
_reset_subagents_config(timeout_seconds=300)
|
||||
_reset_subagents_config(timeout_seconds=300, max_turns=140)
|
||||
original = BUILTIN_SUBAGENTS["general-purpose"]
|
||||
overridden = get_subagent_config("general-purpose")
|
||||
|
||||
assert overridden.name == original.name
|
||||
assert overridden.description == original.description
|
||||
assert overridden.max_turns == original.max_turns
|
||||
assert overridden.max_turns == 140
|
||||
assert overridden.model == original.model
|
||||
assert overridden.tools == original.tools
|
||||
assert overridden.disallowed_tools == original.disallowed_tools
|
||||
@ -291,9 +347,10 @@ class TestRegistryListSubagents:
|
||||
def test_all_returned_configs_get_global_override(self):
|
||||
from deerflow.subagents.registry import list_subagents
|
||||
|
||||
_reset_subagents_config(timeout_seconds=123)
|
||||
_reset_subagents_config(timeout_seconds=123, max_turns=77)
|
||||
for cfg in list_subagents():
|
||||
assert cfg.timeout_seconds == 123, f"{cfg.name} has wrong timeout"
|
||||
assert cfg.max_turns == 77, f"{cfg.name} has wrong max_turns"
|
||||
|
||||
def test_per_agent_overrides_reflected_in_list(self):
|
||||
from deerflow.subagents.registry import list_subagents
|
||||
@ -301,15 +358,18 @@ class TestRegistryListSubagents:
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"max_turns": 120,
|
||||
"agents": {
|
||||
"general-purpose": {"timeout_seconds": 1800},
|
||||
"bash": {"timeout_seconds": 60},
|
||||
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
|
||||
"bash": {"timeout_seconds": 60, "max_turns": 80},
|
||||
},
|
||||
}
|
||||
)
|
||||
by_name = {cfg.name: cfg for cfg in list_subagents()}
|
||||
assert by_name["general-purpose"].timeout_seconds == 1800
|
||||
assert by_name["bash"].timeout_seconds == 60
|
||||
assert by_name["general-purpose"].max_turns == 200
|
||||
assert by_name["bash"].max_turns == 80
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from app.gateway.routers import suggestions
|
||||
|
||||
@ -43,7 +43,7 @@ def test_generate_suggestions_parses_and_limits(monkeypatch):
|
||||
model_name=None,
|
||||
)
|
||||
fake_model = MagicMock()
|
||||
fake_model.invoke.return_value = MagicMock(content='```json\n["Q1", "Q2", "Q3", "Q4"]\n```')
|
||||
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content='```json\n["Q1", "Q2", "Q3", "Q4"]\n```'))
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
result = asyncio.run(suggestions.generate_suggestions("t1", req))
|
||||
@ -61,7 +61,7 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch):
|
||||
model_name=None,
|
||||
)
|
||||
fake_model = MagicMock()
|
||||
fake_model.invoke.return_value = MagicMock(content=[{"type": "text", "text": '```json\n["Q1", "Q2"]\n```'}])
|
||||
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "text", "text": '```json\n["Q1", "Q2"]\n```'}]))
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
result = asyncio.run(suggestions.generate_suggestions("t1", req))
|
||||
@ -79,7 +79,7 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch):
|
||||
model_name=None,
|
||||
)
|
||||
fake_model = MagicMock()
|
||||
fake_model.invoke.return_value = MagicMock(content=[{"type": "output_text", "text": '```json\n["Q1", "Q2"]\n```'}])
|
||||
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "output_text", "text": '```json\n["Q1", "Q2"]\n```'}]))
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
result = asyncio.run(suggestions.generate_suggestions("t1", req))
|
||||
@ -94,7 +94,7 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch):
|
||||
model_name=None,
|
||||
)
|
||||
fake_model = MagicMock()
|
||||
fake_model.invoke.side_effect = RuntimeError("boom")
|
||||
fake_model.ainvoke = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
result = asyncio.run(suggestions.generate_suggestions("t1", req))
|
||||
|
||||
111
backend/tests/test_thread_runs_router.py
Normal file
111
backend/tests/test_thread_runs_router.py
Normal file
@ -0,0 +1,111 @@
|
||||
"""Tests for thread_runs router with auth decorators.
|
||||
|
||||
These tests verify that auth decorators properly enforce permission checks
|
||||
on run endpoints. They follow the same pattern as test_threads_router.py.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.authz import AuthContext
|
||||
from app.gateway.routers.thread_runs import router
|
||||
|
||||
|
||||
def test_create_run_requires_auth():
|
||||
"""POST /{thread_id}/runs requires auth."""
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
response = client.post(
|
||||
"/api/threads/test-thread/runs",
|
||||
json={"assistant_id": "test"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_create_run_with_auth():
|
||||
"""POST /{thread_id}/runs with valid auth passes through."""
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
mock_user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
mock_auth = AuthContext(
|
||||
user=mock_user,
|
||||
permissions=["runs:create", "threads:read", "threads:write"],
|
||||
)
|
||||
|
||||
# Mock the checkpointer and run_manager to avoid 503s
|
||||
mock_checkpointer = MagicMock()
|
||||
mock_run_manager = MagicMock()
|
||||
mock_run_manager.list_by_thread = MagicMock(return_value=[])
|
||||
mock_stream_bridge = MagicMock()
|
||||
|
||||
with patch("app.gateway.routers.thread_runs.get_checkpointer", return_value=mock_checkpointer):
|
||||
with patch("app.gateway.routers.thread_runs.get_run_manager", return_value=mock_run_manager):
|
||||
with patch("app.gateway.routers.thread_runs.get_stream_bridge", return_value=mock_stream_bridge):
|
||||
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
# Without a real checkpointer.setup, this will 500 - but the point is auth passed
|
||||
response = client.post(
|
||||
"/api/threads/test-thread/runs",
|
||||
json={"assistant_id": "test"},
|
||||
)
|
||||
# Auth passed if we don't get 401
|
||||
assert response.status_code != 401
|
||||
|
||||
|
||||
def test_list_runs_requires_auth():
|
||||
"""GET /{thread_id}/runs requires auth."""
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
response = client.get("/api/threads/test-thread/runs")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_list_runs_with_auth():
|
||||
"""GET /{thread_id}/runs with auth passes through."""
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
mock_user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
mock_auth = AuthContext(
|
||||
user=mock_user,
|
||||
permissions=["runs:read", "threads:read"],
|
||||
)
|
||||
|
||||
mock_run_manager = MagicMock()
|
||||
mock_run_manager.list_by_thread = MagicMock(return_value=[])
|
||||
|
||||
with patch("app.gateway.routers.thread_runs.get_run_manager", return_value=mock_run_manager):
|
||||
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
response = client.get("/api/threads/test-thread/runs")
|
||||
# Should not be 401 (may be 500 or other, but auth passed)
|
||||
assert response.status_code != 401
|
||||
|
||||
|
||||
def test_get_run_requires_auth():
|
||||
"""GET /{thread_id}/runs/{run_id} requires auth."""
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
response = client.get("/api/threads/test-thread/runs/run-123")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_cancel_run_requires_auth():
|
||||
"""POST /{thread_id}/runs/{run_id}/cancel requires auth."""
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
response = client.post("/api/threads/test-thread/runs/run-123/cancel")
|
||||
assert response.status_code == 401
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user