Merge branch 'main' into fix-2425

This commit is contained in:
Willem Jiang 2026-05-01 20:02:17 +08:00 committed by GitHub
commit 2cb66f4a7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
401 changed files with 38657 additions and 4467 deletions

View File

@ -34,5 +34,14 @@ INFOQUEST_API_KEY=your-infoquest-api-key
# GitHub API Token
# GITHUB_TOKEN=your-github-token
# Database (only needed when config.yaml has database.backend: postgres)
# DATABASE_URL=postgresql://deerflow:password@localhost:5432/deerflow
#
# WECOM_BOT_ID=your-wecom-bot-id
# WECOM_BOT_SECRET=your-wecom-bot-secret
# DINGTALK_CLIENT_ID=your-dingtalk-client-id
# DINGTALK_CLIENT_SECRET=your-dingtalk-client-secret
# Set to "false" to disable Swagger UI, ReDoc, and OpenAPI schema in production
# GATEWAY_ENABLE_DOCS=false

1
.gitignore vendored
View File

@ -40,6 +40,7 @@ coverage/
skills/custom/*
logs/
log/
debug.log
# Local git hooks (keep only on this machine, do not push)
.githooks/

33
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,33 @@
repos:
# Backend: ruff lint + format via uv (uses the same ruff version as backend deps)
- repo: local
hooks:
- id: ruff
name: ruff lint
entry: bash -c 'cd backend && uv run ruff check --fix "${@/#backend\//}"' --
language: system
types_or: [python]
files: ^backend/
- id: ruff-format
name: ruff format
entry: bash -c 'cd backend && uv run ruff format "${@/#backend\//}"' --
language: system
types_or: [python]
files: ^backend/
# Frontend: eslint + prettier (must run from frontend/ for node_modules resolution)
- repo: local
hooks:
- id: frontend-eslint
name: eslint (frontend)
entry: bash -c 'cd frontend && npx eslint --fix "${@/#frontend\//}"' --
language: system
types_or: [javascript, tsx, ts]
files: ^frontend/
- id: frontend-prettier
name: prettier (frontend)
entry: bash -c 'cd frontend && npx prettier --write "${@/#frontend\//}"' --
language: system
files: ^frontend/
types_or: [javascript, tsx, ts, json, css]

View File

@ -166,7 +166,7 @@ Required tools:
1. **Configure the application** (same as Docker setup above)
2. **Install dependencies**:
2. **Install dependencies** (this also sets up pre-commit hooks):
```bash
make install
```

View File

@ -1,6 +1,6 @@
# DeerFlow - Unified Development Environment
.PHONY: help config config-upgrade check install setup doctor dev dev-pro dev-daemon dev-daemon-pro start start-pro start-daemon start-daemon-pro stop up up-pro down clean docker-init docker-start docker-start-pro docker-stop docker-logs docker-logs-frontend docker-logs-gateway
.PHONY: help config config-upgrade check install setup doctor dev dev-daemon start start-daemon stop up down clean docker-init docker-start docker-stop docker-logs docker-logs-frontend docker-logs-gateway
BASH ?= bash
BACKEND_UV_RUN = cd backend && uv run
@ -23,28 +23,22 @@ help:
@echo " make config - Generate local config files (aborts if config already exists)"
@echo " make config-upgrade - Merge new fields from config.example.yaml into config.yaml"
@echo " make check - Check if all required tools are installed"
@echo " make install - Install all dependencies (frontend + backend)"
@echo " make install - Install all dependencies (frontend + backend + pre-commit hooks)"
@echo " make setup-sandbox - Pre-pull sandbox container image (recommended)"
@echo " make dev - Start all services in development mode (with hot-reloading)"
@echo " make dev-pro - Start in dev + Gateway mode (experimental, no LangGraph server)"
@echo " make dev-daemon - Start dev services in background (daemon mode)"
@echo " make dev-daemon-pro - Start dev daemon + Gateway mode (experimental)"
@echo " make start - Start all services in production mode (optimized, no hot-reloading)"
@echo " make start-pro - Start in prod + Gateway mode (experimental)"
@echo " make start-daemon - Start prod services in background (daemon mode)"
@echo " make start-daemon-pro - Start prod daemon + Gateway mode (experimental)"
@echo " make stop - Stop all running services"
@echo " make clean - Clean up processes and temporary files"
@echo ""
@echo "Docker Production Commands:"
@echo " make up - Build and start production Docker services (localhost:2026)"
@echo " make up-pro - Build and start production Docker in Gateway mode (experimental)"
@echo " make down - Stop and remove production Docker containers"
@echo ""
@echo "Docker Development Commands:"
@echo " make docker-init - Pull the sandbox image"
@echo " make docker-start - Start Docker services (mode-aware from config.yaml, localhost:2026)"
@echo " make docker-start-pro - Start Docker in Gateway mode (experimental, no LangGraph container)"
@echo " make docker-stop - Stop Docker development services"
@echo " make docker-logs - View Docker development logs"
@echo " make docker-logs-frontend - View Docker frontend logs"
@ -73,6 +67,8 @@ install:
@cd backend && uv sync
@echo "Installing frontend dependencies..."
@cd frontend && pnpm install
@echo "Installing pre-commit hooks..."
@$(BACKEND_UV_RUN) --with pre-commit pre-commit install
@echo "✓ All dependencies installed"
@echo ""
@echo "=========================================="
@ -121,41 +117,21 @@ dev:
@$(PYTHON) ./scripts/check.py
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev
# Start all services in dev + Gateway mode (experimental: agent runtime embedded in Gateway)
dev-pro:
@$(PYTHON) ./scripts/check.py
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev --gateway
# Start all services in production mode (with optimizations)
start:
@$(PYTHON) ./scripts/check.py
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod
# Start all services in prod + Gateway mode (experimental)
start-pro:
@$(PYTHON) ./scripts/check.py
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod --gateway
# Start all services in daemon mode (background)
dev-daemon:
@$(PYTHON) ./scripts/check.py
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev --daemon
# Start daemon + Gateway mode (experimental)
dev-daemon-pro:
@$(PYTHON) ./scripts/check.py
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --dev --gateway --daemon
# Start prod services in daemon mode (background)
start-daemon:
@$(PYTHON) ./scripts/check.py
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod --daemon
# Start prod daemon + Gateway mode (experimental)
start-daemon-pro:
@$(PYTHON) ./scripts/check.py
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --prod --gateway --daemon
# Stop all services
stop:
@$(RUN_WITH_GIT_BASH) ./scripts/serve.sh --stop
@ -180,10 +156,6 @@ docker-init:
docker-start:
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh start
# Start Docker in Gateway mode (experimental)
docker-start-pro:
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh start --gateway
# Stop Docker development environment
docker-stop:
@$(RUN_WITH_GIT_BASH) ./scripts/docker.sh stop
@ -206,10 +178,6 @@ docker-logs-gateway:
up:
@$(RUN_WITH_GIT_BASH) ./scripts/deploy.sh
# Build and start production services in Gateway mode
up-pro:
@$(RUN_WITH_GIT_BASH) ./scripts/deploy.sh --gateway
# Stop and remove production containers
down:
@$(RUN_WITH_GIT_BASH) ./scripts/deploy.sh down

View File

@ -243,9 +243,6 @@ make up # Build images and start all production services
make down # Stop and remove containers
```
> [!NOTE]
> The LangGraph agent server currently runs via `langgraph dev` (the open-source CLI server).
Access: http://localhost:2026
See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
@ -264,7 +261,7 @@ On Windows, run the local development flow from Git Bash. Native `cmd.exe` and P
2. **Install dependencies**:
```bash
make install # Install backend + frontend dependencies
make install # Install backend + frontend dependencies + pre-commit hooks
```
3. **(Optional) Pre-pull sandbox image**:
@ -289,53 +286,31 @@ On Windows, run the local development flow from Git Bash. Native `cmd.exe` and P
#### Startup Modes
DeerFlow supports multiple startup modes across two dimensions:
- **Dev / Prod** — dev enables hot-reload; prod uses pre-built frontend
- **Standard / Gateway** — standard uses a separate LangGraph server (4 processes); Gateway mode (experimental) embeds the agent runtime in the Gateway API (3 processes)
DeerFlow runs the agent runtime inside the Gateway API. Development mode enables hot-reload; production mode uses a pre-built frontend.
| | **Local Foreground** | **Local Daemon** | **Docker Dev** | **Docker Prod** |
|---|---|---|---|---|
| **Dev** | `./scripts/serve.sh --dev`<br/>`make dev` | `./scripts/serve.sh --dev --daemon`<br/>`make dev-daemon` | `./scripts/docker.sh start`<br/>`make docker-start` | — |
| **Dev + Gateway** | `./scripts/serve.sh --dev --gateway`<br/>`make dev-pro` | `./scripts/serve.sh --dev --gateway --daemon`<br/>`make dev-daemon-pro` | `./scripts/docker.sh start --gateway`<br/>`make docker-start-pro` | — |
| **Prod** | `./scripts/serve.sh --prod`<br/>`make start` | `./scripts/serve.sh --prod --daemon`<br/>`make start-daemon` | — | `./scripts/deploy.sh`<br/>`make up` |
| **Prod + Gateway** | `./scripts/serve.sh --prod --gateway`<br/>`make start-pro` | `./scripts/serve.sh --prod --gateway --daemon`<br/>`make start-daemon-pro` | — | `./scripts/deploy.sh --gateway`<br/>`make up-pro` |
| Action | Local | Docker Dev | Docker Prod |
|---|---|---|---|
| **Stop** | `./scripts/serve.sh --stop`<br/>`make stop` | `./scripts/docker.sh stop`<br/>`make docker-stop` | `./scripts/deploy.sh down`<br/>`make down` |
| **Restart** | `./scripts/serve.sh --restart [flags]` | `./scripts/docker.sh restart` | — |
> **Gateway mode** eliminates the LangGraph server process — the Gateway API handles agent execution directly via async tasks, managing its own concurrency.
#### Why Gateway Mode?
In standard mode, DeerFlow runs a dedicated [LangGraph Platform](https://langchain-ai.github.io/langgraph/) server alongside the Gateway API. This architecture works well but has trade-offs:
| | Standard Mode | Gateway Mode |
|---|---|---|
| **Architecture** | Gateway (REST API) + LangGraph (agent runtime) | Gateway embeds agent runtime |
| **Concurrency** | `--n-jobs-per-worker` per worker (requires license) | `--workers` × async tasks (no per-worker cap) |
| **Containers / Processes** | 4 (frontend, gateway, langgraph, nginx) | 3 (frontend, gateway, nginx) |
| **Resource usage** | Higher (two Python runtimes) | Lower (single Python runtime) |
| **LangGraph Platform license** | Required for production images | Not required |
| **Cold start** | Slower (two services to initialize) | Faster |
Both modes are functionally equivalent — the same agents, tools, and skills work in either mode.
Gateway owns `/api/langgraph/*` and translates those public LangGraph-compatible paths to its native `/api/*` routers behind nginx.
#### Docker Production Deployment
`deploy.sh` supports building and starting separately. Images are mode-agnostic — runtime mode is selected at start time:
`deploy.sh` supports building and starting separately:
```bash
# One-step (build + start)
deploy.sh # standard mode (default)
deploy.sh --gateway # gateway mode
deploy.sh
# Two-step (build once, start with any mode)
# Two-step (build once, start later)
deploy.sh build # build all images
deploy.sh start # start in standard mode
deploy.sh start --gateway # start in gateway mode
deploy.sh start # start pre-built images
# Stop
deploy.sh down
@ -370,13 +345,14 @@ DeerFlow supports receiving tasks from messaging apps. Channels auto-start when
| Feishu / Lark | WebSocket | Moderate |
| WeChat | Tencent iLink (long-polling) | Moderate |
| WeCom | WebSocket | Moderate |
| DingTalk | Stream Push (WebSocket) | Moderate |
**Configuration in `config.yaml`:**
```yaml
channels:
# LangGraph Server URL (default: http://localhost:2024)
langgraph_url: http://localhost:2024
# LangGraph-compatible Gateway API base URL (default: http://localhost:8001/api)
langgraph_url: http://localhost:8001/api
# Gateway API URL (default: http://localhost:8001)
gateway_url: http://localhost:8001
@ -439,11 +415,19 @@ channels:
context:
thinking_enabled: true
subagent_enabled: true
dingtalk:
enabled: true
client_id: $DINGTALK_CLIENT_ID # Client ID of your DingTalk application
client_secret: $DINGTALK_CLIENT_SECRET # Client Secret of your DingTalk application
allowed_users: [] # empty = allow all
card_template_id: "" # Optional: AI Card template ID for streaming typewriter effect
```
Notes:
- `assistant_id: lead_agent` calls the default LangGraph assistant directly.
- If `assistant_id` is set to a custom agent name, DeerFlow still routes through `lead_agent` and injects that value as `agent_name`, so the custom agent's SOUL/config takes effect for IM channels.
- IM channel workers call Gateway's LangGraph-compatible API internally and automatically attach process-local internal auth plus the CSRF cookie/header pair required for thread and run creation.
Set the corresponding API keys in your `.env` file:
@ -466,6 +450,10 @@ WECHAT_ILINK_BOT_ID=your_ilink_bot_id
# WeCom
WECOM_BOT_ID=your_bot_id
WECOM_BOT_SECRET=your_bot_secret
# DingTalk
DINGTALK_CLIENT_ID=your_client_id
DINGTALK_CLIENT_SECRET=your_client_secret
```
**Telegram Setup**
@ -504,7 +492,15 @@ WECOM_BOT_SECRET=your_bot_secret
4. Make sure backend dependencies include `wecom-aibot-python-sdk`. The channel uses a WebSocket long connection and does not require a public callback URL.
5. The current integration supports inbound text, image, and file messages. Final images/files generated by the agent are also sent back to the WeCom conversation.
When DeerFlow runs in Docker Compose, IM channels execute inside the `gateway` container. In that case, do not point `channels.langgraph_url` or `channels.gateway_url` at `localhost`; use container service names such as `http://langgraph:2024` and `http://gateway:8001`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` and `DEER_FLOW_CHANNELS_GATEWAY_URL`.
**DingTalk Setup**
1. Create a DingTalk application in the [DingTalk Developer Console](https://open.dingtalk.com/) and enable **Robot** capability.
2. Set the message receiving mode to **Stream Mode** in the robot configuration page.
3. Copy the `Client ID` and `Client Secret`, set `DINGTALK_CLIENT_ID` and `DINGTALK_CLIENT_SECRET` in `.env`, and enable the channel in `config.yaml`.
4. *(Optional)* To enable streaming AI Card replies (typewriter effect), create an **AI Card** template on the [DingTalk Card Platform](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card), then set `card_template_id` in `config.yaml` to the template ID. You also need to apply for the `Card.Streaming.Write` and `Card.Instance.Write` permissions.
When DeerFlow runs in Docker Compose, IM channels execute inside the `gateway` container. In that case, do not point `channels.langgraph_url` or `channels.gateway_url` at `localhost`; use container service names such as `http://gateway:8001/api` and `http://gateway:8001`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` and `DEER_FLOW_CHANNELS_GATEWAY_URL`.
**Commands**

View File

@ -290,6 +290,7 @@ DeerFlow peut recevoir des tâches depuis des applications de messagerie. Les ca
| Telegram | Bot API (long-polling) | Facile |
| Slack | Socket Mode | Modérée |
| Feishu / Lark | WebSocket | Modérée |
| DingTalk | Stream Push (WebSocket) | Modérée |
**Configuration dans `config.yaml` :**
@ -341,6 +342,13 @@ channels:
context:
thinking_enabled: true
subagent_enabled: true
dingtalk:
enabled: true
client_id: $DINGTALK_CLIENT_ID # ClientId depuis DingTalk Open Platform
client_secret: $DINGTALK_CLIENT_SECRET # ClientSecret depuis DingTalk Open Platform
allowed_users: [] # vide = tout le monde autorisé
card_template_id: "" # Optionnel : ID de modèle AI Card pour l'effet machine à écrire en streaming
```
Définissez les clés API correspondantes dans votre fichier `.env` :
@ -356,6 +364,10 @@ SLACK_APP_TOKEN=xapp-...
# Feishu / Lark
FEISHU_APP_ID=cli_xxxx
FEISHU_APP_SECRET=your_app_secret
# DingTalk
DINGTALK_CLIENT_ID=your_client_id
DINGTALK_CLIENT_SECRET=your_client_secret
```
**Configuration Telegram**
@ -378,6 +390,13 @@ FEISHU_APP_SECRET=your_app_secret
3. Dans **Events**, abonnez-vous à `im.message.receive_v1` et sélectionnez le mode **Long Connection**.
4. Copiez l'App ID et l'App Secret. Définissez `FEISHU_APP_ID` et `FEISHU_APP_SECRET` dans `.env` et activez le canal dans `config.yaml`.
**Configuration DingTalk**
1. Créez une application sur [DingTalk Open Platform](https://open.dingtalk.com/) et activez la capacité **Robot**.
2. Dans la page de configuration du robot, définissez le mode de réception des messages sur **Stream**.
3. Copiez le `Client ID` et le `Client Secret`. Définissez `DINGTALK_CLIENT_ID` et `DINGTALK_CLIENT_SECRET` dans `.env` et activez le canal dans `config.yaml`.
4. *(Optionnel)* Pour activer les réponses en streaming AI Card (effet machine à écrire), créez un modèle **AI Card** sur la [plateforme de cartes DingTalk](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card), puis définissez `card_template_id` dans `config.yaml` avec l'ID du modèle. Vous devez également demander les permissions `Card.Streaming.Write` et `Card.Instance.Write`.
**Commandes**
Une fois un canal connecté, vous pouvez interagir avec DeerFlow directement depuis le chat :

View File

@ -243,6 +243,7 @@ DeerFlowはメッセージングアプリからのタスク受信をサポート
| Telegram | Bot APIロングポーリング | 簡単 |
| Slack | Socket Mode | 中程度 |
| Feishu / Lark | WebSocket | 中程度 |
| DingTalk | Stream PushWebSocket | 中程度 |
**`config.yaml`での設定:**
@ -294,6 +295,13 @@ channels:
context:
thinking_enabled: true
subagent_enabled: true
dingtalk:
enabled: true
client_id: $DINGTALK_CLIENT_ID # DingTalk Open PlatformのClientId
client_secret: $DINGTALK_CLIENT_SECRET # DingTalk Open PlatformのClientSecret
allowed_users: [] # 空 = 全員許可
card_template_id: "" # オプションストリーミングタイプライター効果用のAIカードテンプレートID
```
対応するAPIキーを`.env`ファイルに設定します:
@ -309,6 +317,10 @@ SLACK_APP_TOKEN=xapp-...
# Feishu / Lark
FEISHU_APP_ID=cli_xxxx
FEISHU_APP_SECRET=your_app_secret
# DingTalk
DINGTALK_CLIENT_ID=your_client_id
DINGTALK_CLIENT_SECRET=your_client_secret
```
**Telegramのセットアップ**
@ -331,6 +343,13 @@ FEISHU_APP_SECRET=your_app_secret
3. **イベント**で`im.message.receive_v1`を購読し、**ロングコネクション**モードを選択。
4. App IDとApp Secretをコピー。`.env``FEISHU_APP_ID``FEISHU_APP_SECRET`を設定し、`config.yaml`でチャネルを有効にします。
**DingTalkのセットアップ**
1. [DingTalk Open Platform](https://open.dingtalk.com/)でアプリを作成し、**ロボット**機能を有効化します。
2. ロボット設定ページでメッセージ受信モードを**Streamモード**に設定します。
3. `Client ID``Client Secret`をコピー。`.env``DINGTALK_CLIENT_ID``DINGTALK_CLIENT_SECRET`を設定し、`config.yaml`でチャネルを有効にします。
4. *(オプション)* ストリーミングAIカード返信タイプライター効果を有効にするには、[DingTalkカードプラットフォーム](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card)で**AIカード**テンプレートを作成し、`config.yaml``card_template_id`にテンプレートIDを設定します。`Card.Streaming.Write` および `Card.Instance.Write` 権限の申請も必要です。
**コマンド**
チャネル接続後、チャットから直接DeerFlowと対話できます

View File

@ -256,6 +256,7 @@ DeerFlow принимает задачи прямо из мессенджеро
| Telegram | Bot API (long-polling) | Просто |
| Slack | Socket Mode | Средне |
| Feishu / Lark | WebSocket | Средне |
| DingTalk | Stream Push (WebSocket) | Средне |
**Конфигурация в `config.yaml`:**
@ -278,6 +279,13 @@ channels:
enabled: true
bot_token: $TELEGRAM_BOT_TOKEN
allowed_users: []
dingtalk:
enabled: true
client_id: $DINGTALK_CLIENT_ID # ClientId с DingTalk Open Platform
client_secret: $DINGTALK_CLIENT_SECRET # ClientSecret с DingTalk Open Platform
allowed_users: [] # пусто = разрешить всем
card_template_id: "" # Опционально: ID шаблона AI Card для потокового эффекта печатной машинки
```
**Настройка Telegram**
@ -285,6 +293,13 @@ channels:
1. Напишите [@BotFather](https://t.me/BotFather), отправьте `/newbot` и скопируйте HTTP API-токен.
2. Укажите `TELEGRAM_BOT_TOKEN` в `.env` и включите канал в `config.yaml`.
**Настройка DingTalk**
1. Создайте приложение на [DingTalk Open Platform](https://open.dingtalk.com/) и включите возможность **Робот**.
2. На странице настроек робота установите режим приёма сообщений на **Stream**.
3. Скопируйте `Client ID` и `Client Secret`. Укажите `DINGTALK_CLIENT_ID` и `DINGTALK_CLIENT_SECRET` в `.env` и включите канал в `config.yaml`.
4. *(Опционально)* Для включения потоковых ответов AI Card (эффект печатной машинки) создайте шаблон **AI Card** на [платформе карточек DingTalk](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card), затем укажите `card_template_id` в `config.yaml` с ID шаблона. Также необходимо запросить разрешения `Card.Streaming.Write` и `Card.Instance.Write`.
**Доступные команды**
| Команда | Описание |

View File

@ -248,6 +248,7 @@ DeerFlow 支持从即时通讯应用接收任务。只要配置完成,对应
| Slack | Socket Mode | 中等 |
| Feishu / Lark | WebSocket | 中等 |
| 企业微信智能机器人 | WebSocket | 中等 |
| 钉钉 | Stream PushWebSocket | 中等 |
**`config.yaml` 中的配置示例:**
@ -304,6 +305,13 @@ channels:
context:
thinking_enabled: true
subagent_enabled: true
dingtalk:
enabled: true
client_id: $DINGTALK_CLIENT_ID # 钉钉开放平台 ClientId
client_secret: $DINGTALK_CLIENT_SECRET # 钉钉开放平台 ClientSecret
allowed_users: [] # 留空表示允许所有人
card_template_id: "" # 可选AI 卡片模板 ID用于流式打字机效果
```
说明:
@ -327,6 +335,10 @@ FEISHU_APP_SECRET=your_app_secret
# 企业微信智能机器人
WECOM_BOT_ID=your_bot_id
WECOM_BOT_SECRET=your_bot_secret
# 钉钉
DINGTALK_CLIENT_ID=your_client_id
DINGTALK_CLIENT_SECRET=your_client_secret
```
**Telegram 配置**
@ -357,6 +369,13 @@ WECOM_BOT_SECRET=your_bot_secret
4. 安装后端依赖时确保包含 `wecom-aibot-python-sdk`,渠道会通过 WebSocket 长连接接收消息,无需公网回调地址。
5. 当前支持文本、图片和文件入站消息agent 生成的最终图片/文件也会回传到企业微信会话中。
**钉钉配置**
1. 在 [钉钉开放平台](https://open.dingtalk.com/) 创建应用,并启用 **机器人** 能力。
2. 在机器人配置页面设置消息接收模式为 **Stream模式**
3. 复制 `Client ID``Client Secret`,在 `.env` 中设置 `DINGTALK_CLIENT_ID``DINGTALK_CLIENT_SECRET`,并在 `config.yaml` 中启用该渠道。
4. *(可选)* 如需开启流式 AI 卡片回复(打字机效果),请在[钉钉卡片平台](https://open.dingtalk.com/document/dingstart/typewriter-effect-streaming-ai-card)创建 **AI 卡片**模板,然后在 `config.yaml` 中将 `card_template_id` 设为该模板 ID。同时需要申请 `Card.Streaming.Write``Card.Instance.Write` 权限。
**命令**
渠道连接完成后,你可以直接在聊天窗口里和 DeerFlow 交互:

View File

@ -7,15 +7,13 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co
DeerFlow is a LangGraph-based AI super agent system with a full-stack architecture. The backend provides a "super agent" with sandbox execution, persistent memory, subagent delegation, and extensible tool integration - all operating in per-thread isolated environments.
**Architecture**:
- **LangGraph Server** (port 2024): Agent runtime and workflow execution
- **Gateway API** (port 8001): REST API for models, MCP, skills, memory, artifacts, uploads, and local thread cleanup
- **Gateway API** (port 8001): REST API plus embedded LangGraph-compatible agent runtime
- **Frontend** (port 3000): Next.js web interface
- **Nginx** (port 2026): Unified reverse proxy entry point
- **Provisioner** (port 8002, optional in Docker dev): Started only when sandbox is configured for provisioner/Kubernetes mode
**Runtime Modes**:
- **Standard mode** (`make dev`): LangGraph Server handles agent execution as a separate process. 4 processes total.
- **Gateway mode** (`make dev-pro`, experimental): Agent runtime embedded in Gateway via `RunManager` + `run_agent()` + `StreamBridge` (`packages/harness/deerflow/runtime/`). Service manages its own concurrency via async tasks. 3 processes total, no LangGraph Server.
**Runtime**:
- `make dev`, Docker dev, and production all run the agent runtime in Gateway via `RunManager` + `run_agent()` + `StreamBridge` (`packages/harness/deerflow/runtime/`). Nginx exposes that runtime at `/api/langgraph/*` and rewrites it to Gateway's native `/api/*` routers.
**Project Structure**:
```
@ -25,7 +23,7 @@ deer-flow/
├── extensions_config.json # MCP servers and skills configuration
├── backend/ # Backend application (this directory)
│ ├── Makefile # Backend-only commands (dev, gateway, lint)
│ ├── langgraph.json # LangGraph server configuration
│ ├── langgraph.json # LangGraph Studio graph configuration
│ ├── packages/
│ │ └── harness/ # deerflow-harness package (import: deerflow.*)
│ │ ├── pyproject.toml
@ -83,16 +81,15 @@ When making code changes, you MUST update the relevant documentation:
```bash
make check # Check system requirements
make install # Install all dependencies (frontend + backend)
make dev # Start all services (LangGraph + Gateway + Frontend + Nginx), with config.yaml preflight
make dev-pro # Gateway mode (experimental): skip LangGraph, agent runtime embedded in Gateway
make start-pro # Production + Gateway mode (experimental)
make dev # Start all services (Gateway + Frontend + Nginx), with config.yaml preflight
make start # Start production services locally
make stop # Stop all services
```
**Backend directory** (for backend development only):
```bash
make install # Install backend dependencies
make dev # Run LangGraph server only (port 2024)
make dev # Run Gateway API with reload (port 8001)
make gateway # Run Gateway API only (port 8001)
make test # Run all backend tests
make lint # Lint with ruff
@ -115,7 +112,7 @@ CI runs these regression tests for every pull request via [.github/workflows/bac
The backend is split into two layers with a strict dependency direction:
- **Harness** (`packages/harness/deerflow/`): Publishable agent framework package (`deerflow-harness`). Import prefix: `deerflow.*`. Contains agent orchestration, tools, sandbox, models, MCP, skills, config — everything needed to build and run agents.
- **App** (`app/`): Unpublished application code. Import prefix: `app.*`. Contains the FastAPI Gateway API and IM channel integrations (Feishu, Slack, Telegram).
- **App** (`app/`): Unpublished application code. Import prefix: `app.*`. Contains the FastAPI Gateway API and IM channel integrations (Feishu, Slack, Telegram, DingTalk).
**Dependency rule**: App imports deerflow, but deerflow never imports app. This boundary is enforced by `tests/test_harness_boundary.py` which runs in CI.
@ -158,7 +155,7 @@ from deerflow.config import get_app_config
Lead-agent middlewares are assembled in strict append order across `packages/harness/deerflow/agents/middlewares/tool_error_handling_middleware.py` (`build_lead_runtime_middlewares`) and `packages/harness/deerflow/agents/lead_agent/agent.py` (`_build_middlewares`):
1. **ThreadDataMiddleware** - Creates per-thread directories (`backend/.deer-flow/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local `.deer-flow/threads/{thread_id}` directory
1. **ThreadDataMiddleware** - Creates per-thread directories under the user's isolation scope (`backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/{workspace,uploads,outputs}`); resolves `user_id` via `get_effective_user_id()` (falls back to `"default"` in no-auth mode); Web UI thread deletion now follows LangGraph thread removal with Gateway cleanup of the local thread directory
2. **UploadsMiddleware** - Tracks and injects newly uploaded files into conversation
3. **SandboxMiddleware** - Acquires sandbox, stores `sandbox_id` in state
4. **DanglingToolCallMiddleware** - Injects placeholder ToolMessages for AIMessage tool_calls that lack responses (e.g., due to user interruption), including raw provider tool-call payloads preserved only in `additional_kwargs["tool_calls"]`
@ -208,7 +205,7 @@ Configuration priority:
### Gateway API (`app/gateway/`)
FastAPI application on port 8001 with health check at `GET /health`.
FastAPI application on port 8001 with health check at `GET /health`. Set `GATEWAY_ENABLE_DOCS=false` to disable `/docs`, `/redoc`, and `/openapi.json` in production (default: enabled).
**Routers**:
@ -222,6 +219,9 @@ FastAPI application on port 8001 with health check at `GET /health`.
| **Threads** (`/api/threads/{id}`) | `DELETE /` - remove DeerFlow-managed local thread data after LangGraph thread deletion; unexpected failures are logged server-side and return a generic 500 detail |
| **Artifacts** (`/api/threads/{id}/artifacts`) | `GET /{path}` - serve artifacts; active content types (`text/html`, `application/xhtml+xml`, `image/svg+xml`) are always forced as download attachments to reduce XSS risk; `?download=true` still forces download for other file types |
| **Suggestions** (`/api/threads/{id}/suggestions`) | `POST /` - generate follow-up questions; rich list/block model content is normalized before JSON parsing |
| **Thread Runs** (`/api/threads/{id}/runs`) | `POST /` - create background run; `POST /stream` - create + SSE stream; `POST /wait` - create + block; `GET /` - list runs; `GET /{rid}` - run details; `POST /{rid}/cancel` - cancel; `GET /{rid}/join` - join SSE; `GET /{rid}/messages` - paginated messages `{data, has_more}`; `GET /{rid}/events` - full event stream; `GET /../messages` - thread messages with feedback; `GET /../token-usage` - aggregate tokens |
| **Feedback** (`/api/threads/{id}/runs/{rid}/feedback`) | `PUT /` - upsert feedback; `DELETE /` - delete user feedback; `POST /` - create feedback; `GET /` - list feedback; `GET /stats` - aggregate stats; `DELETE /{fid}` - delete specific |
| **Runs** (`/api/runs`) | `POST /stream` - stateless run + SSE; `POST /wait` - stateless run + block; `GET /{rid}/messages` - paginated messages by run_id `{data, has_more}` (cursor: `after_seq`/`before_seq`); `GET /{rid}/feedback` - list feedback by run_id |
Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` → Gateway.
@ -235,7 +235,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
**Virtual Path System**:
- Agent sees: `/mnt/user-data/{workspace,uploads,outputs}`, `/mnt/skills`
- Physical: `backend/.deer-flow/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
- Physical: `backend/.deer-flow/users/{user_id}/threads/{thread_id}/user-data/...`, `deer-flow/skills/`
- Translation: `replace_virtual_path()` / `replace_virtual_paths_in_command()`
- Detection: `is_local_sandbox()` checks `sandbox_id == "local"`
@ -275,7 +275,7 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
- `invoke_acp_agent` - Invokes external ACP-compatible agents from `config.yaml`
- ACP launchers must be real ACP adapters. The standard `codex` CLI is not ACP-compatible by itself; configure a wrapper such as `npx -y @zed-industries/codex-acp` or an installed `codex-acp` binary
- Missing ACP executables now return an actionable error message instead of a raw `[Errno 2]`
- Each ACP agent uses a per-thread workspace at `{base_dir}/threads/{thread_id}/acp-workspace/`. The workspace is accessible to the lead agent via the virtual path `/mnt/acp-workspace/` (read-only). In docker sandbox mode, the directory is volume-mounted into the container at `/mnt/acp-workspace` (read-only); in local sandbox mode, path translation is handled by `tools.py`
- Each ACP agent uses a per-thread workspace at `{base_dir}/users/{user_id}/threads/{thread_id}/acp-workspace/`. The workspace is accessible to the lead agent via the virtual path `/mnt/acp-workspace/` (read-only). In docker sandbox mode, the directory is volume-mounted into the container at `/mnt/acp-workspace` (read-only); in local sandbox mode, path translation is handled by `tools.py`
- `image_search/` - Image search via DuckDuckGo
### MCP System (`packages/harness/deerflow/mcp/`)
@ -312,9 +312,10 @@ Proxied through nginx: `/api/langgraph/*` → LangGraph, all other `/api/*` →
### IM Channels System (`app/channels/`)
Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow agent via the LangGraph Server.
Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the DeerFlow agent via the LangGraph Server.
**Architecture**: Channels communicate with the LangGraph Server through `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side.
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies.
**Components**:
- `message_bus.py` - Async pub/sub hub (`InboundMessage` → queue → dispatcher; `OutboundMessage` → callbacks → channels)
@ -322,40 +323,51 @@ Bridges external messaging platforms (Feishu, Slack, Telegram) to the DeerFlow a
- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Telegram on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu incremental outbound updates
- `base.py` - Abstract `Channel` base class (start/stop/send lifecycle)
- `service.py` - Manages lifecycle of all configured channels from `config.yaml`
- `slack.py` / `feishu.py` / `telegram.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place)
- `slack.py` / `feishu.py` / `telegram.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured)
**Message Flow**:
1. External platform -> Channel impl -> `MessageBus.publish_inbound()`
2. `ChannelManager._dispatch_loop()` consumes from queue
3. For chat: look up/create thread on LangGraph Server
3. For chat: look up/create thread through Gateway's LangGraph-compatible API
4. Feishu chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
5. Slack/Telegram chat: `runs.wait()` → extract final response → publish outbound
6. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
7. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
8. Outbound → channel callbacks → platform reply
7. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails
8. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
9. Outbound → channel callbacks → platform reply
**Configuration** (`config.yaml` -> `channels`):
- `langgraph_url` - LangGraph Server URL (default: `http://localhost:2024`)
- `langgraph_url` - LangGraph-compatible Gateway API base URL (default: `http://localhost:8001/api`)
- `gateway_url` - Gateway API URL for auxiliary commands (default: `http://localhost:8001`)
- In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://langgraph:2024` / `http://gateway:8001`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`.
- Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token)
- In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://gateway:8001/api` for `langgraph_url` and `http://gateway:8001` for `gateway_url`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`.
- Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token), `dingtalk` (client_id, client_secret, optional `card_template_id` for AI Card streaming)
### Memory System (`packages/harness/deerflow/agents/memory/`)
**Components**:
- `updater.py` - LLM-based memory updates with fact extraction, whitespace-normalized fact deduplication (trims leading/trailing whitespace before comparing), and atomic file I/O
- `queue.py` - Debounced update queue (per-thread deduplication, configurable wait time)
- `queue.py` - Debounced update queue (per-thread deduplication, configurable wait time); captures `user_id` at enqueue time so it survives the `threading.Timer` boundary
- `prompt.py` - Prompt templates for memory updates
- `storage.py` - File-based storage with per-user isolation; cache keyed by `(user_id, agent_name)` tuple
**Data Structure** (stored in `backend/.deer-flow/memory.json`):
**Per-User Isolation**:
- Memory is stored per-user at `{base_dir}/users/{user_id}/memory.json`
- Per-agent per-user memory at `{base_dir}/users/{user_id}/agents/{agent_name}/memory.json`
- `user_id` is resolved via `get_effective_user_id()` from `deerflow.runtime.user_context`
- In no-auth mode, `user_id` defaults to `"default"` (constant `DEFAULT_USER_ID`)
- Absolute `storage_path` in config opts out of per-user isolation
- **Migration**: Run `PYTHONPATH=. python scripts/migrate_user_isolation.py` to move legacy `memory.json` and `threads/` into per-user layout; supports `--dry-run`
**Data Structure** (stored in `{base_dir}/users/{user_id}/memory.json`):
- **User Context**: `workContext`, `personalContext`, `topOfMind` (1-3 sentence summaries)
- **History**: `recentMonths`, `earlierContext`, `longTermBackground`
- **Facts**: Discrete facts with `id`, `content`, `category` (preference/knowledge/context/behavior/goal), `confidence` (0-1), `createdAt`, `source`
**Workflow**:
1. `MemoryMiddleware` filters messages (user inputs + final AI responses) and queues conversation
1. `MemoryMiddleware` filters messages (user inputs + final AI responses), captures `user_id` via `get_effective_user_id()`, and queues conversation with the captured `user_id`
2. Queue debounces (30s default), batches updates, deduplicates per-thread
3. Background thread invokes LLM to extract context updates and facts
3. Background thread invokes LLM to extract context updates and facts, using the stored `user_id` (not the contextvar, which is unavailable on timer threads)
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
@ -363,7 +375,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
**Configuration** (`config.yaml``memory`):
- `enabled` / `injection_enabled` - Master switches
- `storage_path` - Path to memory.json
- `storage_path` - Path to memory.json (absolute path opts out of per-user isolation)
- `debounce_seconds` - Wait time before processing (default: 30)
- `model_name` - LLM for updates (null = default model)
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
@ -398,9 +410,9 @@ Both can be modified at runtime via Gateway API endpoints or `DeerFlowClient` me
`DeerFlowClient` provides direct in-process access to all DeerFlow capabilities without HTTP services. All return types align with the Gateway API response schemas, so consumer code works identically in HTTP and embedded modes.
**Architecture**: Imports the same `deerflow` modules that LangGraph Server and Gateway API use. Shares the same config files and data directories. No FastAPI dependency.
**Architecture**: Imports the same `deerflow` modules that Gateway API uses. Shares the same config files and data directories. No FastAPI dependency.
**Agent Conversation** (replaces LangGraph Server):
**Agent Conversation**:
- `chat(message, thread_id)` — synchronous, accumulates streaming deltas per message-id and returns the final AI text
- `stream(message, thread_id)` — subscribes to LangGraph `stream_mode=["values", "messages", "custom"]` and yields `StreamEvent`:
- `"values"` — full state snapshot (title, messages, artifacts); AI text already delivered via `messages` mode is **not** re-synthesized here to avoid duplicate deliveries
@ -463,20 +475,15 @@ This starts all services and makes the application available at `http://localhos
| | **Local Foreground** | **Local Daemon** | **Docker Dev** | **Docker Prod** |
|---|---|---|---|---|
| **Dev** | `./scripts/serve.sh --dev`<br/>`make dev` | `./scripts/serve.sh --dev --daemon`<br/>`make dev-daemon` | `./scripts/docker.sh start`<br/>`make docker-start` | — |
| **Dev + Gateway** | `./scripts/serve.sh --dev --gateway`<br/>`make dev-pro` | `./scripts/serve.sh --dev --gateway --daemon`<br/>`make dev-daemon-pro` | `./scripts/docker.sh start --gateway`<br/>`make docker-start-pro` | — |
| **Prod** | `./scripts/serve.sh --prod`<br/>`make start` | `./scripts/serve.sh --prod --daemon`<br/>`make start-daemon` | — | `./scripts/deploy.sh`<br/>`make up` |
| **Prod + Gateway** | `./scripts/serve.sh --prod --gateway`<br/>`make start-pro` | `./scripts/serve.sh --prod --gateway --daemon`<br/>`make start-daemon-pro` | — | `./scripts/deploy.sh --gateway`<br/>`make up-pro` |
| Action | Local | Docker Dev | Docker Prod |
|---|---|---|---|
| **Stop** | `./scripts/serve.sh --stop`<br/>`make stop` | `./scripts/docker.sh stop`<br/>`make docker-stop` | `./scripts/deploy.sh down`<br/>`make down` |
| **Restart** | `./scripts/serve.sh --restart [flags]` | `./scripts/docker.sh restart` | — |
Gateway mode embeds the agent runtime in Gateway, no LangGraph server.
**Nginx routing**:
- Standard mode: `/api/langgraph/*` → LangGraph Server (2024)
- Gateway mode: `/api/langgraph/*` → Gateway embedded runtime (8001) (via envsubst)
- `/api/langgraph/*` → Gateway embedded runtime (8001), rewritten to `/api/*`
- `/api/*` (other) → Gateway API (8001)
- `/` (non-API) → Frontend (3000)
@ -485,15 +492,11 @@ Gateway mode embeds the agent runtime in Gateway, no LangGraph server.
From the **backend** directory:
```bash
# Terminal 1: LangGraph server
make dev
# Terminal 2: Gateway API
# Gateway API
make gateway
```
Direct access (without nginx):
- LangGraph: `http://localhost:2024`
- Gateway: `http://localhost:8001`
### Frontend Configuration

View File

@ -13,6 +13,9 @@ FROM python:3.12-slim-bookworm AS builder
ARG NODE_MAJOR=22
ARG APT_MIRROR
ARG UV_INDEX_URL
# Optional extras to install (e.g. "postgres" for PostgreSQL support)
# Usage: docker build --build-arg UV_EXTRAS=postgres ...
ARG UV_EXTRAS
# Optionally override apt mirror for restricted networks (e.g. APT_MIRROR=mirrors.aliyun.com)
RUN if [ -n "${APT_MIRROR}" ]; then \
@ -43,8 +46,9 @@ WORKDIR /app
COPY backend ./backend
# Install dependencies with cache mount
# When UV_EXTRAS is set (e.g. "postgres"), installs optional dependencies.
RUN --mount=type=cache,target=/root/.cache/uv \
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync"
sh -c "cd backend && UV_INDEX_URL=${UV_INDEX_URL:-https://pypi.org/simple} uv sync ${UV_EXTRAS:+--extra $UV_EXTRAS}"
# ── Stage 2: Dev ──────────────────────────────────────────────────────────────
# Retains compiler toolchain from builder so startup-time `uv sync` can build

View File

@ -2,7 +2,7 @@ install:
uv sync
dev:
uv run langgraph dev --no-browser --no-reload --n-jobs-per-worker 10
PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 --reload
gateway:
PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001

View File

@ -2,7 +2,7 @@
Provides a pluggable channel system that connects external messaging platforms
(Feishu/Lark, Slack, Telegram) to the DeerFlow agent via the ChannelManager,
which uses ``langgraph-sdk`` to communicate with the underlying LangGraph Server.
which uses ``langgraph-sdk`` to communicate with Gateway's LangGraph-compatible API.
"""
from app.channels.base import Channel

View File

@ -31,6 +31,10 @@ class Channel(ABC):
def is_running(self) -> bool:
return self._running
@property
def supports_streaming(self) -> bool:
return False
# -- lifecycle ---------------------------------------------------------
@abstractmethod

View File

@ -0,0 +1,740 @@
"""DingTalk channel implementation."""
from __future__ import annotations
import asyncio
import json
import logging
import re
import threading
import time
from pathlib import Path
from typing import Any
import httpx
from app.channels.base import Channel
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__)
DINGTALK_API_BASE = "https://api.dingtalk.com"
_TOKEN_REFRESH_MARGIN_SECONDS = 300
_CONVERSATION_TYPE_P2P = "1"
_CONVERSATION_TYPE_GROUP = "2"
_MAX_UPLOAD_SIZE_BYTES = 20 * 1024 * 1024
def _normalize_conversation_type(raw: Any) -> str:
"""Normalize ``conversationType`` to ``"1"`` (P2P) or ``"2"`` (group).
Stream payloads may send int or string values.
"""
if raw is None:
return _CONVERSATION_TYPE_P2P
s = str(raw).strip()
if s == _CONVERSATION_TYPE_GROUP:
return _CONVERSATION_TYPE_GROUP
return _CONVERSATION_TYPE_P2P
def _normalize_allowed_users(allowed_users: Any) -> set[str]:
if allowed_users is None:
return set()
if isinstance(allowed_users, str):
values = [allowed_users]
elif isinstance(allowed_users, (list, tuple, set)):
values = allowed_users
else:
logger.warning(
"DingTalk allowed_users should be a list of user IDs; treating %s as one string value",
type(allowed_users).__name__,
)
values = [allowed_users]
return {str(uid) for uid in values if str(uid)}
def _is_dingtalk_command(text: str) -> bool:
if not text.startswith("/"):
return False
return text.split(maxsplit=1)[0].lower() in KNOWN_CHANNEL_COMMANDS
def _extract_text_from_rich_text(rich_text_list: list) -> str:
parts: list[str] = []
for item in rich_text_list:
if isinstance(item, dict) and "text" in item:
parts.append(item["text"])
return " ".join(parts)
_FENCED_CODE_BLOCK_RE = re.compile(r"```(\w*)\n(.*?)```", re.DOTALL)
_INLINE_CODE_RE = re.compile(r"`([^`\n]+)`")
_HORIZONTAL_RULE_RE = re.compile(r"^-{3,}$", re.MULTILINE)
_TABLE_SEPARATOR_RE = re.compile(r"^\|[-:| ]+\|$", re.MULTILINE)
def _convert_markdown_table(text: str) -> str:
# DingTalk sampleMarkdown does not render pipe-delimited tables.
lines = text.split("\n")
result: list[str] = []
i = 0
while i < len(lines):
line = lines[i]
# Detect table: header row followed by separator row
if i + 1 < len(lines) and line.strip().startswith("|") and _TABLE_SEPARATOR_RE.match(lines[i + 1].strip()):
headers = [h.strip() for h in line.strip().strip("|").split("|")]
i += 2 # skip header + separator
while i < len(lines) and lines[i].strip().startswith("|"):
cells = [c.strip() for c in lines[i].strip().strip("|").split("|")]
for h, c in zip(headers, cells):
result.append(f"> **{h}**: {c}")
result.append("")
i += 1
else:
result.append(line)
i += 1
return "\n".join(result)
def _adapt_markdown_for_dingtalk(text: str) -> str:
"""Adapt markdown for DingTalk's limited sampleMarkdown renderer."""
def _code_block_to_quote(match: re.Match) -> str:
lang = match.group(1)
code = match.group(2).rstrip("\n")
prefix = f"> **{lang}**\n" if lang else ""
quoted_lines = "\n".join(f"> {line}" for line in code.split("\n"))
return f"{prefix}{quoted_lines}\n"
text = _FENCED_CODE_BLOCK_RE.sub(_code_block_to_quote, text)
text = _INLINE_CODE_RE.sub(r"**\1**", text)
text = _convert_markdown_table(text)
text = _HORIZONTAL_RULE_RE.sub("───────────", text)
return text
class DingTalkChannel(Channel):
"""DingTalk IM channel using Stream Push (WebSocket, no public IP needed)."""
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
super().__init__(name="dingtalk", bus=bus, config=config)
self._thread: threading.Thread | None = None
self._main_loop: asyncio.AbstractEventLoop | None = None
self._client_id: str = ""
self._client_secret: str = ""
self._allowed_users: set[str] = _normalize_allowed_users(config.get("allowed_users"))
self._cached_token: str = ""
self._token_expires_at: float = 0.0
self._token_lock = asyncio.Lock()
self._card_template_id: str = config.get("card_template_id", "")
self._card_track_ids: dict[str, str] = {}
self._dingtalk_client: Any = None
self._stream_client: Any = None
self._incoming_messages: dict[str, Any] = {}
self._incoming_messages_lock = threading.Lock()
self._card_repliers: dict[str, Any] = {}
@property
def supports_streaming(self) -> bool:
return bool(self._card_template_id)
async def start(self) -> None:
if self._running:
return
try:
import dingtalk_stream # noqa: F401
except ImportError:
logger.error("dingtalk-stream is not installed. Install it with: uv add dingtalk-stream")
return
client_id = self.config.get("client_id", "")
client_secret = self.config.get("client_secret", "")
if not client_id or not client_secret:
logger.error("DingTalk channel requires client_id and client_secret")
return
self._client_id = client_id
self._client_secret = client_secret
self._main_loop = asyncio.get_running_loop()
if self._card_template_id:
logger.info("[DingTalk] AI Card mode enabled (template=%s)", self._card_template_id)
self._running = True
self.bus.subscribe_outbound(self._on_outbound)
self._thread = threading.Thread(
target=self._run_stream,
args=(client_id, client_secret),
daemon=True,
)
self._thread.start()
logger.info("DingTalk channel started")
async def stop(self) -> None:
self._running = False
self.bus.unsubscribe_outbound(self._on_outbound)
stream_client = self._stream_client
if stream_client is not None:
try:
if hasattr(stream_client, "disconnect"):
stream_client.disconnect()
except Exception:
logger.debug("[DingTalk] error disconnecting stream client", exc_info=True)
self._dingtalk_client = None
self._stream_client = None
with self._incoming_messages_lock:
self._incoming_messages.clear()
self._card_repliers.clear()
self._card_track_ids.clear()
if self._thread:
self._thread.join(timeout=5)
self._thread = None
logger.info("DingTalk channel stopped")
def _resolve_routing(self, msg: OutboundMessage) -> tuple[str, str, str]:
"""Return (conversation_type, sender_staff_id, conversation_id).
Uses msg.chat_id as the primary routing key; metadata as fallback.
"""
conversation_type = _normalize_conversation_type(msg.metadata.get("conversation_type"))
sender_staff_id = msg.metadata.get("sender_staff_id", "")
conversation_id = msg.metadata.get("conversation_id", "")
if conversation_type == _CONVERSATION_TYPE_GROUP:
conversation_id = msg.chat_id or conversation_id
else:
sender_staff_id = msg.chat_id or sender_staff_id
return conversation_type, sender_staff_id, conversation_id
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
conversation_type, sender_staff_id, conversation_id = self._resolve_routing(msg)
robot_code = self._client_id
# Card mode: stream update to existing AI card
source_key = self._make_card_source_key_from_outbound(msg)
out_track_id = self._card_track_ids.get(source_key)
# ``card_template_id`` enables ``runs.stream`` (non-final + final outbounds).
# If card creation failed, skip non-final chunks to avoid duplicate messages.
if self._card_template_id and not out_track_id and not msg.is_final:
return
if out_track_id:
try:
await self._stream_update_card(
out_track_id,
msg.text,
is_finalize=msg.is_final,
)
except Exception:
logger.warning("[DingTalk] card stream failed, falling back to sampleMarkdown")
if msg.is_final:
self._card_track_ids.pop(source_key, None)
self._card_repliers.pop(out_track_id, None)
await self._send_markdown_fallback(robot_code, conversation_type, sender_staff_id, conversation_id, msg.text)
return
if msg.is_final:
self._card_track_ids.pop(source_key, None)
self._card_repliers.pop(out_track_id, None)
return
# Non-card mode: send sampleMarkdown with retry
last_exc: Exception | None = None
for attempt in range(_max_retries):
try:
if conversation_type == _CONVERSATION_TYPE_GROUP:
await self._send_group_message(robot_code, conversation_id, msg.text, at_user_ids=[sender_staff_id] if sender_staff_id else None)
else:
await self._send_p2p_message(robot_code, sender_staff_id, msg.text)
return
except Exception as exc:
last_exc = exc
if attempt < _max_retries - 1:
delay = 2**attempt
logger.warning(
"[DingTalk] send failed (attempt %d/%d), retrying in %ds: %s",
attempt + 1,
_max_retries,
delay,
exc,
)
await asyncio.sleep(delay)
logger.error("[DingTalk] send failed after %d attempts: %s", _max_retries, last_exc)
if last_exc is None:
raise RuntimeError("DingTalk send failed without an exception from any attempt")
raise last_exc
async def _send_markdown_fallback(
self,
robot_code: str,
conversation_type: str,
sender_staff_id: str,
conversation_id: str,
text: str,
) -> None:
try:
if conversation_type == _CONVERSATION_TYPE_GROUP:
await self._send_group_message(robot_code, conversation_id, text)
else:
await self._send_p2p_message(robot_code, sender_staff_id, text)
except Exception:
logger.exception("[DingTalk] markdown fallback also failed")
raise
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
if attachment.size > _MAX_UPLOAD_SIZE_BYTES:
logger.warning("[DingTalk] file too large (%d bytes), skipping: %s", attachment.size, attachment.filename)
return False
conversation_type, sender_staff_id, conversation_id = self._resolve_routing(msg)
robot_code = self._client_id
try:
media_id = await self._upload_media(attachment.actual_path, "image" if attachment.is_image else "file")
if not media_id:
return False
if attachment.is_image:
msg_key = "sampleImageMsg"
msg_param = json.dumps({"photoURL": media_id})
else:
msg_key = "sampleFile"
msg_param = json.dumps(
{
"fileUrl": media_id,
"fileName": attachment.filename,
"fileSize": str(attachment.size),
}
)
token = await self._get_access_token()
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
if conversation_type == _CONVERSATION_TYPE_GROUP:
response = await client.post(
f"{DINGTALK_API_BASE}/v1.0/robot/groupMessages/send",
headers=self._api_headers(token),
json={
"msgKey": msg_key,
"msgParam": msg_param,
"robotCode": robot_code,
"openConversationId": conversation_id,
},
)
else:
response = await client.post(
f"{DINGTALK_API_BASE}/v1.0/robot/oToMessages/batchSend",
headers=self._api_headers(token),
json={
"msgKey": msg_key,
"msgParam": msg_param,
"robotCode": robot_code,
"userIds": [sender_staff_id],
},
)
response.raise_for_status()
logger.info("[DingTalk] file sent: %s", attachment.filename)
return True
except (httpx.HTTPError, OSError, ValueError, TypeError, AttributeError):
logger.exception("[DingTalk] failed to send file: %s", attachment.filename)
return False
# -- stream client (runs in dedicated thread) --------------------------
def _run_stream(self, client_id: str, client_secret: str) -> None:
try:
import dingtalk_stream
credential = dingtalk_stream.Credential(client_id, client_secret)
client = dingtalk_stream.DingTalkStreamClient(credential)
self._stream_client = client
client.register_callback_handler(
dingtalk_stream.chatbot.ChatbotMessage.TOPIC,
_DingTalkMessageHandler(self),
)
client.start_forever()
except Exception:
if self._running:
logger.exception("DingTalk Stream Push error")
finally:
self._stream_client = None
def _on_chatbot_message(self, message: Any) -> None:
if not self._running:
return
try:
sender_staff_id = message.sender_staff_id or ""
conversation_type = _normalize_conversation_type(message.conversation_type)
conversation_id = message.conversation_id or ""
msg_id = message.message_id or ""
sender_nick = message.sender_nick or ""
if self._allowed_users and sender_staff_id not in self._allowed_users:
logger.debug("[DingTalk] ignoring message from non-allowed user: %s", sender_staff_id)
return
text = self._extract_text(message)
if not text:
logger.info("[DingTalk] empty text, ignoring message")
return
logger.info(
"[DingTalk] parsed message: conv_type=%s, msg_id=%s, sender=%s(%s), text=%r",
conversation_type,
msg_id,
sender_staff_id,
sender_nick,
text[:100],
)
if _is_dingtalk_command(text):
msg_type = InboundMessageType.COMMAND
else:
msg_type = InboundMessageType.CHAT
# P2P: topic_id=None (single thread per user, like Telegram private chat)
# Group: topic_id=msg_id (each new message starts a new topic, like Feishu)
topic_id: str | None = msg_id if conversation_type == _CONVERSATION_TYPE_GROUP else None
# chat_id uses conversation_id for groups, sender_staff_id for P2P
chat_id = conversation_id if conversation_type == _CONVERSATION_TYPE_GROUP else sender_staff_id
inbound = self._make_inbound(
chat_id=chat_id,
user_id=sender_staff_id,
text=text,
msg_type=msg_type,
thread_ts=msg_id,
metadata={
"conversation_type": conversation_type,
"conversation_id": conversation_id,
"sender_staff_id": sender_staff_id,
"sender_nick": sender_nick,
"message_id": msg_id,
},
)
inbound.topic_id = topic_id
if self._card_template_id:
source_key = self._make_card_source_key(inbound)
with self._incoming_messages_lock:
self._incoming_messages[source_key] = message
if self._main_loop and self._main_loop.is_running():
logger.info("[DingTalk] publishing inbound message to bus (type=%s, msg_id=%s)", msg_type.value, msg_id)
fut = asyncio.run_coroutine_threadsafe(
self._prepare_inbound(chat_id, inbound),
self._main_loop,
)
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "prepare_inbound", mid))
else:
logger.warning("[DingTalk] main loop not running, cannot publish inbound message")
except Exception:
logger.exception("[DingTalk] error processing chatbot message")
@staticmethod
def _extract_text(message: Any) -> str:
msg_type = message.message_type
if msg_type == "text" and message.text:
return message.text.content.strip()
if msg_type == "richText" and message.rich_text_content:
return _extract_text_from_rich_text(message.rich_text_content.rich_text_list).strip()
return ""
async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None:
# Running reply must finish before publish_inbound so AI card tracks are
# registered before the manager emits streaming outbounds.
await self._send_running_reply(chat_id, inbound)
await self.bus.publish_inbound(inbound)
async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None:
conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P)
sender_staff_id = inbound.metadata.get("sender_staff_id", "")
conversation_id = inbound.metadata.get("conversation_id", "")
text = "\u23f3 Working on it..."
try:
if self._card_template_id:
source_key = self._make_card_source_key(inbound)
with self._incoming_messages_lock:
chatbot_message = self._incoming_messages.pop(source_key, None)
out_track_id = await self._create_and_deliver_card(
text,
chatbot_message=chatbot_message,
)
if out_track_id:
self._card_track_ids[source_key] = out_track_id
logger.info("[DingTalk] AI card running reply sent for chat=%s", chat_id)
return
robot_code = self._client_id
if conversation_type == _CONVERSATION_TYPE_GROUP:
await self._send_text_message_to_group(robot_code, conversation_id, text)
else:
await self._send_text_message_to_user(robot_code, sender_staff_id, text)
logger.info("[DingTalk] 'Working on it...' reply sent for chat=%s", chat_id)
except Exception:
logger.exception("[DingTalk] failed to send running reply for chat=%s", chat_id)
# -- DingTalk API helpers ----------------------------------------------
async def _get_access_token(self) -> str:
if self._cached_token and time.monotonic() < self._token_expires_at:
return self._cached_token
async with self._token_lock:
if self._cached_token and time.monotonic() < self._token_expires_at:
return self._cached_token
async with httpx.AsyncClient(timeout=httpx.Timeout(10.0)) as client:
response = await client.post(
f"{DINGTALK_API_BASE}/v1.0/oauth2/accessToken",
json={"appKey": self._client_id, "appSecret": self._client_secret}, # DingTalk API field names
)
response.raise_for_status()
data = response.json()
if not isinstance(data, dict):
raise ValueError(f"DingTalk access token response must be a JSON object, got {type(data).__name__}")
access_token = data.get("accessToken")
if not isinstance(access_token, str) or not access_token.strip():
raise ValueError("DingTalk access token response did not contain a usable accessToken")
raw_expires_in = data.get("expireIn", 7200)
try:
expires_in = int(raw_expires_in)
except (TypeError, ValueError):
logger.warning("[DingTalk] invalid expireIn value %r, using default 7200s", raw_expires_in)
expires_in = 7200
self._cached_token = access_token.strip()
self._token_expires_at = time.monotonic() + expires_in - _TOKEN_REFRESH_MARGIN_SECONDS
return self._cached_token
@staticmethod
def _api_headers(token: str) -> dict[str, str]:
return {
"x-acs-dingtalk-access-token": token,
"Content-Type": "application/json",
}
async def _send_text_message_to_user(self, robot_code: str, user_id: str, text: str) -> None:
token = await self._get_access_token()
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
response = await client.post(
f"{DINGTALK_API_BASE}/v1.0/robot/oToMessages/batchSend",
headers=self._api_headers(token),
json={
"msgKey": "sampleText",
"msgParam": json.dumps({"content": text}),
"robotCode": robot_code,
"userIds": [user_id],
},
)
response.raise_for_status()
async def _send_text_message_to_group(self, robot_code: str, conversation_id: str, text: str) -> None:
token = await self._get_access_token()
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
response = await client.post(
f"{DINGTALK_API_BASE}/v1.0/robot/groupMessages/send",
headers=self._api_headers(token),
json={
"msgKey": "sampleText",
"msgParam": json.dumps({"content": text}),
"robotCode": robot_code,
"openConversationId": conversation_id,
},
)
response.raise_for_status()
async def _send_p2p_message(self, robot_code: str, user_id: str, text: str) -> None:
text = _adapt_markdown_for_dingtalk(text)
token = await self._get_access_token()
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
response = await client.post(
f"{DINGTALK_API_BASE}/v1.0/robot/oToMessages/batchSend",
headers=self._api_headers(token),
json={
"msgKey": "sampleMarkdown",
"msgParam": json.dumps({"title": "DeerFlow", "text": text}),
"robotCode": robot_code,
"userIds": [user_id],
},
)
response.raise_for_status()
data = response.json()
if data.get("processQueryKey"):
logger.info("[DingTalk] P2P message sent to user=%s", user_id)
else:
logger.warning("[DingTalk] P2P send response: %s", data)
async def _send_group_message(
self,
robot_code: str,
conversation_id: str,
text: str,
*,
at_user_ids: list[str] | None = None, # noqa: ARG002
) -> None:
# at_user_ids accepted for call-site compatibility but not passed to the API
# (sampleMarkdown does not support @mentions).
text = _adapt_markdown_for_dingtalk(text)
token = await self._get_access_token()
async with httpx.AsyncClient(timeout=httpx.Timeout(30.0)) as client:
response = await client.post(
f"{DINGTALK_API_BASE}/v1.0/robot/groupMessages/send",
headers=self._api_headers(token),
json={
"msgKey": "sampleMarkdown",
"msgParam": json.dumps({"title": "DeerFlow", "text": text}),
"robotCode": robot_code,
"openConversationId": conversation_id,
},
)
response.raise_for_status()
data = response.json()
if data.get("processQueryKey"):
logger.info("[DingTalk] group message sent to conversation=%s", conversation_id)
else:
logger.warning("[DingTalk] group send response: %s", data)
# -- AI Card streaming helpers -------------------------------------------
def _make_card_source_key(self, inbound: InboundMessage) -> str:
m = inbound.metadata
return f"{m.get('conversation_type', '')}:{m.get('sender_staff_id', '')}:{m.get('conversation_id', '')}:{m.get('message_id', '')}"
def _make_card_source_key_from_outbound(self, msg: OutboundMessage) -> str:
m = msg.metadata
correlation_id = m.get("message_id") or msg.thread_ts or ""
return f"{m.get('conversation_type', '')}:{m.get('sender_staff_id', '')}:{m.get('conversation_id', '')}:{correlation_id}"
async def _create_and_deliver_card(
self,
initial_text: str,
*,
chatbot_message: Any = None,
) -> str | None:
if self._dingtalk_client is None or chatbot_message is None:
logger.warning("[DingTalk] SDK client or chatbot_message unavailable, skipping AI card")
return None
try:
from dingtalk_stream.card_replier import AICardReplier
except ImportError:
logger.warning("[DingTalk] dingtalk-stream card_replier not available")
return None
try:
replier = AICardReplier(self._dingtalk_client, chatbot_message)
card_instance_id = await replier.async_create_and_deliver_card(
card_template_id=self._card_template_id,
card_data={"content": initial_text},
)
if not card_instance_id:
return None
self._card_repliers[card_instance_id] = replier
logger.info("[DingTalk] AI card created: outTrackId=%s", card_instance_id)
return card_instance_id
except Exception:
logger.exception("[DingTalk] failed to create AI card")
return None
async def _stream_update_card(
self,
out_track_id: str,
content: str,
*,
is_finalize: bool = False,
is_error: bool = False,
) -> None:
replier = self._card_repliers.get(out_track_id)
if not replier:
raise RuntimeError(f"No AICardReplier found for track ID {out_track_id}")
await replier.async_streaming(
card_instance_id=out_track_id,
content_key="content",
content_value=content,
append=False,
finished=is_finalize,
failed=is_error,
)
# -- media upload --------------------------------------------------------
async def _upload_media(self, file_path: str | Path, media_type: str) -> str | None:
try:
file_bytes = await asyncio.to_thread(Path(file_path).read_bytes)
token = await self._get_access_token()
async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client:
response = await client.post(
f"{DINGTALK_API_BASE}/v1.0/files/upload",
headers={"x-acs-dingtalk-access-token": token},
files={"file": ("upload", file_bytes)},
data={"type": media_type},
)
response.raise_for_status()
try:
payload = response.json()
except json.JSONDecodeError:
logger.exception("[DingTalk] failed to decode upload response JSON: %s", file_path)
return None
if not isinstance(payload, dict):
logger.warning("[DingTalk] unexpected upload response type %s for %s", type(payload).__name__, file_path)
return None
return payload.get("mediaId")
except (httpx.HTTPError, OSError):
logger.exception("[DingTalk] failed to upload media: %s", file_path)
return None
@staticmethod
def _log_future_error(fut: Any, name: str, msg_id: str) -> None:
try:
exc = fut.exception()
if exc:
logger.error("[DingTalk] %s failed for msg_id=%s: %s", name, msg_id, exc)
except (asyncio.CancelledError, asyncio.InvalidStateError):
pass
class _DingTalkMessageHandler:
"""Callback handler registered with dingtalk-stream."""
def __init__(self, channel: DingTalkChannel) -> None:
self._channel = channel
def pre_start(self) -> None:
if hasattr(self, "dingtalk_client") and self.dingtalk_client is not None:
self._channel._dingtalk_client = self.dingtalk_client
async def raw_process(self, callback_message: Any) -> Any:
import dingtalk_stream
from dingtalk_stream.frames import Headers
code, message = await self.process(callback_message)
ack_message = dingtalk_stream.AckMessage()
ack_message.code = code
ack_message.headers.message_id = callback_message.headers.message_id
ack_message.headers.content_type = Headers.CONTENT_TYPE_APPLICATION_JSON
ack_message.data = {"response": message}
return ack_message
async def process(self, callback: Any) -> tuple[int, str]:
import dingtalk_stream
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
self._channel._on_chatbot_message(incoming_message)
return dingtalk_stream.AckMessage.STATUS_OK, "OK"

View File

@ -13,6 +13,7 @@ from app.channels.base import Channel
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
logger = logging.getLogger(__name__)
@ -62,6 +63,10 @@ class FeishuChannel(Channel):
self._GetMessageResourceRequest = None
self._thread_lock = threading.Lock()
@property
def supports_streaming(self) -> bool:
return True
async def start(self) -> None:
if self._running:
return
@ -344,8 +349,9 @@ class FeishuChannel(Channel):
return f"Failed to obtain the [{type}]"
paths = get_paths()
paths.ensure_thread_dirs(thread_id)
uploads_dir = paths.sandbox_uploads_dir(thread_id).resolve()
user_id = get_effective_user_id()
paths.ensure_thread_dirs(thread_id, user_id=user_id)
uploads_dir = paths.sandbox_uploads_dir(thread_id, user_id=user_id).resolve()
ext = "png" if type == "image" else "bin"
raw_filename = getattr(response, "file_name", "") or f"feishu_{file_key[-12:]}.{ext}"

View File

@ -1,4 +1,4 @@
"""ChannelManager — consumes inbound messages and dispatches them to the DeerFlow agent via LangGraph Server."""
"""ChannelManager — consumes inbound messages and dispatches them to the DeerFlow agent via Gateway."""
from __future__ import annotations
@ -17,10 +17,13 @@ from langgraph_sdk.errors import ConflictError
from app.channels.commands import KNOWN_CHANNEL_COMMANDS
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.store import ChannelStore
from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token
from app.gateway.internal_auth import create_internal_auth_headers
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
DEFAULT_LANGGRAPH_URL = "http://localhost:2024"
DEFAULT_LANGGRAPH_URL = "http://localhost:8001/api"
DEFAULT_GATEWAY_URL = "http://localhost:8001"
DEFAULT_ASSISTANT_ID = "lead_agent"
CUSTOM_AGENT_NAME_PATTERN = re.compile(r"^[A-Za-z0-9-]+$")
@ -35,6 +38,7 @@ STREAM_UPDATE_MIN_INTERVAL_SECONDS = 0.35
THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again."
CHANNEL_CAPABILITIES = {
"dingtalk": {"supports_streaming": False},
"discord": {"supports_streaming": False},
"feishu": {"supports_streaming": True},
"slack": {"supports_streaming": False},
@ -45,6 +49,13 @@ CHANNEL_CAPABILITIES = {
InboundFileReader = Callable[[dict[str, Any], httpx.AsyncClient], Awaitable[bytes | None]]
_METADATA_DROP_KEYS = frozenset({"raw_message", "ref_msg"})
def _slim_metadata(meta: dict[str, Any]) -> dict[str, Any]:
"""Return a shallow copy of *meta* with known-large keys removed."""
return {k: v for k, v in meta.items() if k not in _METADATA_DROP_KEYS}
INBOUND_FILE_READERS: dict[str, InboundFileReader] = {}
@ -342,14 +353,15 @@ def _resolve_attachments(thread_id: str, artifacts: list[str]) -> list[ResolvedA
attachments: list[ResolvedAttachment] = []
paths = get_paths()
outputs_dir = paths.sandbox_outputs_dir(thread_id).resolve()
user_id = get_effective_user_id()
outputs_dir = paths.sandbox_outputs_dir(thread_id, user_id=user_id).resolve()
for virtual_path in artifacts:
# Security: only allow files from the agent outputs directory
if not virtual_path.startswith(_OUTPUTS_VIRTUAL_PREFIX):
logger.warning("[Manager] rejected non-outputs artifact path: %s", virtual_path)
continue
try:
actual = paths.resolve_virtual_path(thread_id, virtual_path)
actual = paths.resolve_virtual_path(thread_id, virtual_path, user_id=user_id)
# Verify the resolved path is actually under the outputs directory
# (guards against path-traversal even after prefix check)
try:
@ -507,7 +519,7 @@ class ChannelManager:
"""Core dispatcher that bridges IM channels to the DeerFlow agent.
It reads from the MessageBus inbound queue, creates/reuses threads on
the LangGraph Server, sends messages via ``runs.wait``, and publishes
Gateway's LangGraph-compatible API, sends messages via ``runs.wait``, and publishes
outbound responses back through the bus.
"""
@ -532,12 +544,20 @@ class ChannelManager:
self._default_session = _as_dict(default_session)
self._channel_sessions = dict(channel_sessions or {})
self._client = None # lazy init — langgraph_sdk async client
self._csrf_token = generate_csrf_token()
self._semaphore: asyncio.Semaphore | None = None
self._running = False
self._task: asyncio.Task | None = None
@staticmethod
def _channel_supports_streaming(channel_name: str) -> bool:
from .service import get_channel_service
service = get_channel_service()
if service:
channel = service.get_channel(channel_name)
if channel is not None:
return channel.supports_streaming
return CHANNEL_CAPABILITIES.get(channel_name, {}).get("supports_streaming", False)
def _resolve_session_layer(self, msg: InboundMessage) -> tuple[dict[str, Any], dict[str, Any]]:
@ -593,7 +613,14 @@ class ChannelManager:
if self._client is None:
from langgraph_sdk import get_client
self._client = get_client(url=self._langgraph_url)
self._client = get_client(
url=self._langgraph_url,
headers={
**create_internal_auth_headers(),
CSRF_HEADER_NAME: self._csrf_token,
"Cookie": f"{CSRF_COOKIE_NAME}={self._csrf_token}",
},
)
return self._client
# -- lifecycle ---------------------------------------------------------
@ -676,7 +703,7 @@ class ChannelManager:
# -- chat handling -----------------------------------------------------
async def _create_thread(self, client, msg: InboundMessage) -> str:
"""Create a new thread on the LangGraph Server and store the mapping."""
"""Create a new thread through Gateway and store the mapping."""
thread = await client.threads.create()
thread_id = thread["thread_id"]
self.store.set_thread_id(
@ -686,7 +713,7 @@ class ChannelManager:
topic_id=msg.topic_id,
user_id=msg.user_id,
)
logger.info("[Manager] new thread created on LangGraph Server: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
return thread_id
async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None:
@ -769,6 +796,7 @@ class ChannelManager:
artifacts=artifacts,
attachments=attachments,
thread_ts=msg.thread_ts,
metadata=_slim_metadata(msg.metadata),
)
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
await self.bus.publish_outbound(outbound)
@ -830,6 +858,7 @@ class ChannelManager:
text=latest_text,
is_final=False,
thread_ts=msg.thread_ts,
metadata=_slim_metadata(msg.metadata),
)
)
last_published_text = latest_text
@ -874,6 +903,7 @@ class ChannelManager:
attachments=attachments,
is_final=True,
thread_ts=msg.thread_ts,
metadata=_slim_metadata(msg.metadata),
)
)
@ -893,7 +923,7 @@ class ChannelManager:
return
if command == "new":
# Create a new thread on the LangGraph Server
# Create a new thread through Gateway
client = self._get_client()
thread = await client.threads.create()
new_thread_id = thread["thread_id"]
@ -932,6 +962,7 @@ class ChannelManager:
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
text=reply,
thread_ts=msg.thread_ts,
metadata=_slim_metadata(msg.metadata),
)
await self.bus.publish_outbound(outbound)
@ -965,5 +996,6 @@ class ChannelManager:
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id) or "",
text=error_text,
thread_ts=msg.thread_ts,
metadata=_slim_metadata(msg.metadata),
)
await self.bus.publish_outbound(outbound)

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import logging
import os
from typing import Any
from typing import TYPE_CHECKING, Any
from app.channels.base import Channel
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
@ -13,8 +13,12 @@ from app.channels.store import ChannelStore
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
# Channel name → import path for lazy loading
_CHANNEL_REGISTRY: dict[str, str] = {
"dingtalk": "app.channels.dingtalk:DingTalkChannel",
"discord": "app.channels.discord:DiscordChannel",
"feishu": "app.channels.feishu:FeishuChannel",
"slack": "app.channels.slack:SlackChannel",
@ -23,6 +27,17 @@ _CHANNEL_REGISTRY: dict[str, str] = {
"wecom": "app.channels.wecom:WeComChannel",
}
# Keys that indicate a user has configured credentials for a channel.
_CHANNEL_CREDENTIAL_KEYS: dict[str, list[str]] = {
"dingtalk": ["client_id", "client_secret"],
"discord": ["bot_token"],
"feishu": ["app_id", "app_secret"],
"slack": ["bot_token", "app_token"],
"telegram": ["bot_token"],
"wecom": ["bot_id", "bot_secret"],
"wechat": ["bot_token"],
}
_CHANNELS_LANGGRAPH_URL_ENV = "DEER_FLOW_CHANNELS_LANGGRAPH_URL"
_CHANNELS_GATEWAY_URL_ENV = "DEER_FLOW_CHANNELS_GATEWAY_URL"
@ -65,14 +80,15 @@ class ChannelService:
self._running = False
@classmethod
def from_app_config(cls) -> ChannelService:
def from_app_config(cls, app_config: AppConfig | None = None) -> ChannelService:
"""Create a ChannelService from the application config."""
from deerflow.config.app_config import get_app_config
if app_config is None:
from deerflow.config.app_config import get_app_config
config = get_app_config()
app_config = get_app_config()
channels_config = {}
# extra fields are allowed by AppConfig (extra="allow")
extra = config.model_extra or {}
extra = app_config.model_extra or {}
if "channels" in extra:
channels_config = extra["channels"]
return cls(channels_config=channels_config)
@ -88,7 +104,16 @@ class ChannelService:
if not isinstance(channel_config, dict):
continue
if not channel_config.get("enabled", False):
logger.info("Channel %s is disabled, skipping", name)
cred_keys = _CHANNEL_CREDENTIAL_KEYS.get(name, [])
has_creds = any(not isinstance(channel_config.get(k), bool) and channel_config.get(k) is not None and str(channel_config[k]).strip() for k in cred_keys)
if has_creds:
logger.warning(
"Channel '%s' has credentials configured but is disabled. Set enabled: true under channels.%s in config.yaml to activate it.",
name,
name,
)
else:
logger.info("Channel %s is disabled, skipping", name)
continue
await self._start_channel(name, channel_config)
@ -143,11 +168,16 @@ class ChannelService:
try:
channel = channel_cls(bus=self.bus, config=config)
await channel.start()
self._channels[name] = channel
await channel.start()
if not channel.is_running:
self._channels.pop(name, None)
logger.error("Channel %s did not enter a running state after start()", name)
return False
logger.info("Channel %s started", name)
return True
except Exception:
self._channels.pop(name, None)
logger.exception("Failed to start channel %s", name)
return False
@ -182,12 +212,12 @@ def get_channel_service() -> ChannelService | None:
return _channel_service
async def start_channel_service() -> ChannelService:
async def start_channel_service(app_config: AppConfig | None = None) -> ChannelService:
"""Create and start the global ChannelService from app config."""
global _channel_service
if _channel_service is not None:
return _channel_service
_channel_service = ChannelService.from_app_config()
_channel_service = ChannelService.from_app_config(app_config)
await _channel_service.start()
return _channel_service

View File

@ -16,13 +16,31 @@ logger = logging.getLogger(__name__)
_slack_md_converter = SlackMarkdownConverter()
def _normalize_allowed_users(allowed_users: Any) -> set[str]:
if allowed_users is None:
return set()
if isinstance(allowed_users, str):
values = [allowed_users]
elif isinstance(allowed_users, list | tuple | set):
values = allowed_users
else:
logger.warning(
"Slack allowed_users should be a list of Slack user IDs or a single Slack user ID string; treating %s as one string value",
type(allowed_users).__name__,
)
values = [allowed_users]
return {str(user_id) for user_id in values if str(user_id)}
class SlackChannel(Channel):
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
Configuration keys (in ``config.yaml`` under ``channels.slack``):
- ``bot_token``: Slack Bot User OAuth Token (xoxb-...).
- ``app_token``: Slack App-Level Token (xapp-...) for Socket Mode.
- ``allowed_users``: (optional) List of allowed Slack user IDs. Empty = allow all.
- ``allowed_users``: (optional) List of allowed Slack user IDs, or a
single Slack user ID string as shorthand. Empty = allow all. Other
scalar values are treated as a single string with a warning.
"""
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
@ -30,7 +48,7 @@ class SlackChannel(Channel):
self._socket_client = None
self._web_client = None
self._loop: asyncio.AbstractEventLoop | None = None
self._allowed_users: set[str] = {str(user_id) for user_id in config.get("allowed_users", [])}
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
async def start(self) -> None:
if self._running:

View File

@ -29,6 +29,10 @@ class WeComChannel(Channel):
self._ws_stream_ids: dict[str, str] = {}
self._working_message = "Working on it..."
@property
def supports_streaming(self) -> bool:
return True
def _clear_ws_context(self, thread_ts: str | None) -> None:
if not thread_ts:
return

View File

@ -1,16 +1,23 @@
import asyncio
import logging
import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.gateway.auth_middleware import AuthMiddleware
from app.gateway.config import get_gateway_config
from app.gateway.csrf_middleware import CSRFMiddleware
from app.gateway.deps import langgraph_runtime
from app.gateway.routers import (
agents,
artifacts,
assistants_compat,
auth,
channels,
feedback,
mcp,
memory,
models,
@ -21,9 +28,13 @@ from app.gateway.routers import (
threads,
uploads,
)
from deerflow.config.app_config import get_app_config
from deerflow.config import app_config as deerflow_app_config
from deerflow.config.app_config import apply_logging_level
# Configure logging
AppConfig = deerflow_app_config.AppConfig
get_app_config = deerflow_app_config.get_app_config
# Default logging; lifespan overrides from config.yaml log_level.
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
@ -32,6 +43,120 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
# Upper bound (seconds) each lifespan shutdown hook is allowed to run.
# Bounds worker exit time so uvicorn's reload supervisor does not keep
# firing signals into a worker that is stuck waiting for shutdown cleanup.
_SHUTDOWN_HOOK_TIMEOUT_SECONDS = 5.0
async def _ensure_admin_user(app: FastAPI) -> None:
"""Startup hook: handle first boot and migrate orphan threads otherwise.
After admin creation, migrate orphan threads from the LangGraph
store (metadata.user_id unset) to the admin account. This is the
"no-auth → with-auth" upgrade path: users who ran DeerFlow without
authentication have existing LangGraph thread data that needs an
owner assigned.
First boot (no admin exists):
- Does NOT create any user accounts automatically.
- The operator must visit ``/setup`` to create the first admin.
Subsequent boots (admin already exists):
- Runs the one-time "no-auth → with-auth" orphan thread migration for
existing LangGraph thread metadata that has no owner_id.
No SQL persistence migration is needed: the four user_id columns
(threads_meta, runs, run_events, feedback) only come into existence
alongside the auth module via create_all, so freshly created tables
never contain NULL-owner rows.
"""
from sqlalchemy import select
from app.gateway.deps import get_local_provider
from deerflow.persistence.engine import get_session_factory
from deerflow.persistence.user.model import UserRow
try:
provider = get_local_provider()
except RuntimeError:
# Auth persistence may not be initialized in some test/boot paths.
# Skip admin migration work rather than failing gateway startup.
logger.warning("Auth persistence not ready; skipping admin bootstrap check")
return
sf = get_session_factory()
if sf is None:
return
admin_count = await provider.count_admin_users()
if admin_count == 0:
logger.info("=" * 60)
logger.info(" First boot detected — no admin account exists.")
logger.info(" Visit /setup to complete admin account creation.")
logger.info("=" * 60)
return
# Admin already exists — run orphan thread migration for any
# LangGraph thread metadata that pre-dates the auth module.
async with sf() as session:
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
row = (await session.execute(stmt)).scalar_one_or_none()
if row is None:
return # Should not happen (admin_count > 0 above), but be safe.
admin_id = str(row.id)
# LangGraph store orphan migration — non-fatal.
# This covers the "no-auth → with-auth" upgrade path for users
# whose existing LangGraph thread metadata has no user_id set.
store = getattr(app.state, "store", None)
if store is not None:
try:
migrated = await _migrate_orphaned_threads(store, admin_id)
if migrated:
logger.info("Migrated %d orphan LangGraph thread(s) to admin", migrated)
except Exception:
logger.exception("LangGraph thread migration failed (non-fatal)")
async def _iter_store_items(store, namespace, *, page_size: int = 500):
"""Paginated async iterator over a LangGraph store namespace.
Replaces the old hardcoded ``limit=1000`` call with a cursor-style
loop so that environments with more than one page of orphans do
not silently lose data. Terminates when a page is empty OR when a
short page arrives (indicating the last page).
"""
offset = 0
while True:
batch = await store.asearch(namespace, limit=page_size, offset=offset)
if not batch:
return
for item in batch:
yield item
if len(batch) < page_size:
return
offset += page_size
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
"""Migrate LangGraph store threads with no user_id to the given admin.
Uses cursor pagination so all orphans are migrated regardless of
count. Returns the number of rows migrated.
"""
migrated = 0
async for item in _iter_store_items(store, ("threads",)):
metadata = item.value.get("metadata", {})
if not metadata.get("user_id"):
metadata["user_id"] = admin_user_id
item.value["metadata"] = metadata
await store.aput(("threads",), item.key, item.value)
migrated += 1
return migrated
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
@ -39,7 +164,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# Load config and check necessary environment variables at startup
try:
get_app_config()
app.state.config = get_app_config()
apply_logging_level(app.state.config.log_level)
logger.info("Configuration loaded successfully")
except Exception as e:
error_msg = f"Failed to load configuration during gateway startup: {e}"
@ -52,22 +178,34 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
async with langgraph_runtime(app):
logger.info("LangGraph runtime initialised")
# Ensure admin user exists (auto-create on first boot)
# Must run AFTER langgraph_runtime so app.state.store is available for thread migration
await _ensure_admin_user(app)
# Start IM channel service if any channels are configured
try:
from app.channels.service import start_channel_service
channel_service = await start_channel_service()
channel_service = await start_channel_service(app.state.config)
logger.info("Channel service started: %s", channel_service.get_status())
except Exception:
logger.exception("No IM channels configured or channel service failed to start")
yield
# Stop channel service on shutdown
# Stop channel service on shutdown (bounded to prevent worker hang)
try:
from app.channels.service import stop_channel_service
await stop_channel_service()
await asyncio.wait_for(
stop_channel_service(),
timeout=_SHUTDOWN_HOOK_TIMEOUT_SECONDS,
)
except TimeoutError:
logger.warning(
"Channel service shutdown exceeded %.1fs; proceeding with worker exit.",
_SHUTDOWN_HOOK_TIMEOUT_SECONDS,
)
except Exception:
logger.exception("Failed to stop channel service")
@ -80,6 +218,8 @@ def create_app() -> FastAPI:
Returns:
Configured FastAPI application instance.
"""
config = get_gateway_config()
docs_kwargs = {"docs_url": "/docs", "redoc_url": "/redoc", "openapi_url": "/openapi.json"} if config.enable_docs else {"docs_url": None, "redoc_url": None, "openapi_url": None}
app = FastAPI(
title="DeerFlow API Gateway",
@ -104,9 +244,7 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
""",
version="0.1.0",
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
**docs_kwargs,
openapi_tags=[
{
"name": "models",
@ -163,7 +301,31 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
],
)
# CORS is handled by nginx - no need for FastAPI middleware
# Auth: reject unauthenticated requests to non-public paths (fail-closed safety net)
app.add_middleware(AuthMiddleware)
# CSRF: Double Submit Cookie pattern for state-changing requests
app.add_middleware(CSRFMiddleware)
# CORS: when GATEWAY_CORS_ORIGINS is set (dev without nginx), add CORS middleware.
# In production, nginx handles CORS and no middleware is needed.
cors_origins_env = os.environ.get("GATEWAY_CORS_ORIGINS", "")
if cors_origins_env:
cors_origins = [o.strip() for o in cors_origins_env.split(",") if o.strip()]
# Validate: wildcard origin with credentials is a security misconfiguration
for origin in cors_origins:
if origin == "*":
logger.error("GATEWAY_CORS_ORIGINS contains wildcard '*' with allow_credentials=True. This is a security misconfiguration — browsers will reject the response. Use explicit scheme://host:port origins instead.")
cors_origins = [o for o in cors_origins if o != "*"]
break
if cors_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include routers
# Models API is mounted at /api/models
@ -199,6 +361,12 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
# Assistants compatibility API (LangGraph Platform stub)
app.include_router(assistants_compat.router)
# Auth API is mounted at /api/v1/auth
app.include_router(auth.router)
# Feedback API is mounted at /api/threads/{thread_id}/runs/{run_id}/feedback
app.include_router(feedback.router)
# Thread Runs API (LangGraph Platform-compatible runs lifecycle)
app.include_router(thread_runs.router)

View File

@ -0,0 +1,42 @@
"""Authentication module for DeerFlow.
This module provides:
- JWT-based authentication
- Provider Factory pattern for extensible auth methods
- UserRepository interface for storage backends (SQLite)
"""
from app.gateway.auth.config import AuthConfig, get_auth_config, set_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
from app.gateway.auth.jwt import TokenPayload, create_access_token, decode_token
from app.gateway.auth.local_provider import LocalAuthProvider
from app.gateway.auth.models import User, UserResponse
from app.gateway.auth.password import hash_password, verify_password
from app.gateway.auth.providers import AuthProvider
from app.gateway.auth.repositories.base import UserRepository
__all__ = [
# Config
"AuthConfig",
"get_auth_config",
"set_auth_config",
# Errors
"AuthErrorCode",
"AuthErrorResponse",
"TokenError",
# JWT
"TokenPayload",
"create_access_token",
"decode_token",
# Password
"hash_password",
"verify_password",
# Models
"User",
"UserResponse",
# Providers
"AuthProvider",
"LocalAuthProvider",
# Repository
"UserRepository",
]

View File

@ -0,0 +1,57 @@
"""Authentication configuration for DeerFlow."""
import logging
import os
import secrets
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class AuthConfig(BaseModel):
"""JWT and auth-related configuration. Parsed once at startup.
Note: the ``users`` table now lives in the shared persistence
database managed by ``deerflow.persistence.engine``. The old
``users_db_path`` config key has been removed user storage is
configured through ``config.database`` like every other table.
"""
jwt_secret: str = Field(
...,
description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.",
)
token_expiry_days: int = Field(default=7, ge=1, le=30)
oauth_github_client_id: str | None = Field(default=None)
oauth_github_client_secret: str | None = Field(default=None)
_auth_config: AuthConfig | None = None
def get_auth_config() -> AuthConfig:
"""Get the global AuthConfig instance. Parses from env on first call."""
global _auth_config
if _auth_config is None:
from dotenv import load_dotenv
load_dotenv()
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
if not jwt_secret:
jwt_secret = secrets.token_urlsafe(32)
os.environ["AUTH_JWT_SECRET"] = jwt_secret
logger.warning(
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
"Sessions will be invalidated on restart. "
"For production, add AUTH_JWT_SECRET to your .env file: "
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
)
_auth_config = AuthConfig(jwt_secret=jwt_secret)
return _auth_config
def set_auth_config(config: AuthConfig) -> None:
"""Set the global AuthConfig instance (for testing)."""
global _auth_config
_auth_config = config

View File

@ -0,0 +1,48 @@
"""Write initial admin credentials to a restricted file instead of logs.
Logging secrets to stdout/stderr is a well-known CodeQL finding
(py/clear-text-logging-sensitive-data) in production those logs
get collected into ELK/Splunk/etc and become a secret sprawl
source. This helper writes the credential to a 0600 file that only
the process user can read, and returns the path so the caller can
log **the path** (not the password) for the operator to pick up.
"""
from __future__ import annotations
import os
from pathlib import Path
from deerflow.config.paths import get_paths
_CREDENTIAL_FILENAME = "admin_initial_credentials.txt"
def write_initial_credentials(email: str, password: str, *, label: str = "initial") -> Path:
"""Write the admin email + password to ``{base_dir}/admin_initial_credentials.txt``.
The file is created **atomically** with mode 0600 via ``os.open``
so the password is never world-readable, even for the single syscall
window between ``write_text`` and ``chmod``.
``label`` distinguishes "initial" (fresh creation) from "reset"
(password reset) in the file header so an operator picking up the
file after a restart can tell which event produced it.
Returns the absolute :class:`Path` to the file.
"""
target = get_paths().base_dir / _CREDENTIAL_FILENAME
target.parent.mkdir(parents=True, exist_ok=True)
content = (
f"# DeerFlow admin {label} credentials\n# This file is generated on first boot or password reset.\n# Change the password after login via Settings -> Account,\n# then delete this file.\n#\nemail: {email}\npassword: {password}\n"
)
# Atomic 0600 create-or-truncate. O_TRUNC (not O_EXCL) so the
# reset-password path can rewrite an existing file without a
# separate unlink-then-create dance.
fd = os.open(target, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, "w", encoding="utf-8") as fh:
fh.write(content)
return target.resolve()

View File

@ -0,0 +1,45 @@
"""Typed error definitions for auth module.
AuthErrorCode: exhaustive enum of all auth failure conditions.
TokenError: exhaustive enum of JWT decode failures.
AuthErrorResponse: structured error payload for HTTP responses.
"""
from enum import StrEnum
from pydantic import BaseModel
class AuthErrorCode(StrEnum):
"""Exhaustive list of auth error conditions."""
INVALID_CREDENTIALS = "invalid_credentials"
TOKEN_EXPIRED = "token_expired"
TOKEN_INVALID = "token_invalid"
USER_NOT_FOUND = "user_not_found"
EMAIL_ALREADY_EXISTS = "email_already_exists"
PROVIDER_NOT_FOUND = "provider_not_found"
NOT_AUTHENTICATED = "not_authenticated"
SYSTEM_ALREADY_INITIALIZED = "system_already_initialized"
class TokenError(StrEnum):
"""Exhaustive list of JWT decode failure reasons."""
EXPIRED = "expired"
INVALID_SIGNATURE = "invalid_signature"
MALFORMED = "malformed"
class AuthErrorResponse(BaseModel):
"""Structured error response — replaces bare `detail` strings."""
code: AuthErrorCode
message: str
def token_error_to_code(err: TokenError) -> AuthErrorCode:
"""Map TokenError to AuthErrorCode — single source of truth."""
if err == TokenError.EXPIRED:
return AuthErrorCode.TOKEN_EXPIRED
return AuthErrorCode.TOKEN_INVALID

View File

@ -0,0 +1,55 @@
"""JWT token creation and verification."""
from datetime import UTC, datetime, timedelta
import jwt
from pydantic import BaseModel
from app.gateway.auth.config import get_auth_config
from app.gateway.auth.errors import TokenError
class TokenPayload(BaseModel):
"""JWT token payload."""
sub: str # user_id
exp: datetime
iat: datetime | None = None
ver: int = 0 # token_version — must match User.token_version
def create_access_token(user_id: str, expires_delta: timedelta | None = None, token_version: int = 0) -> str:
"""Create a JWT access token.
Args:
user_id: The user's UUID as string
expires_delta: Optional custom expiry, defaults to 7 days
token_version: User's current token_version for invalidation
Returns:
Encoded JWT string
"""
config = get_auth_config()
expiry = expires_delta or timedelta(days=config.token_expiry_days)
now = datetime.now(UTC)
payload = {"sub": user_id, "exp": now + expiry, "iat": now, "ver": token_version}
return jwt.encode(payload, config.jwt_secret, algorithm="HS256")
def decode_token(token: str) -> TokenPayload | TokenError:
"""Decode and validate a JWT token.
Returns:
TokenPayload if valid, or a specific TokenError variant.
"""
config = get_auth_config()
try:
payload = jwt.decode(token, config.jwt_secret, algorithms=["HS256"])
return TokenPayload(**payload)
except jwt.ExpiredSignatureError:
return TokenError.EXPIRED
except jwt.InvalidSignatureError:
return TokenError.INVALID_SIGNATURE
except jwt.PyJWTError:
return TokenError.MALFORMED

View File

@ -0,0 +1,104 @@
"""Local email/password authentication provider."""
import logging
from app.gateway.auth.models import User
from app.gateway.auth.password import hash_password_async, needs_rehash, verify_password_async
from app.gateway.auth.providers import AuthProvider
from app.gateway.auth.repositories.base import UserRepository
logger = logging.getLogger(__name__)
class LocalAuthProvider(AuthProvider):
"""Email/password authentication provider using local database."""
def __init__(self, repository: UserRepository):
"""Initialize with a UserRepository.
Args:
repository: UserRepository implementation (SQLite)
"""
self._repo = repository
async def authenticate(self, credentials: dict) -> User | None:
"""Authenticate with email and password.
Args:
credentials: dict with 'email' and 'password' keys
Returns:
User if authentication succeeds, None otherwise
"""
email = credentials.get("email")
password = credentials.get("password")
if not email or not password:
return None
user = await self._repo.get_user_by_email(email)
if user is None:
return None
if user.password_hash is None:
# OAuth user without local password
return None
if not await verify_password_async(password, user.password_hash):
return None
if needs_rehash(user.password_hash):
try:
user.password_hash = await hash_password_async(password)
await self._repo.update_user(user)
except Exception:
# Rehash is an opportunistic upgrade; a transient DB error must not
# prevent an otherwise-valid login from succeeding.
logger.warning("Failed to rehash password for user %s; login will still succeed", user.email, exc_info=True)
return user
async def get_user(self, user_id: str) -> User | None:
"""Get user by ID."""
return await self._repo.get_user_by_id(user_id)
async def create_user(self, email: str, password: str | None = None, system_role: str = "user", needs_setup: bool = False) -> User:
"""Create a new local user.
Args:
email: User email address
password: Plain text password (will be hashed)
system_role: Role to assign ("admin" or "user")
needs_setup: If True, user must complete setup on first login
Returns:
Created User instance
"""
password_hash = await hash_password_async(password) if password else None
user = User(
email=email,
password_hash=password_hash,
system_role=system_role,
needs_setup=needs_setup,
)
return await self._repo.create_user(user)
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
"""Get user by OAuth provider and ID."""
return await self._repo.get_user_by_oauth(provider, oauth_id)
async def count_users(self) -> int:
"""Return total number of registered users."""
return await self._repo.count_users()
async def count_admin_users(self) -> int:
"""Return number of admin users."""
return await self._repo.count_admin_users()
async def update_user(self, user: User) -> User:
"""Update an existing user."""
return await self._repo.update_user(user)
async def get_user_by_email(self, email: str) -> User | None:
"""Get user by email."""
return await self._repo.get_user_by_email(email)

View File

@ -0,0 +1,41 @@
"""User Pydantic models for authentication."""
from datetime import UTC, datetime
from typing import Literal
from uuid import UUID, uuid4
from pydantic import BaseModel, ConfigDict, EmailStr, Field
def _utc_now() -> datetime:
"""Return current UTC time (timezone-aware)."""
return datetime.now(UTC)
class User(BaseModel):
"""Internal user representation."""
model_config = ConfigDict(from_attributes=True)
id: UUID = Field(default_factory=uuid4, description="Primary key")
email: EmailStr = Field(..., description="Unique email address")
password_hash: str | None = Field(None, description="bcrypt hash, nullable for OAuth users")
system_role: Literal["admin", "user"] = Field(default="user")
created_at: datetime = Field(default_factory=_utc_now)
# OAuth linkage (optional)
oauth_provider: str | None = Field(None, description="e.g. 'github', 'google'")
oauth_id: str | None = Field(None, description="User ID from OAuth provider")
# Auth lifecycle
needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes")
token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs")
class UserResponse(BaseModel):
"""Response model for user info endpoint."""
id: str
email: str
system_role: Literal["admin", "user"]
needs_setup: bool = False

View File

@ -0,0 +1,81 @@
"""Password hashing utilities with versioned hash format.
Hash format: ``$dfv<N>$<bcrypt_hash>`` where ``<N>`` is the version.
- **v1** (legacy): ``bcrypt(password)`` plain bcrypt, susceptible to
72-byte silent truncation.
- **v2** (current): ``bcrypt(b64(sha256(password)))`` SHA-256 pre-hash
avoids the 72-byte truncation limit so the full password contributes
to the hash.
Verification auto-detects the version and falls back to v1 for hashes
without a prefix, so existing deployments upgrade transparently on next
login.
"""
import asyncio
import base64
import hashlib
import bcrypt
_CURRENT_VERSION = 2
_PREFIX_V2 = "$dfv2$"
_PREFIX_V1 = "$dfv1$"
def _pre_hash_v2(password: str) -> bytes:
"""SHA-256 pre-hash to bypass bcrypt's 72-byte limit."""
return base64.b64encode(hashlib.sha256(password.encode("utf-8")).digest())
def hash_password(password: str) -> str:
"""Hash a password (current version: v2 — SHA-256 + bcrypt)."""
raw = bcrypt.hashpw(_pre_hash_v2(password), bcrypt.gensalt()).decode("utf-8")
return f"{_PREFIX_V2}{raw}"
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""Verify a password, auto-detecting the hash version.
Accepts v2 (``$dfv2$``), v1 (``$dfv1$``), and bare bcrypt hashes
(treated as v1 for backward compatibility with pre-versioning data).
"""
try:
if hashed_password.startswith(_PREFIX_V2):
bcrypt_hash = hashed_password[len(_PREFIX_V2) :]
return bcrypt.checkpw(_pre_hash_v2(plain_password), bcrypt_hash.encode("utf-8"))
if hashed_password.startswith(_PREFIX_V1):
bcrypt_hash = hashed_password[len(_PREFIX_V1) :]
else:
bcrypt_hash = hashed_password
return bcrypt.checkpw(plain_password.encode("utf-8"), bcrypt_hash.encode("utf-8"))
except ValueError:
# bcrypt raises ValueError for malformed or corrupt hashes (e.g., invalid salt).
# Fail closed rather than crashing the request.
return False
def needs_rehash(hashed_password: str) -> bool:
"""Return True if the hash uses an older version and should be rehashed."""
return not hashed_password.startswith(_PREFIX_V2)
async def hash_password_async(password: str) -> str:
"""Hash a password using bcrypt (non-blocking).
Wraps the blocking bcrypt operation in a thread pool to avoid
blocking the event loop during password hashing.
"""
return await asyncio.to_thread(hash_password, password)
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
"""Verify a password against its hash (non-blocking).
Wraps the blocking bcrypt operation in a thread pool to avoid
blocking the event loop during password verification.
"""
return await asyncio.to_thread(verify_password, plain_password, hashed_password)

View File

@ -0,0 +1,24 @@
"""Auth provider abstraction."""
from abc import ABC, abstractmethod
class AuthProvider(ABC):
"""Abstract base class for authentication providers."""
@abstractmethod
async def authenticate(self, credentials: dict) -> "User | None":
"""Authenticate user with given credentials.
Returns User if authentication succeeds, None otherwise.
"""
raise NotImplementedError
@abstractmethod
async def get_user(self, user_id: str) -> "User | None":
"""Retrieve user by ID."""
raise NotImplementedError
# Import User at runtime to avoid circular imports
from app.gateway.auth.models import User # noqa: E402

View File

@ -0,0 +1,102 @@
"""User repository interface for abstracting database operations."""
from abc import ABC, abstractmethod
from app.gateway.auth.models import User
class UserNotFoundError(LookupError):
"""Raised when a user repository operation targets a non-existent row.
Subclass of :class:`LookupError` so callers that already catch
``LookupError`` for "missing entity" can keep working unchanged,
while specific call sites can pin to this class to distinguish
"concurrent delete during update" from other lookups.
"""
class UserRepository(ABC):
"""Abstract interface for user data storage.
Implement this interface to support different storage backends
(SQLite)
"""
@abstractmethod
async def create_user(self, user: User) -> User:
"""Create a new user.
Args:
user: User object to create
Returns:
Created User with ID assigned
Raises:
ValueError: If email already exists
"""
raise NotImplementedError
@abstractmethod
async def get_user_by_id(self, user_id: str) -> User | None:
"""Get user by ID.
Args:
user_id: User UUID as string
Returns:
User if found, None otherwise
"""
raise NotImplementedError
@abstractmethod
async def get_user_by_email(self, email: str) -> User | None:
"""Get user by email.
Args:
email: User email address
Returns:
User if found, None otherwise
"""
raise NotImplementedError
@abstractmethod
async def update_user(self, user: User) -> User:
"""Update an existing user.
Args:
user: User object with updated fields
Returns:
Updated User
Raises:
UserNotFoundError: If no row exists for ``user.id``. This is
a hard failure (not a no-op) so callers cannot mistake a
concurrent-delete race for a successful update.
"""
raise NotImplementedError
@abstractmethod
async def count_users(self) -> int:
"""Return total number of registered users."""
raise NotImplementedError
@abstractmethod
async def count_admin_users(self) -> int:
"""Return number of users with system_role == 'admin'."""
raise NotImplementedError
@abstractmethod
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
"""Get user by OAuth provider and ID.
Args:
provider: OAuth provider name (e.g. 'github', 'google')
oauth_id: User ID from the OAuth provider
Returns:
User if found, None otherwise
"""
raise NotImplementedError

View File

@ -0,0 +1,127 @@
"""SQLAlchemy-backed UserRepository implementation.
Uses the shared async session factory from
``deerflow.persistence.engine`` the ``users`` table lives in the
same database as ``threads_meta``, ``runs``, ``run_events``, and
``feedback``.
Constructor takes the session factory directly (same pattern as the
other four repositories in ``deerflow.persistence.*``). Callers
construct this after ``init_engine_from_config()`` has run.
"""
from __future__ import annotations
from datetime import UTC
from uuid import UUID
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.gateway.auth.models import User
from app.gateway.auth.repositories.base import UserNotFoundError, UserRepository
from deerflow.persistence.user.model import UserRow
class SQLiteUserRepository(UserRepository):
"""Async user repository backed by the shared SQLAlchemy engine."""
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory
# ── Converters ────────────────────────────────────────────────────
@staticmethod
def _row_to_user(row: UserRow) -> User:
return User(
id=UUID(row.id),
email=row.email,
password_hash=row.password_hash,
system_role=row.system_role, # type: ignore[arg-type]
# SQLite loses tzinfo on read; reattach UTC so downstream
# code can compare timestamps reliably.
created_at=row.created_at if row.created_at.tzinfo else row.created_at.replace(tzinfo=UTC),
oauth_provider=row.oauth_provider,
oauth_id=row.oauth_id,
needs_setup=row.needs_setup,
token_version=row.token_version,
)
@staticmethod
def _user_to_row(user: User) -> UserRow:
return UserRow(
id=str(user.id),
email=user.email,
password_hash=user.password_hash,
system_role=user.system_role,
created_at=user.created_at,
oauth_provider=user.oauth_provider,
oauth_id=user.oauth_id,
needs_setup=user.needs_setup,
token_version=user.token_version,
)
# ── CRUD ──────────────────────────────────────────────────────────
async def create_user(self, user: User) -> User:
"""Insert a new user. Raises ``ValueError`` on duplicate email."""
row = self._user_to_row(user)
async with self._sf() as session:
session.add(row)
try:
await session.commit()
except IntegrityError as exc:
await session.rollback()
raise ValueError(f"Email already registered: {user.email}") from exc
return user
async def get_user_by_id(self, user_id: str) -> User | None:
async with self._sf() as session:
row = await session.get(UserRow, user_id)
return self._row_to_user(row) if row is not None else None
async def get_user_by_email(self, email: str) -> User | None:
stmt = select(UserRow).where(UserRow.email == email)
async with self._sf() as session:
result = await session.execute(stmt)
row = result.scalar_one_or_none()
return self._row_to_user(row) if row is not None else None
async def update_user(self, user: User) -> User:
async with self._sf() as session:
row = await session.get(UserRow, str(user.id))
if row is None:
# Hard fail on concurrent delete: callers (reset_admin,
# password change handlers, _ensure_admin_user) all
# fetched the user just before this call, so a missing
# row here means the row vanished underneath us. Silent
# success would let the caller log "password reset" for
# a row that no longer exists.
raise UserNotFoundError(f"User {user.id} no longer exists")
row.email = user.email
row.password_hash = user.password_hash
row.system_role = user.system_role
row.oauth_provider = user.oauth_provider
row.oauth_id = user.oauth_id
row.needs_setup = user.needs_setup
row.token_version = user.token_version
await session.commit()
return user
async def count_users(self) -> int:
stmt = select(func.count()).select_from(UserRow)
async with self._sf() as session:
return await session.scalar(stmt) or 0
async def count_admin_users(self) -> int:
stmt = select(func.count()).select_from(UserRow).where(UserRow.system_role == "admin")
async with self._sf() as session:
return await session.scalar(stmt) or 0
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id)
async with self._sf() as session:
result = await session.execute(stmt)
row = result.scalar_one_or_none()
return self._row_to_user(row) if row is not None else None

View File

@ -0,0 +1,91 @@
"""CLI tool to reset an admin password.
Usage:
python -m app.gateway.auth.reset_admin
python -m app.gateway.auth.reset_admin --email admin@example.com
Writes the new password to ``.deer-flow/admin_initial_credentials.txt``
(mode 0600) instead of printing it, so CI / log aggregators never see
the cleartext secret.
"""
from __future__ import annotations
import argparse
import asyncio
import secrets
import sys
from sqlalchemy import select
from app.gateway.auth.credential_file import write_initial_credentials
from app.gateway.auth.password import hash_password
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
from deerflow.persistence.user.model import UserRow
async def _run(email: str | None) -> int:
from deerflow.config import get_app_config
from deerflow.persistence.engine import (
close_engine,
get_session_factory,
init_engine_from_config,
)
config = get_app_config()
await init_engine_from_config(config.database)
try:
sf = get_session_factory()
if sf is None:
print("Error: persistence engine not available (check config.database).", file=sys.stderr)
return 1
repo = SQLiteUserRepository(sf)
if email:
user = await repo.get_user_by_email(email)
else:
# Find first admin via direct SELECT — repository does not
# expose a "first admin" helper and we do not want to add
# one just for this CLI.
async with sf() as session:
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
row = (await session.execute(stmt)).scalar_one_or_none()
if row is None:
user = None
else:
user = await repo.get_user_by_id(row.id)
if user is None:
if email:
print(f"Error: user '{email}' not found.", file=sys.stderr)
else:
print("Error: no admin user found.", file=sys.stderr)
return 1
new_password = secrets.token_urlsafe(16)
user.password_hash = hash_password(new_password)
user.token_version += 1
user.needs_setup = True
await repo.update_user(user)
cred_path = write_initial_credentials(user.email, new_password, label="reset")
print(f"Password reset for: {user.email}")
print(f"Credentials written to: {cred_path} (mode 0600)")
print("Next login will require setup (new email + password).")
return 0
finally:
await close_engine()
def main() -> None:
parser = argparse.ArgumentParser(description="Reset admin password")
parser.add_argument("--email", help="Admin email (default: first admin found)")
args = parser.parse_args()
exit_code = asyncio.run(_run(args.email))
sys.exit(exit_code)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,126 @@
"""Global authentication middleware — fail-closed safety net.
Rejects unauthenticated requests to non-public paths with 401. When a
request passes the cookie check, resolves the JWT payload to a real
``User`` object and stamps it into both ``request.state.user`` and the
``deerflow.runtime.user_context`` contextvar so that repository-layer
owner filtering works automatically via the sentinel pattern.
Fine-grained permission checks remain in authz.py decorators.
"""
from collections.abc import Callable
from fastapi import HTTPException, Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from starlette.types import ASGIApp
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
from deerflow.runtime.user_context import reset_current_user, set_current_user
# Paths that never require authentication.
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
"/health",
"/docs",
"/redoc",
"/openapi.json",
)
# Exact auth paths that are public (login/register/status check).
# /api/v1/auth/me, /api/v1/auth/change-password etc. are NOT public.
_PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
{
"/api/v1/auth/login/local",
"/api/v1/auth/register",
"/api/v1/auth/logout",
"/api/v1/auth/setup-status",
"/api/v1/auth/initialize",
}
)
def _is_public(path: str) -> bool:
stripped = path.rstrip("/")
if stripped in _PUBLIC_EXACT_PATHS:
return True
return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES)
class AuthMiddleware(BaseHTTPMiddleware):
"""Strict auth gate: reject requests without a valid session.
Two-stage check for non-public paths:
1. Cookie presence return 401 NOT_AUTHENTICATED if missing
2. JWT validation via ``get_optional_user_from_request`` return 401
TOKEN_INVALID if the token is absent, malformed, expired, or the
signed user does not exist / is stale
On success, stamps ``request.state.user`` and the
``deerflow.runtime.user_context`` contextvar so that repository-layer
owner filters work downstream without every route needing a
``@require_auth`` decorator. Routes that need per-resource
authorization (e.g. "user A cannot read user B's thread by guessing
the URL") should additionally use ``@require_permission(...,
owner_check=True)`` for explicit enforcement but authentication
itself is fully handled here.
"""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
if _is_public(request.url.path):
return await call_next(request)
internal_user = None
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
internal_user = get_internal_user()
# Non-public path: require session cookie
if internal_user is None and not request.cookies.get("access_token"):
return JSONResponse(
status_code=401,
content={
"detail": AuthErrorResponse(
code=AuthErrorCode.NOT_AUTHENTICATED,
message="Authentication required",
).model_dump()
},
)
# Strict JWT validation: reject junk/expired tokens with 401
# right here instead of silently passing through. This closes
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
# without this, non-isolation routes like /api/models would
# accept any cookie-shaped string as authentication.
#
# We call the *strict* resolver so that fine-grained error
# codes (token_expired, token_invalid, user_not_found, …)
# propagate from AuthErrorCode, not get flattened into one
# generic code. BaseHTTPMiddleware doesn't let HTTPException
# bubble up, so we catch and render it as JSONResponse here.
from app.gateway.deps import get_current_user_from_request
if internal_user is not None:
user = internal_user
else:
try:
user = await get_current_user_from_request(request)
except HTTPException as exc:
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
# Stamp both request.state.user (for the contextvar pattern)
# and request.state.auth (so @require_permission's "auth is
# None" branch short-circuits instead of running the entire
# JWT-decode + DB-lookup pipeline a second time per request).
request.state.user = user
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
token = set_current_user(user)
try:
return await call_next(request)
finally:
reset_current_user(token)

View File

@ -0,0 +1,301 @@
"""Authorization decorators and context for DeerFlow.
Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blob/main/libs/sdk-py/langgraph_sdk/auth/__init__.py
**Usage:**
1. Use ``@require_auth`` on routes that need authentication
2. Use ``@require_permission("resource", "action", filter_key=...)`` for permission checks
3. The decorator chain processes from bottom to top
**Example:**
@router.get("/{thread_id}")
@require_auth
@require_permission("threads", "read", owner_check=True)
async def get_thread(thread_id: str, request: Request):
# User is authenticated and has threads:read permission
...
**Permission Model:**
- threads:read - View thread
- threads:write - Create/update thread
- threads:delete - Delete thread
- runs:create - Run agent
- runs:read - View run
- runs:cancel - Cancel run
"""
from __future__ import annotations
import functools
import inspect
from collections.abc import Callable
from types import SimpleNamespace
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
from fastapi import HTTPException, Request
if TYPE_CHECKING:
from app.gateway.auth.models import User
P = ParamSpec("P")
T = TypeVar("T")
# Permission constants
class Permissions:
"""Permission constants for resource:action format."""
# Threads
THREADS_READ = "threads:read"
THREADS_WRITE = "threads:write"
THREADS_DELETE = "threads:delete"
# Runs
RUNS_CREATE = "runs:create"
RUNS_READ = "runs:read"
RUNS_CANCEL = "runs:cancel"
class AuthContext:
"""Authentication context for the current request.
Stored in request.state.auth after require_auth decoration.
Attributes:
user: The authenticated user, or None if anonymous
permissions: List of permission strings (e.g., "threads:read")
"""
__slots__ = ("user", "permissions")
def __init__(self, user: User | None = None, permissions: list[str] | None = None):
self.user = user
self.permissions = permissions or []
@property
def is_authenticated(self) -> bool:
"""Check if user is authenticated."""
return self.user is not None
def has_permission(self, resource: str, action: str) -> bool:
"""Check if context has permission for resource:action.
Args:
resource: Resource name (e.g., "threads")
action: Action name (e.g., "read")
Returns:
True if user has permission
"""
permission = f"{resource}:{action}"
return permission in self.permissions
def require_user(self) -> User:
"""Get user or raise 401.
Raises:
HTTPException 401 if not authenticated
"""
if not self.user:
raise HTTPException(status_code=401, detail="Authentication required")
return self.user
def get_auth_context(request: Request) -> AuthContext | None:
"""Get AuthContext from request state."""
return getattr(request.state, "auth", None)
_ALL_PERMISSIONS: list[str] = [
Permissions.THREADS_READ,
Permissions.THREADS_WRITE,
Permissions.THREADS_DELETE,
Permissions.RUNS_CREATE,
Permissions.RUNS_READ,
Permissions.RUNS_CANCEL,
]
def _make_test_request_stub() -> Any:
"""Create a minimal request-like object for direct unit calls.
Used when decorated route handlers are invoked without FastAPI's
request injection. Includes fields accessed by auth helpers.
"""
return SimpleNamespace(state=SimpleNamespace(), cookies={}, _deerflow_test_bypass_auth=True)
async def _authenticate(request: Request) -> AuthContext:
"""Authenticate request and return AuthContext.
Delegates to deps.get_optional_user_from_request() for the JWTUser pipeline.
Returns AuthContext with user=None for anonymous requests.
"""
from app.gateway.deps import get_optional_user_from_request
user = await get_optional_user_from_request(request)
if user is None:
return AuthContext(user=None, permissions=[])
# In future, permissions could be stored in user record
return AuthContext(user=user, permissions=_ALL_PERMISSIONS)
def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
"""Decorator that authenticates the request and enforces authentication.
Independently raises HTTP 401 for unauthenticated requests, regardless of
whether ``AuthMiddleware`` is present in the ASGI stack. Sets the resolved
``AuthContext`` on ``request.state.auth`` for downstream handlers.
Must be placed ABOVE other decorators (executes after them).
Usage:
@router.get("/{thread_id}")
@require_auth # Bottom decorator (executes first after permission check)
@require_permission("threads", "read")
async def get_thread(thread_id: str, request: Request):
auth: AuthContext = request.state.auth
...
Raises:
HTTPException: 401 if the request is unauthenticated.
ValueError: If 'request' parameter is missing.
"""
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
request = kwargs.get("request")
if request is None:
# Unit tests may call decorated handlers directly without a
# FastAPI Request object. Inject a minimal request stub when
# the wrapped function declares `request`.
if "request" in inspect.signature(func).parameters:
kwargs["request"] = _make_test_request_stub()
else:
raise ValueError("require_auth decorator requires 'request' parameter")
request = kwargs["request"]
if getattr(request, "_deerflow_test_bypass_auth", False):
return await func(*args, **kwargs)
# Authenticate and set context
auth_context = await _authenticate(request)
request.state.auth = auth_context
if not auth_context.is_authenticated:
raise HTTPException(status_code=401, detail="Authentication required")
return await func(*args, **kwargs)
return wrapper
def require_permission(
resource: str,
action: str,
owner_check: bool = False,
require_existing: bool = False,
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Decorator that checks permission for resource:action.
Must be used AFTER @require_auth.
Args:
resource: Resource name (e.g., "threads", "runs")
action: Action name (e.g., "read", "write", "delete")
owner_check: If True, validates that the current user owns the resource.
Requires 'thread_id' path parameter and performs ownership check.
require_existing: Only meaningful with ``owner_check=True``. If True, a
missing ``threads_meta`` row counts as a denial (404)
instead of "untracked legacy thread, allow". Use on
**destructive / mutating** routes (DELETE, PATCH,
state-update) so a deleted thread can't be re-targeted
by another user via the missing-row code path.
Usage:
# Read-style: legacy untracked threads are allowed
@require_permission("threads", "read", owner_check=True)
async def get_thread(thread_id: str, request: Request):
...
# Destructive: thread row MUST exist and be owned by caller
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_thread(thread_id: str, request: Request):
...
Raises:
HTTPException 401: If authentication required but user is anonymous
HTTPException 403: If user lacks permission
HTTPException 404: If owner_check=True but user doesn't own the thread
ValueError: If owner_check=True but 'thread_id' parameter is missing
"""
def decorator(func: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
request = kwargs.get("request")
if request is None:
# Unit tests may call decorated route handlers directly without
# constructing a FastAPI Request object. Inject a minimal stub
# when the wrapped function declares `request`.
if "request" in inspect.signature(func).parameters:
kwargs["request"] = _make_test_request_stub()
else:
return await func(*args, **kwargs)
request = kwargs["request"]
if getattr(request, "_deerflow_test_bypass_auth", False):
return await func(*args, **kwargs)
auth: AuthContext = getattr(request.state, "auth", None)
if auth is None:
auth = await _authenticate(request)
request.state.auth = auth
if not auth.is_authenticated:
raise HTTPException(status_code=401, detail="Authentication required")
# Check permission
if not auth.has_permission(resource, action):
raise HTTPException(
status_code=403,
detail=f"Permission denied: {resource}:{action}",
)
# Owner check for thread-specific resources.
#
# 2.0-rc moved thread metadata into the SQL persistence layer
# (``threads_meta`` table). We verify ownership via
# ``ThreadMetaStore.check_access``: it returns True for
# missing rows (untracked legacy thread) and for rows whose
# ``user_id`` is NULL (shared / pre-auth data), so this is
# strict-deny rather than strict-allow — only an *existing*
# row with a *different* user_id triggers 404.
if owner_check:
thread_id = kwargs.get("thread_id")
if thread_id is None:
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
from app.gateway.deps import get_thread_store
thread_store = get_thread_store(request)
allowed = await thread_store.check_access(
thread_id,
str(auth.user.id),
require_existing=require_existing,
)
if not allowed:
raise HTTPException(
status_code=404,
detail=f"Thread {thread_id} not found",
)
return await func(*args, **kwargs)
return wrapper
return decorator

View File

@ -9,6 +9,7 @@ class GatewayConfig(BaseModel):
host: str = Field(default="0.0.0.0", description="Host to bind the gateway server")
port: int = Field(default=8001, description="Port to bind the gateway server")
cors_origins: list[str] = Field(default_factory=lambda: ["http://localhost:3000"], description="Allowed CORS origins")
enable_docs: bool = Field(default=True, description="Enable Swagger/ReDoc/OpenAPI endpoints")
_gateway_config: GatewayConfig | None = None
@ -23,5 +24,6 @@ def get_gateway_config() -> GatewayConfig:
host=os.getenv("GATEWAY_HOST", "0.0.0.0"),
port=int(os.getenv("GATEWAY_PORT", "8001")),
cors_origins=cors_origins_str.split(","),
enable_docs=os.getenv("GATEWAY_ENABLE_DOCS", "true").lower() == "true",
)
return _gateway_config

View File

@ -0,0 +1,113 @@
"""CSRF protection middleware for FastAPI.
Per RFC-001:
State-changing operations require CSRF protection.
"""
import secrets
from collections.abc import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
from starlette.types import ASGIApp
CSRF_COOKIE_NAME = "csrf_token"
CSRF_HEADER_NAME = "X-CSRF-Token"
CSRF_TOKEN_LENGTH = 64 # bytes
def is_secure_request(request: Request) -> bool:
"""Detect whether the original client request was made over HTTPS."""
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
def generate_csrf_token() -> str:
"""Generate a secure random CSRF token."""
return secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
def should_check_csrf(request: Request) -> bool:
"""Determine if a request needs CSRF validation.
CSRF is checked for state-changing methods (POST, PUT, DELETE, PATCH).
GET, HEAD, OPTIONS, and TRACE are exempt per RFC 7231.
"""
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
return False
path = request.url.path.rstrip("/")
# Exempt /api/v1/auth/me endpoint
if path == "/api/v1/auth/me":
return False
return True
_AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
{
"/api/v1/auth/login/local",
"/api/v1/auth/logout",
"/api/v1/auth/register",
"/api/v1/auth/initialize",
}
)
def is_auth_endpoint(request: Request) -> bool:
"""Check if the request is to an auth endpoint.
Auth endpoints don't need CSRF validation on first call (no token).
"""
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
class CSRFMiddleware(BaseHTTPMiddleware):
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
_is_auth = is_auth_endpoint(request)
if should_check_csrf(request) and not _is_auth:
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
header_token = request.headers.get(CSRF_HEADER_NAME)
if not cookie_token or not header_token:
return JSONResponse(
status_code=403,
content={"detail": "CSRF token missing. Include X-CSRF-Token header."},
)
if not secrets.compare_digest(cookie_token, header_token):
return JSONResponse(
status_code=403,
content={"detail": "CSRF token mismatch."},
)
response = await call_next(request)
# For auth endpoints that set up session, also set CSRF cookie
if _is_auth and request.method == "POST":
# Generate a new CSRF token for the session
csrf_token = generate_csrf_token()
is_https = is_secure_request(request)
response.set_cookie(
key=CSRF_COOKIE_NAME,
value=csrf_token,
httponly=False, # Must be JS-readable for Double Submit Cookie pattern
secure=is_https,
samesite="strict",
)
return response
def get_csrf_token(request: Request) -> str | None:
"""Get the CSRF token from the current request's cookies.
This is useful for server-side rendering where you need to embed
token in forms or headers.
"""
return request.cookies.get(CSRF_COOKIE_NAME)

View File

@ -8,12 +8,34 @@ Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
from __future__ import annotations
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING, TypeVar, cast
from fastapi import FastAPI, HTTPException, Request
from langgraph.types import Checkpointer
from deerflow.runtime import RunManager, StreamBridge
from deerflow.config.app_config import AppConfig
from deerflow.persistence.feedback import FeedbackRepository
from deerflow.runtime import RunContext, RunManager, StreamBridge
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.runs.store.base import RunStore
if TYPE_CHECKING:
from app.gateway.auth.local_provider import LocalAuthProvider
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
from deerflow.persistence.thread_meta.base import ThreadMetaStore
T = TypeVar("T")
def get_config(request: Request) -> AppConfig:
"""Return the app-scoped ``AppConfig`` stored on ``app.state``."""
config = getattr(request.app.state, "config", None)
if config is None:
raise HTTPException(status_code=503, detail="Configuration not available")
return config
@asynccontextmanager
@ -25,15 +47,54 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
async with langgraph_runtime(app):
yield
"""
from deerflow.agents.checkpointer.async_provider import make_checkpointer
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
from deerflow.runtime import make_store, make_stream_bridge
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
from deerflow.runtime.events.store import make_run_event_store
async with AsyncExitStack() as stack:
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
app.state.store = await stack.enter_async_context(make_store())
app.state.run_manager = RunManager()
yield
config = getattr(app.state, "config", None)
if config is None:
raise RuntimeError("langgraph_runtime() requires app.state.config to be initialized")
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge(config))
# Initialize persistence engine BEFORE checkpointer so that
# auto-create-database logic runs first (postgres backend).
await init_engine_from_config(config.database)
app.state.checkpointer = await stack.enter_async_context(make_checkpointer(config))
app.state.store = await stack.enter_async_context(make_store(config))
# Initialize repositories — one get_session_factory() call for all.
sf = get_session_factory()
if sf is not None:
from deerflow.persistence.feedback import FeedbackRepository
from deerflow.persistence.run import RunRepository
app.state.run_store = RunRepository(sf)
app.state.feedback_repo = FeedbackRepository(sf)
else:
from deerflow.runtime.runs.store.memory import MemoryRunStore
app.state.run_store = MemoryRunStore()
app.state.feedback_repo = None
from deerflow.persistence.thread_meta import make_thread_store
app.state.thread_store = make_thread_store(sf, app.state.store)
# Run event store (has its own factory with config-driven backend selection)
run_events_config = getattr(config, "run_events", None)
app.state.run_event_store = make_run_event_store(run_events_config)
# RunManager with store backing for persistence
app.state.run_manager = RunManager(store=app.state.run_store)
try:
yield
finally:
await close_engine()
# ---------------------------------------------------------------------------
@ -41,30 +102,144 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
# ---------------------------------------------------------------------------
def get_stream_bridge(request: Request) -> StreamBridge:
"""Return the global :class:`StreamBridge`, or 503."""
bridge = getattr(request.app.state, "stream_bridge", None)
if bridge is None:
raise HTTPException(status_code=503, detail="Stream bridge not available")
return bridge
def _require(attr: str, label: str) -> Callable[[Request], T]:
"""Create a FastAPI dependency that returns ``app.state.<attr>`` or 503."""
def dep(request: Request) -> T:
val = getattr(request.app.state, attr, None)
if val is None:
raise HTTPException(status_code=503, detail=f"{label} not available")
return cast(T, val)
dep.__name__ = dep.__qualname__ = f"get_{attr}"
return dep
def get_run_manager(request: Request) -> RunManager:
"""Return the global :class:`RunManager`, or 503."""
mgr = getattr(request.app.state, "run_manager", None)
if mgr is None:
raise HTTPException(status_code=503, detail="Run manager not available")
return mgr
def get_checkpointer(request: Request):
"""Return the global checkpointer, or 503."""
cp = getattr(request.app.state, "checkpointer", None)
if cp is None:
raise HTTPException(status_code=503, detail="Checkpointer not available")
return cp
get_stream_bridge: Callable[[Request], StreamBridge] = _require("stream_bridge", "Stream bridge")
get_run_manager: Callable[[Request], RunManager] = _require("run_manager", "Run manager")
get_checkpointer: Callable[[Request], Checkpointer] = _require("checkpointer", "Checkpointer")
get_run_event_store: Callable[[Request], RunEventStore] = _require("run_event_store", "Run event store")
get_feedback_repo: Callable[[Request], FeedbackRepository] = _require("feedback_repo", "Feedback")
get_run_store: Callable[[Request], RunStore] = _require("run_store", "Run store")
def get_store(request: Request):
"""Return the global store (may be ``None`` if not configured)."""
return getattr(request.app.state, "store", None)
def get_thread_store(request: Request) -> ThreadMetaStore:
"""Return the thread metadata store (SQL or memory-backed)."""
val = getattr(request.app.state, "thread_store", None)
if val is None:
raise HTTPException(status_code=503, detail="Thread metadata store not available")
return val
def get_run_context(request: Request) -> RunContext:
"""Build a :class:`RunContext` from ``app.state`` singletons.
Returns a *base* context with infrastructure dependencies.
"""
config = get_config(request)
return RunContext(
checkpointer=get_checkpointer(request),
store=get_store(request),
event_store=get_run_event_store(request),
run_events_config=getattr(config, "run_events", None),
thread_store=get_thread_store(request),
app_config=config,
)
# ---------------------------------------------------------------------------
# Auth helpers (used by authz.py and auth middleware)
# ---------------------------------------------------------------------------
# Cached singletons to avoid repeated instantiation per request
_cached_local_provider: LocalAuthProvider | None = None
_cached_repo: SQLiteUserRepository | None = None
def get_local_provider() -> LocalAuthProvider:
"""Get or create the cached LocalAuthProvider singleton.
Must be called after ``init_engine_from_config()`` the shared
session factory is required to construct the user repository.
"""
global _cached_local_provider, _cached_repo
if _cached_repo is None:
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
from deerflow.persistence.engine import get_session_factory
sf = get_session_factory()
if sf is None:
raise RuntimeError("get_local_provider() called before init_engine_from_config(); cannot access users table")
_cached_repo = SQLiteUserRepository(sf)
if _cached_local_provider is None:
from app.gateway.auth.local_provider import LocalAuthProvider
_cached_local_provider = LocalAuthProvider(repository=_cached_repo)
return _cached_local_provider
async def get_current_user_from_request(request: Request):
"""Get the current authenticated user from the request cookie.
Raises HTTPException 401 if not authenticated.
"""
from app.gateway.auth import decode_token
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
access_token = request.cookies.get("access_token")
if not access_token:
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(),
)
payload = decode_token(access_token)
if isinstance(payload, TokenError):
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=token_error_to_code(payload), message=f"Token error: {payload.value}").model_dump(),
)
provider = get_local_provider()
user = await provider.get_user(payload.sub)
if user is None:
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(),
)
# Token version mismatch → password was changed, token is stale
if user.token_version != payload.ver:
raise HTTPException(
status_code=401,
detail=AuthErrorResponse(code=AuthErrorCode.TOKEN_INVALID, message="Token revoked (password changed)").model_dump(),
)
return user
async def get_optional_user_from_request(request: Request):
"""Get optional authenticated user from request.
Returns None if not authenticated.
"""
try:
return await get_current_user_from_request(request)
except HTTPException:
return None
async def get_current_user(request: Request) -> str | None:
"""Extract user_id from request cookie, or None if not authenticated.
Thin adapter that returns the string id for callers that only need
identification (e.g., ``feedback.py``). Full-user callers should use
``get_current_user_from_request`` or ``get_optional_user_from_request``.
"""
user = await get_optional_user_from_request(request)
return str(user.id) if user else None

View File

@ -0,0 +1,26 @@
"""Process-local authentication for Gateway internal callers."""
from __future__ import annotations
import secrets
from types import SimpleNamespace
from deerflow.runtime.user_context import DEFAULT_USER_ID
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
_INTERNAL_AUTH_TOKEN = secrets.token_urlsafe(32)
def create_internal_auth_headers() -> dict[str, str]:
"""Return headers that authenticate same-process Gateway internal calls."""
return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
def is_valid_internal_auth_token(token: str | None) -> bool:
"""Return True when *token* matches the process-local internal token."""
return bool(token) and secrets.compare_digest(token, _INTERNAL_AUTH_TOKEN)
def get_internal_user():
"""Return the synthetic user used for trusted internal channel calls."""
return SimpleNamespace(id=DEFAULT_USER_ID, system_role="internal")

View File

@ -0,0 +1,106 @@
"""LangGraph Server auth handler — shares JWT logic with Gateway.
Loaded by LangGraph Server via langgraph.json ``auth.path``.
Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway,
so both modes validate tokens with the same secret and rules.
Two layers:
1. @auth.authenticate validates JWT cookie, extracts user_id,
and enforces CSRF on state-changing methods (POST/PUT/DELETE/PATCH)
2. @auth.on returns metadata filter so each user only sees own threads
"""
import secrets
from langgraph_sdk import Auth
from app.gateway.auth.errors import TokenError
from app.gateway.auth.jwt import decode_token
from app.gateway.deps import get_local_provider
auth = Auth()
# Methods that require CSRF validation (state-changing per RFC 7231).
_CSRF_METHODS = frozenset({"POST", "PUT", "DELETE", "PATCH"})
def _check_csrf(request) -> None:
"""Enforce Double Submit Cookie CSRF check for state-changing requests.
Mirrors Gateway's CSRFMiddleware logic so that LangGraph routes
proxied directly by nginx have the same CSRF protection.
"""
method = getattr(request, "method", "") or ""
if method.upper() not in _CSRF_METHODS:
return
cookie_token = request.cookies.get("csrf_token")
header_token = request.headers.get("x-csrf-token")
if not cookie_token or not header_token:
raise Auth.exceptions.HTTPException(
status_code=403,
detail="CSRF token missing. Include X-CSRF-Token header.",
)
if not secrets.compare_digest(cookie_token, header_token):
raise Auth.exceptions.HTTPException(
status_code=403,
detail="CSRF token mismatch.",
)
@auth.authenticate
async def authenticate(request):
"""Validate the session cookie, decode JWT, and check token_version.
Same validation chain as Gateway's get_current_user_from_request:
cookie decode JWT DB lookup token_version match
Also enforces CSRF on state-changing methods.
"""
# CSRF check before authentication so forged cross-site requests
# are rejected early, even if the cookie carries a valid JWT.
_check_csrf(request)
token = request.cookies.get("access_token")
if not token:
raise Auth.exceptions.HTTPException(
status_code=401,
detail="Not authenticated",
)
payload = decode_token(token)
if isinstance(payload, TokenError):
raise Auth.exceptions.HTTPException(
status_code=401,
detail="Invalid token",
)
user = await get_local_provider().get_user(payload.sub)
if user is None:
raise Auth.exceptions.HTTPException(
status_code=401,
detail="User not found",
)
if user.token_version != payload.ver:
raise Auth.exceptions.HTTPException(
status_code=401,
detail="Token revoked (password changed)",
)
return payload.sub
@auth.on
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
"""Inject user_id metadata on writes; filter by user_id on reads.
Gateway stores thread ownership as ``metadata.user_id``.
This handler ensures LangGraph Server enforces the same isolation.
"""
# On create/update: stamp user_id into metadata
metadata = value.setdefault("metadata", {})
metadata["user_id"] = ctx.user.identity
# Return filter dict — LangGraph applies it to search/read/delete
return {"user_id": ctx.user.identity}

View File

@ -5,6 +5,7 @@ from pathlib import Path
from fastapi import HTTPException
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
@ -22,7 +23,7 @@ def resolve_thread_virtual_path(thread_id: str, virtual_path: str) -> Path:
HTTPException: If the path is invalid or outside allowed directories.
"""
try:
return get_paths().resolve_virtual_path(thread_id, virtual_path)
return get_paths().resolve_virtual_path(thread_id, virtual_path, user_id=get_effective_user_id())
except ValueError as e:
status = 403 if "traversal" in str(e) else 400
raise HTTPException(status_code=status, detail=str(e))

View File

@ -25,6 +25,7 @@ class AgentResponse(BaseModel):
description: str = Field(default="", description="Agent description")
model: str | None = Field(default=None, description="Optional model override")
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
skills: list[str] | None = Field(default=None, description="Optional skill whitelist (None=all, []=none)")
soul: str | None = Field(default=None, description="SOUL.md content")
@ -41,6 +42,7 @@ class AgentCreateRequest(BaseModel):
description: str = Field(default="", description="Agent description")
model: str | None = Field(default=None, description="Optional model override")
tool_groups: list[str] | None = Field(default=None, description="Optional tool group whitelist")
skills: list[str] | None = Field(default=None, description="Optional skill whitelist (None=all enabled, []=none)")
soul: str = Field(default="", description="SOUL.md content — agent personality and behavioral guardrails")
@ -50,6 +52,7 @@ class AgentUpdateRequest(BaseModel):
description: str | None = Field(default=None, description="Updated description")
model: str | None = Field(default=None, description="Updated model override")
tool_groups: list[str] | None = Field(default=None, description="Updated tool group whitelist")
skills: list[str] | None = Field(default=None, description="Updated skill whitelist (None=all, []=none)")
soul: str | None = Field(default=None, description="Updated SOUL.md content")
@ -94,6 +97,7 @@ def _agent_config_to_response(agent_cfg: AgentConfig, include_soul: bool = False
description=agent_cfg.description,
model=agent_cfg.model,
tool_groups=agent_cfg.tool_groups,
skills=agent_cfg.skills,
soul=soul,
)
@ -215,6 +219,8 @@ async def create_agent_endpoint(request: AgentCreateRequest) -> AgentResponse:
config_data["model"] = request.model
if request.tool_groups is not None:
config_data["tool_groups"] = request.tool_groups
if request.skills is not None:
config_data["skills"] = request.skills
config_file = agent_dir / "config.yaml"
with open(config_file, "w", encoding="utf-8") as f:
@ -271,21 +277,32 @@ async def update_agent(name: str, request: AgentUpdateRequest) -> AgentResponse:
try:
# Update config if any config fields changed
config_changed = any(v is not None for v in [request.description, request.model, request.tool_groups])
# Use model_fields_set to distinguish "field omitted" from "explicitly set to null".
# This is critical for skills where None means "inherit all" (not "don't change").
fields_set = request.model_fields_set
config_changed = bool(fields_set & {"description", "model", "tool_groups", "skills"})
if config_changed:
updated: dict = {
"name": agent_cfg.name,
"description": request.description if request.description is not None else agent_cfg.description,
"description": request.description if "description" in fields_set else agent_cfg.description,
}
new_model = request.model if request.model is not None else agent_cfg.model
new_model = request.model if "model" in fields_set else agent_cfg.model
if new_model is not None:
updated["model"] = new_model
new_tool_groups = request.tool_groups if request.tool_groups is not None else agent_cfg.tool_groups
new_tool_groups = request.tool_groups if "tool_groups" in fields_set else agent_cfg.tool_groups
if new_tool_groups is not None:
updated["tool_groups"] = new_tool_groups
# skills: None = inherit all, [] = no skills, ["a","b"] = whitelist
if "skills" in fields_set:
new_skills = request.skills
else:
new_skills = agent_cfg.skills
if new_skills is not None:
updated["skills"] = new_skills
config_file = agent_dir / "config.yaml"
with open(config_file, "w", encoding="utf-8") as f:
yaml.dump(updated, f, default_flow_style=False, allow_unicode=True)

View File

@ -7,6 +7,7 @@ from urllib.parse import quote
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import FileResponse, PlainTextResponse, Response
from app.gateway.authz import require_permission
from app.gateway.path_utils import resolve_thread_virtual_path
logger = logging.getLogger(__name__)
@ -81,6 +82,7 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte
summary="Get Artifact File",
description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.",
)
@require_permission("threads", "read", owner_check=True)
async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response:
"""Get an artifact file by its path.

View File

@ -0,0 +1,493 @@
"""Authentication endpoints."""
import logging
import os
import time
from ipaddress import ip_address, ip_network
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.security import OAuth2PasswordRequestForm
from pydantic import BaseModel, EmailStr, Field, field_validator
from app.gateway.auth import (
UserResponse,
create_access_token,
)
from app.gateway.auth.config import get_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
from app.gateway.csrf_middleware import is_secure_request
from app.gateway.deps import get_current_user_from_request, get_local_provider
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
# ── Request/Response Models ──────────────────────────────────────────────
class LoginResponse(BaseModel):
"""Response model for login — token only lives in HttpOnly cookie."""
expires_in: int # seconds
needs_setup: bool = False
# Top common-password blocklist. Drawn from the public SecLists "10k worst
# passwords" set, lowercased + length>=8 only (shorter ones already fail
# the min_length check). Kept tight on purpose: this is the **lower bound**
# defense, not a full HIBP / passlib check, and runs in-process per request.
_COMMON_PASSWORDS: frozenset[str] = frozenset(
{
"password",
"password1",
"password12",
"password123",
"password1234",
"12345678",
"123456789",
"1234567890",
"qwerty12",
"qwertyui",
"qwerty123",
"abc12345",
"abcd1234",
"iloveyou",
"letmein1",
"welcome1",
"welcome123",
"admin123",
"administrator",
"passw0rd",
"p@ssw0rd",
"monkey12",
"trustno1",
"sunshine",
"princess",
"football",
"baseball",
"superman",
"batman123",
"starwars",
"dragon123",
"master123",
"shadow12",
"michael1",
"jennifer",
"computer",
}
)
def _password_is_common(password: str) -> bool:
"""Case-insensitive blocklist check.
Lowercases the input so trivial mutations like ``Password`` /
``PASSWORD`` are also rejected. Does not normalize digit substitutions
(``p@ssw0rd`` is included as a literal entry instead) keeping the
rule cheap and predictable.
"""
return password.lower() in _COMMON_PASSWORDS
def _validate_strong_password(value: str) -> str:
"""Pydantic field-validator body shared by Register + ChangePassword.
Constraint = function, not type-level mixin. The two request models
have no "is-a" relationship; they only share the password-strength
rule. Lifting it into a free function lets each model bind it via
``@field_validator(field_name)`` without inheritance gymnastics.
"""
if _password_is_common(value):
raise ValueError("Password is too common; choose a stronger password.")
return value
class RegisterRequest(BaseModel):
"""Request model for user registration."""
email: EmailStr
password: str = Field(..., min_length=8)
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
class ChangePasswordRequest(BaseModel):
"""Request model for password change (also handles setup flow)."""
current_password: str
new_password: str = Field(..., min_length=8)
new_email: EmailStr | None = None
_strong_password = field_validator("new_password")(classmethod(lambda cls, v: _validate_strong_password(v)))
class MessageResponse(BaseModel):
"""Generic message response."""
message: str
# ── Helpers ───────────────────────────────────────────────────────────────
def _set_session_cookie(response: Response, token: str, request: Request) -> None:
"""Set the access_token HttpOnly cookie on the response."""
config = get_auth_config()
is_https = is_secure_request(request)
response.set_cookie(
key="access_token",
value=token,
httponly=True,
secure=is_https,
samesite="lax",
max_age=config.token_expiry_days * 24 * 3600 if is_https else None,
)
# ── Rate Limiting ────────────────────────────────────────────────────────
# In-process dict — not shared across workers.
#
# **Limitation**: with multi-worker deployments (e.g., gunicorn -w N), each
# worker maintains its own lockout table, so an attacker effectively gets
# N × _MAX_LOGIN_ATTEMPTS guesses before being locked out everywhere. For
# production multi-worker setups, replace this with a shared store (Redis,
# database-backed counter) to enforce a true per-IP limit.
_MAX_LOGIN_ATTEMPTS = 5
_LOCKOUT_SECONDS = 300 # 5 minutes
# ip → (fail_count, lock_until_timestamp)
_login_attempts: dict[str, tuple[int, float]] = {}
def _trusted_proxies() -> list:
"""Parse ``AUTH_TRUSTED_PROXIES`` env var into a list of ip_network objects.
Comma-separated CIDR or single-IP entries. Empty / unset = no proxy is
trusted (direct mode). Invalid entries are skipped with a logger warning.
Read live so env-var overrides take effect immediately and tests can
``monkeypatch.setenv`` without poking a module-level cache.
"""
raw = os.getenv("AUTH_TRUSTED_PROXIES", "").strip()
if not raw:
return []
nets = []
for entry in raw.split(","):
entry = entry.strip()
if not entry:
continue
try:
nets.append(ip_network(entry, strict=False))
except ValueError:
logger.warning("AUTH_TRUSTED_PROXIES: ignoring invalid entry %r", entry)
return nets
def _get_client_ip(request: Request) -> str:
"""Extract the real client IP for rate limiting.
Trust model:
- The TCP peer (``request.client.host``) is always the baseline. It is
whatever the kernel reports as the connecting socket unforgeable
by the client itself.
- ``X-Real-IP`` is **only** honored if the TCP peer is in the
``AUTH_TRUSTED_PROXIES`` allowlist (set via env var, comma-separated
CIDR or single IPs). When set, the gateway is assumed to be behind a
reverse proxy (nginx, Cloudflare, ALB, ) that overwrites
``X-Real-IP`` with the original client address.
- With no ``AUTH_TRUSTED_PROXIES`` set, ``X-Real-IP`` is silently
ignored closing the bypass where any client could rotate the
header to dodge per-IP rate limits in dev / direct-gateway mode.
``X-Forwarded-For`` is intentionally NOT used because it is naturally
client-controlled at the *first* hop and the trust chain is harder to
audit per-request.
"""
peer_host = request.client.host if request.client else None
trusted = _trusted_proxies()
if trusted and peer_host:
try:
peer_ip = ip_address(peer_host)
if any(peer_ip in net for net in trusted):
real_ip = request.headers.get("x-real-ip", "").strip()
if real_ip:
return real_ip
except ValueError:
# peer_host wasn't a parseable IP (e.g. "unknown") — fall through
pass
return peer_host or "unknown"
def _check_rate_limit(ip: str) -> None:
"""Raise 429 if the IP is currently locked out."""
record = _login_attempts.get(ip)
if record is None:
return
fail_count, lock_until = record
if fail_count >= _MAX_LOGIN_ATTEMPTS:
if time.time() < lock_until:
raise HTTPException(
status_code=429,
detail="Too many login attempts. Try again later.",
)
del _login_attempts[ip]
_MAX_TRACKED_IPS = 10000
def _record_login_failure(ip: str) -> None:
"""Record a failed login attempt for the given IP."""
# Evict expired lockouts when dict grows too large
if len(_login_attempts) >= _MAX_TRACKED_IPS:
now = time.time()
expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t]
for k in expired:
del _login_attempts[k]
# If still too large, evict cheapest-to-lose half: below-threshold
# IPs (lock_until=0.0) sort first, then earliest-expiring lockouts.
if len(_login_attempts) >= _MAX_TRACKED_IPS:
by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1])
for k, _ in by_time[: len(by_time) // 2]:
del _login_attempts[k]
record = _login_attempts.get(ip)
if record is None:
_login_attempts[ip] = (1, 0.0)
else:
new_count = record[0] + 1
lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0
_login_attempts[ip] = (new_count, lock_until)
def _record_login_success(ip: str) -> None:
"""Clear failure counter for the given IP on successful login."""
_login_attempts.pop(ip, None)
# ── Endpoints ─────────────────────────────────────────────────────────────
@router.post("/login/local", response_model=LoginResponse)
async def login_local(
request: Request,
response: Response,
form_data: OAuth2PasswordRequestForm = Depends(),
):
"""Local email/password login."""
client_ip = _get_client_ip(request)
_check_rate_limit(client_ip)
user = await get_local_provider().authenticate({"email": form_data.username, "password": form_data.password})
if user is None:
_record_login_failure(client_ip)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Incorrect email or password").model_dump(),
)
_record_login_success(client_ip)
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return LoginResponse(
expires_in=get_auth_config().token_expiry_days * 24 * 3600,
needs_setup=user.needs_setup,
)
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def register(request: Request, response: Response, body: RegisterRequest):
"""Register a new user account (always 'user' role).
Admin is auto-created on first boot. This endpoint creates regular users.
Auto-login by setting the session cookie.
"""
try:
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="user")
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already registered").model_dump(),
)
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
@router.post("/logout", response_model=MessageResponse)
async def logout(request: Request, response: Response):
"""Logout current user by clearing the cookie."""
response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax")
return MessageResponse(message="Successfully logged out")
@router.post("/change-password", response_model=MessageResponse)
async def change_password(request: Request, response: Response, body: ChangePasswordRequest):
"""Change password for the currently authenticated user.
Also handles the first-boot setup flow:
- If new_email is provided, updates email (checks uniqueness)
- If user.needs_setup is True and new_email is given, clears needs_setup
- Always increments token_version to invalidate old sessions
- Re-issues session cookie with new token_version
"""
from app.gateway.auth.password import hash_password_async, verify_password_async
user = await get_current_user_from_request(request)
if user.password_hash is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
if not await verify_password_async(body.current_password, user.password_hash):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Current password is incorrect").model_dump())
provider = get_local_provider()
# Update email if provided
if body.new_email is not None:
existing = await provider.get_user_by_email(body.new_email)
if existing and str(existing.id) != str(user.id):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already in use").model_dump())
user.email = body.new_email
# Update password + bump version
user.password_hash = await hash_password_async(body.new_password)
user.token_version += 1
# Clear setup flag if this is the setup flow
if user.needs_setup and body.new_email is not None:
user.needs_setup = False
await provider.update_user(user)
# Re-issue cookie with new token_version
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return MessageResponse(message="Password changed successfully")
@router.get("/me", response_model=UserResponse)
async def get_me(request: Request):
"""Get current authenticated user info."""
user = await get_current_user_from_request(request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
_SETUP_STATUS_COOLDOWN: dict[str, float] = {}
_SETUP_STATUS_COOLDOWN_SECONDS = 60
_MAX_TRACKED_SETUP_STATUS_IPS = 10000
@router.get("/setup-status")
async def setup_status(request: Request):
"""Check if an admin account exists. Returns needs_setup=True when no admin exists."""
client_ip = _get_client_ip(request)
now = time.time()
last_check = _SETUP_STATUS_COOLDOWN.get(client_ip, 0)
elapsed = now - last_check
if elapsed < _SETUP_STATUS_COOLDOWN_SECONDS:
retry_after = max(1, int(_SETUP_STATUS_COOLDOWN_SECONDS - elapsed))
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Setup status check is rate limited",
headers={"Retry-After": str(retry_after)},
)
# Evict stale entries when dict grows too large to bound memory usage.
if len(_SETUP_STATUS_COOLDOWN) >= _MAX_TRACKED_SETUP_STATUS_IPS:
cutoff = now - _SETUP_STATUS_COOLDOWN_SECONDS
stale = [k for k, t in _SETUP_STATUS_COOLDOWN.items() if t < cutoff]
for k in stale:
del _SETUP_STATUS_COOLDOWN[k]
# If still too large after evicting expired entries, remove oldest half.
if len(_SETUP_STATUS_COOLDOWN) >= _MAX_TRACKED_SETUP_STATUS_IPS:
by_time = sorted(_SETUP_STATUS_COOLDOWN.items(), key=lambda kv: kv[1])
for k, _ in by_time[: len(by_time) // 2]:
del _SETUP_STATUS_COOLDOWN[k]
_SETUP_STATUS_COOLDOWN[client_ip] = now
admin_count = await get_local_provider().count_admin_users()
return {"needs_setup": admin_count == 0}
class InitializeAdminRequest(BaseModel):
"""Request model for first-boot admin account creation."""
email: EmailStr
password: str = Field(..., min_length=8)
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
@router.post("/initialize", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def initialize_admin(request: Request, response: Response, body: InitializeAdminRequest):
"""Create the first admin account on initial system setup.
Only callable when no admin exists. Returns 409 Conflict if an admin
already exists.
On success, the admin account is created with ``needs_setup=False`` and
the session cookie is set.
"""
admin_count = await get_local_provider().count_admin_users()
if admin_count > 0:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=AuthErrorResponse(code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED, message="System already initialized").model_dump(),
)
try:
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="admin", needs_setup=False)
except ValueError:
# DB unique-constraint race: another concurrent request beat us.
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=AuthErrorResponse(code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED, message="System already initialized").model_dump(),
)
token = create_access_token(str(user.id), token_version=user.token_version)
_set_session_cookie(response, token, request)
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
@router.get("/oauth/{provider}")
async def oauth_login(provider: str):
"""Initiate OAuth login flow.
Redirects to the OAuth provider's authorization URL.
Currently a placeholder - requires OAuth provider implementation.
"""
if provider not in ["github", "google"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported OAuth provider: {provider}",
)
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="OAuth login not yet implemented",
)
@router.get("/callback/{provider}")
async def oauth_callback(provider: str, code: str, state: str):
"""OAuth callback endpoint.
Handles the OAuth provider's callback after user authorization.
Currently a placeholder.
"""
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="OAuth callback not yet implemented",
)

View File

@ -0,0 +1,188 @@
"""Feedback endpoints — create, list, stats, delete.
Allows users to submit thumbs-up/down feedback on runs,
optionally scoped to a specific message.
"""
from __future__ import annotations
import logging
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_current_user, get_feedback_repo, get_run_store
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["feedback"])
# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------
class FeedbackCreateRequest(BaseModel):
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
comment: str | None = Field(default=None, description="Optional text feedback")
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
class FeedbackUpsertRequest(BaseModel):
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
comment: str | None = Field(default=None, description="Optional text feedback")
class FeedbackResponse(BaseModel):
feedback_id: str
run_id: str
thread_id: str
user_id: str | None = None
message_id: str | None = None
rating: int
comment: str | None = None
created_at: str = ""
class FeedbackStatsResponse(BaseModel):
run_id: str
total: int = 0
positive: int = 0
negative: int = 0
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def upsert_feedback(
thread_id: str,
run_id: str,
body: FeedbackUpsertRequest,
request: Request,
) -> dict[str, Any]:
"""Create or update feedback for a run (idempotent)."""
if body.rating not in (1, -1):
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
user_id = await get_current_user(request)
run_store = get_run_store(request)
run = await run_store.get(run_id)
if run is None:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
if run.get("thread_id") != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
feedback_repo = get_feedback_repo(request)
return await feedback_repo.upsert(
run_id=run_id,
thread_id=thread_id,
rating=body.rating,
user_id=user_id,
comment=body.comment,
)
@router.delete("/{thread_id}/runs/{run_id}/feedback")
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_run_feedback(
thread_id: str,
run_id: str,
request: Request,
) -> dict[str, bool]:
"""Delete the current user's feedback for a run."""
user_id = await get_current_user(request)
feedback_repo = get_feedback_repo(request)
deleted = await feedback_repo.delete_by_run(
thread_id=thread_id,
run_id=run_id,
user_id=user_id,
)
if not deleted:
raise HTTPException(status_code=404, detail="No feedback found for this run")
return {"success": True}
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def create_feedback(
thread_id: str,
run_id: str,
body: FeedbackCreateRequest,
request: Request,
) -> dict[str, Any]:
"""Submit feedback (thumbs-up/down) for a run."""
if body.rating not in (1, -1):
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
user_id = await get_current_user(request)
# Validate run exists and belongs to thread
run_store = get_run_store(request)
run = await run_store.get(run_id)
if run is None:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
if run.get("thread_id") != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
feedback_repo = get_feedback_repo(request)
return await feedback_repo.create(
run_id=run_id,
thread_id=thread_id,
rating=body.rating,
user_id=user_id,
message_id=body.message_id,
comment=body.comment,
)
@router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse])
@require_permission("threads", "read", owner_check=True)
async def list_feedback(
thread_id: str,
run_id: str,
request: Request,
) -> list[dict[str, Any]]:
"""List all feedback for a run."""
feedback_repo = get_feedback_repo(request)
return await feedback_repo.list_by_run(thread_id, run_id)
@router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse)
@require_permission("threads", "read", owner_check=True)
async def feedback_stats(
thread_id: str,
run_id: str,
request: Request,
) -> dict[str, Any]:
"""Get aggregated feedback stats (positive/negative counts) for a run."""
feedback_repo = get_feedback_repo(request)
return await feedback_repo.aggregate_by_run(thread_id, run_id)
@router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}")
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_feedback(
thread_id: str,
run_id: str,
feedback_id: str,
request: Request,
) -> dict[str, bool]:
"""Delete a feedback record."""
feedback_repo = get_feedback_repo(request)
# Verify feedback belongs to the specified thread/run before deleting
existing = await feedback_repo.get(feedback_id)
if existing is None:
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
if existing.get("thread_id") != thread_id or existing.get("run_id") != run_id:
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found in run {run_id}")
deleted = await feedback_repo.delete(feedback_id)
if not deleted:
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
return {"success": True}

View File

@ -13,6 +13,7 @@ from deerflow.agents.memory.updater import (
update_memory_fact,
)
from deerflow.config.memory_config import get_memory_config
from deerflow.runtime.user_context import get_effective_user_id
router = APIRouter(prefix="/api", tags=["memory"])
@ -147,7 +148,7 @@ async def get_memory() -> MemoryResponse:
}
```
"""
memory_data = get_memory_data()
memory_data = get_memory_data(user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
@ -167,7 +168,7 @@ async def reload_memory() -> MemoryResponse:
Returns:
The reloaded memory data.
"""
memory_data = reload_memory_data()
memory_data = reload_memory_data(user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
@ -181,7 +182,7 @@ async def reload_memory() -> MemoryResponse:
async def clear_memory() -> MemoryResponse:
"""Clear all persisted memory data."""
try:
memory_data = clear_memory_data()
memory_data = clear_memory_data(user_id=get_effective_user_id())
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
@ -202,6 +203,7 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
content=request.content,
category=request.category,
confidence=request.confidence,
user_id=get_effective_user_id(),
)
except ValueError as exc:
raise _map_memory_fact_value_error(exc) from exc
@ -221,7 +223,7 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
"""Delete a single fact from memory by fact id."""
try:
memory_data = delete_memory_fact(fact_id)
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id())
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
except OSError as exc:
@ -245,6 +247,7 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
content=request.content,
category=request.category,
confidence=request.confidence,
user_id=get_effective_user_id(),
)
except ValueError as exc:
raise _map_memory_fact_value_error(exc) from exc
@ -265,7 +268,7 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
)
async def export_memory() -> MemoryResponse:
"""Export the current memory data."""
memory_data = get_memory_data()
memory_data = get_memory_data(user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
@ -279,7 +282,7 @@ async def export_memory() -> MemoryResponse:
async def import_memory(request: MemoryResponse) -> MemoryResponse:
"""Import and persist memory data."""
try:
memory_data = import_memory_data(request.model_dump())
memory_data = import_memory_data(request.model_dump(), user_id=get_effective_user_id())
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
@ -337,7 +340,7 @@ async def get_memory_status() -> MemoryStatusResponse:
Combined memory configuration and current data.
"""
config = get_memory_config()
memory_data = get_memory_data()
memory_data = get_memory_data(user_id=get_effective_user_id())
return MemoryStatusResponse(
config=MemoryConfigResponse(

View File

@ -1,7 +1,8 @@
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from deerflow.config import get_app_config
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
router = APIRouter(prefix="/api", tags=["models"])
@ -36,7 +37,7 @@ class ModelsListResponse(BaseModel):
summary="List All Models",
description="Retrieve a list of all available AI models configured in the system.",
)
async def list_models() -> ModelsListResponse:
async def list_models(config: AppConfig = Depends(get_config)) -> ModelsListResponse:
"""List all available models from configuration.
Returns model information suitable for frontend display,
@ -72,7 +73,6 @@ async def list_models() -> ModelsListResponse:
}
```
"""
config = get_app_config()
models = [
ModelResponse(
name=model.name,
@ -96,7 +96,7 @@ async def list_models() -> ModelsListResponse:
summary="Get Model Details",
description="Retrieve detailed information about a specific AI model by its name.",
)
async def get_model(model_name: str) -> ModelResponse:
async def get_model(model_name: str, config: AppConfig = Depends(get_config)) -> ModelResponse:
"""Get a specific model by name.
Args:
@ -118,7 +118,6 @@ async def get_model(model_name: str) -> ModelResponse:
}
```
"""
config = get_app_config()
model = config.get_model_config(model_name)
if model is None:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")

View File

@ -11,10 +11,11 @@ import asyncio
import logging
import uuid
from fastapi import APIRouter, Request
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import StreamingResponse
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.routers.thread_runs import RunCreateRequest
from app.gateway.services import sse_consumer, start_run
from deerflow.runtime import serialize_channel_values
@ -85,3 +86,58 @@ async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
logger.exception("Failed to fetch final state for run %s", record.run_id)
return {"status": record.status.value, "error": record.error}
# ---------------------------------------------------------------------------
# Run-scoped read endpoints
# ---------------------------------------------------------------------------
async def _resolve_run(run_id: str, request: Request) -> dict:
"""Fetch run by run_id with user ownership check. Raises 404 if not found."""
run_store = get_run_store(request)
record = await run_store.get(run_id) # user_id=AUTO filters by contextvar
if record is None:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
return record
@router.get("/{run_id}/messages")
@require_permission("runs", "read")
async def run_messages(
run_id: str,
request: Request,
limit: int = Query(default=50, le=200, ge=1),
before_seq: int | None = Query(default=None),
after_seq: int | None = Query(default=None),
) -> dict:
"""Return paginated messages for a run (cursor-based).
Pagination:
- after_seq: messages with seq > after_seq (forward)
- before_seq: messages with seq < before_seq (backward)
- neither: latest messages
Response: { data: [...], has_more: bool }
"""
run = await _resolve_run(run_id, request)
event_store = get_run_event_store(request)
rows = await event_store.list_messages_by_run(
run["thread_id"],
run_id,
limit=limit + 1,
before_seq=before_seq,
after_seq=after_seq,
)
has_more = len(rows) > limit
data = rows[:limit] if has_more else rows
return {"data": data, "has_more": has_more}
@router.get("/{run_id}/feedback")
@require_permission("runs", "read")
async def run_feedback(run_id: str, request: Request) -> list[dict]:
"""Return all feedback for a run."""
run = await _resolve_run(run_id, request)
feedback_repo = get_feedback_repo(request)
return await feedback_repo.list_by_run(run["thread_id"], run_id)

View File

@ -1,30 +1,20 @@
import errno
import json
import logging
import shutil
from pathlib import Path
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from app.gateway.deps import get_config
from app.gateway.path_utils import resolve_thread_virtual_path
from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async
from deerflow.config.app_config import AppConfig
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from deerflow.skills import Skill, load_skills
from deerflow.skills.installer import SkillAlreadyExistsError, install_skill_from_archive
from deerflow.skills.manager import (
append_history,
atomic_write,
custom_skill_exists,
ensure_custom_skill_is_editable,
get_custom_skill_dir,
get_custom_skill_file,
get_skill_history_file,
read_custom_skill_content,
read_history,
validate_skill_markdown_content,
)
from deerflow.skills import Skill
from deerflow.skills.installer import SkillAlreadyExistsError
from deerflow.skills.security_scanner import scan_skill_content
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.skills.types import SKILL_MD_FILE, SkillCategory
logger = logging.getLogger(__name__)
@ -37,7 +27,7 @@ class SkillResponse(BaseModel):
name: str = Field(..., description="Name of the skill")
description: str = Field(..., description="Description of what the skill does")
license: str | None = Field(None, description="License information")
category: str = Field(..., description="Category of the skill (public or custom)")
category: SkillCategory = Field(..., description="Category of the skill (public or custom)")
enabled: bool = Field(default=True, description="Whether this skill is enabled")
@ -101,9 +91,9 @@ def _skill_to_response(skill: Skill) -> SkillResponse:
summary="List All Skills",
description="Retrieve a list of all available skills from both public and custom directories.",
)
async def list_skills() -> SkillsListResponse:
async def list_skills(config: AppConfig = Depends(get_config)) -> SkillsListResponse:
try:
skills = load_skills(enabled_only=False)
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e:
logger.error(f"Failed to load skills: {e}", exc_info=True)
@ -116,10 +106,10 @@ async def list_skills() -> SkillsListResponse:
summary="Install Skill",
description="Install a skill from a .skill file (ZIP archive) located in the thread's user-data directory.",
)
async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
async def install_skill(request: SkillInstallRequest, config: AppConfig = Depends(get_config)) -> SkillInstallResponse:
try:
skill_file_path = resolve_thread_virtual_path(request.thread_id, request.path)
result = install_skill_from_archive(skill_file_path)
result = await get_or_new_skill_storage(app_config=config).ainstall_skill_from_archive(skill_file_path)
await refresh_skills_system_prompt_cache_async()
return SkillInstallResponse(**result)
except FileNotFoundError as e:
@ -136,9 +126,9 @@ async def install_skill(request: SkillInstallRequest) -> SkillInstallResponse:
@router.get("/skills/custom", response_model=SkillsListResponse, summary="List Custom Skills")
async def list_custom_skills() -> SkillsListResponse:
async def list_custom_skills(config: AppConfig = Depends(get_config)) -> SkillsListResponse:
try:
skills = [skill for skill in load_skills(enabled_only=False) if skill.category == "custom"]
skills = [skill for skill in get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False) if skill.category == SkillCategory.CUSTOM]
return SkillsListResponse(skills=[_skill_to_response(skill) for skill in skills])
except Exception as e:
logger.error("Failed to list custom skills: %s", e, exc_info=True)
@ -146,13 +136,14 @@ async def list_custom_skills() -> SkillsListResponse:
@router.get("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Get Custom Skill Content")
async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
async def get_custom_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
try:
skills = load_skills(enabled_only=False)
skill = next((s for s in skills if s.name == skill_name and s.category == "custom"), None)
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
skill = next((s for s in skills if s.name == skill_name and s.category == SkillCategory.CUSTOM), None)
if skill is None:
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=read_custom_skill_content(skill_name))
return CustomSkillContentResponse(**_skill_to_response(skill).model_dump(), content=get_or_new_skill_storage(app_config=config).read_custom_skill(skill_name))
except HTTPException:
raise
except Exception as e:
@ -161,30 +152,31 @@ async def get_custom_skill(skill_name: str) -> CustomSkillContentResponse:
@router.put("/skills/custom/{skill_name}", response_model=CustomSkillContentResponse, summary="Edit Custom Skill")
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest) -> CustomSkillContentResponse:
async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
try:
ensure_custom_skill_is_editable(skill_name)
validate_skill_markdown_content(skill_name, request.content)
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/SKILL.md")
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
storage = get_or_new_skill_storage(app_config=config)
storage.ensure_custom_skill_is_editable(skill_name)
storage.validate_skill_markdown_content(skill_name, request.content)
scan = await scan_skill_content(request.content, executable=False, location=f"{skill_name}/{SKILL_MD_FILE}", app_config=config)
if scan.decision == "block":
raise HTTPException(status_code=400, detail=f"Security scan blocked the edit: {scan.reason}")
skill_file = get_custom_skill_dir(skill_name) / "SKILL.md"
prev_content = skill_file.read_text(encoding="utf-8")
atomic_write(skill_file, request.content)
append_history(
prev_content = storage.read_custom_skill(skill_name)
storage.write_custom_skill(skill_name, SKILL_MD_FILE, request.content)
storage.append_history(
skill_name,
{
"action": "human_edit",
"author": "human",
"thread_id": None,
"file_path": "SKILL.md",
"file_path": SKILL_MD_FILE,
"prev_content": prev_content,
"new_content": request.content,
"scanner": {"decision": scan.decision, "reason": scan.reason},
},
)
await refresh_skills_system_prompt_cache_async()
return await get_custom_skill(skill_name)
return await get_custom_skill(skill_name, config)
except HTTPException:
raise
except FileNotFoundError as e:
@ -197,29 +189,22 @@ async def update_custom_skill(skill_name: str, request: CustomSkillUpdateRequest
@router.delete("/skills/custom/{skill_name}", summary="Delete Custom Skill")
async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
async def delete_custom_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> dict[str, bool]:
try:
ensure_custom_skill_is_editable(skill_name)
skill_dir = get_custom_skill_dir(skill_name)
prev_content = read_custom_skill_content(skill_name)
try:
append_history(
skill_name,
{
"action": "human_delete",
"author": "human",
"thread_id": None,
"file_path": "SKILL.md",
"prev_content": prev_content,
"new_content": None,
"scanner": {"decision": "allow", "reason": "Deletion requested."},
},
)
except OSError as e:
if not isinstance(e, PermissionError) and e.errno not in {errno.EACCES, errno.EPERM, errno.EROFS}:
raise
logger.warning("Skipping delete history write for custom skill %s due to readonly/permission failure; continuing with skill directory removal: %s", skill_name, e)
shutil.rmtree(skill_dir)
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
storage = get_or_new_skill_storage(app_config=config)
storage.delete_custom_skill(
skill_name,
history_meta={
"action": "human_delete",
"author": "human",
"thread_id": None,
"file_path": SKILL_MD_FILE,
"prev_content": None,
"new_content": None,
"scanner": {"decision": "allow", "reason": "Deletion requested."},
},
)
await refresh_skills_system_prompt_cache_async()
return {"success": True}
except FileNotFoundError as e:
@ -232,11 +217,13 @@ async def delete_custom_skill(skill_name: str) -> dict[str, bool]:
@router.get("/skills/custom/{skill_name}/history", response_model=CustomSkillHistoryResponse, summary="Get Custom Skill History")
async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryResponse:
async def get_custom_skill_history(skill_name: str, config: AppConfig = Depends(get_config)) -> CustomSkillHistoryResponse:
try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
storage = get_or_new_skill_storage(app_config=config)
if not storage.custom_skill_exists(skill_name) and not storage.get_skill_history_file(skill_name).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
return CustomSkillHistoryResponse(history=read_history(skill_name))
return CustomSkillHistoryResponse(history=storage.read_history(skill_name))
except HTTPException:
raise
except Exception as e:
@ -245,38 +232,39 @@ async def get_custom_skill_history(skill_name: str) -> CustomSkillHistoryRespons
@router.post("/skills/custom/{skill_name}/rollback", response_model=CustomSkillContentResponse, summary="Rollback Custom Skill")
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest) -> CustomSkillContentResponse:
async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest, config: AppConfig = Depends(get_config)) -> CustomSkillContentResponse:
try:
if not custom_skill_exists(skill_name) and not get_skill_history_file(skill_name).exists():
storage = get_or_new_skill_storage(app_config=config)
if not storage.custom_skill_exists(skill_name) and not storage.get_skill_history_file(skill_name).exists():
raise HTTPException(status_code=404, detail=f"Custom skill '{skill_name}' not found")
history = read_history(skill_name)
history = storage.read_history(skill_name)
if not history:
raise HTTPException(status_code=400, detail=f"Custom skill '{skill_name}' has no history")
record = history[request.history_index]
target_content = record.get("prev_content")
if target_content is None:
raise HTTPException(status_code=400, detail="Selected history entry has no previous content to roll back to")
validate_skill_markdown_content(skill_name, target_content)
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/SKILL.md")
skill_file = get_custom_skill_file(skill_name)
storage.validate_skill_markdown_content(skill_name, target_content)
scan = await scan_skill_content(target_content, executable=False, location=f"{skill_name}/{SKILL_MD_FILE}", app_config=config)
skill_file = storage.get_custom_skill_file(skill_name)
current_content = skill_file.read_text(encoding="utf-8") if skill_file.exists() else None
history_entry = {
"action": "rollback",
"author": "human",
"thread_id": None,
"file_path": "SKILL.md",
"file_path": SKILL_MD_FILE,
"prev_content": current_content,
"new_content": target_content,
"rollback_from_ts": record.get("ts"),
"scanner": {"decision": scan.decision, "reason": scan.reason},
}
if scan.decision == "block":
append_history(skill_name, history_entry)
storage.append_history(skill_name, history_entry)
raise HTTPException(status_code=400, detail=f"Rollback blocked by security scanner: {scan.reason}")
atomic_write(skill_file, target_content)
append_history(skill_name, history_entry)
storage.write_custom_skill(skill_name, SKILL_MD_FILE, target_content)
storage.append_history(skill_name, history_entry)
await refresh_skills_system_prompt_cache_async()
return await get_custom_skill(skill_name)
return await get_custom_skill(skill_name, config)
except HTTPException:
raise
except IndexError:
@ -296,9 +284,10 @@ async def rollback_custom_skill(skill_name: str, request: SkillRollbackRequest)
summary="Get Skill Details",
description="Retrieve detailed information about a specific skill by its name.",
)
async def get_skill(skill_name: str) -> SkillResponse:
async def get_skill(skill_name: str, config: AppConfig = Depends(get_config)) -> SkillResponse:
try:
skills = load_skills(enabled_only=False)
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
skill = next((s for s in skills if s.name == skill_name), None)
if skill is None:
@ -318,9 +307,10 @@ async def get_skill(skill_name: str) -> SkillResponse:
summary="Update Skill",
description="Update a skill's enabled status by modifying the extensions_config.json file.",
)
async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillResponse:
async def update_skill(skill_name: str, request: SkillUpdateRequest, config: AppConfig = Depends(get_config)) -> SkillResponse:
try:
skills = load_skills(enabled_only=False)
skill_name = skill_name.replace("\r\n", "").replace("\n", "")
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
skill = next((s for s in skills if s.name == skill_name), None)
if skill is None:
@ -346,7 +336,7 @@ async def update_skill(skill_name: str, request: SkillUpdateRequest) -> SkillRes
reload_extensions_config()
await refresh_skills_system_prompt_cache_async()
skills = load_skills(enabled_only=False)
skills = get_or_new_skill_storage(app_config=config).load_skills(enabled_only=False)
updated_skill = next((s for s in skills if s.name == skill_name), None)
if updated_skill is None:

View File

@ -1,10 +1,13 @@
import json
import logging
from fastapi import APIRouter
from fastapi import APIRouter, Depends, Request
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@ -98,12 +101,18 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str:
summary="Generate Follow-up Questions",
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
)
async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse:
if not request.messages:
@require_permission("threads", "read", owner_check=True)
async def generate_suggestions(
thread_id: str,
body: SuggestionsRequest,
request: Request,
config: AppConfig = Depends(get_config),
) -> SuggestionsResponse:
if not body.messages:
return SuggestionsResponse(suggestions=[])
n = request.n
conversation = _format_conversation(request.messages)
n = body.n
conversation = _format_conversation(body.messages)
if not conversation:
return SuggestionsResponse(suggestions=[])
@ -120,8 +129,8 @@ async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> S
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
try:
model = create_chat_model(name=request.model_name, thinking_enabled=False)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
model = create_chat_model(name=body.model_name, thinking_enabled=False, app_config=config)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)], config={"run_name": "suggest_agent"})
raw = _extract_response_text(response.content)
suggestions = _parse_json_string_list(raw) or []
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]

View File

@ -19,7 +19,8 @@ from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel, Field
from app.gateway.deps import get_checkpointer, get_run_manager, get_stream_bridge
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.services import sse_consumer, start_run
from deerflow.runtime import RunRecord, serialize_channel_values
@ -92,6 +93,7 @@ def _record_to_response(record: RunRecord) -> RunResponse:
@router.post("/{thread_id}/runs", response_model=RunResponse)
@require_permission("runs", "create", owner_check=True, require_existing=True)
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
"""Create a background run (returns immediately)."""
record = await start_run(body, thread_id, request)
@ -99,6 +101,7 @@ async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -
@router.post("/{thread_id}/runs/stream")
@require_permission("runs", "create", owner_check=True, require_existing=True)
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
"""Create a run and stream events via SSE.
@ -126,6 +129,7 @@ async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -
@router.post("/{thread_id}/runs/wait", response_model=dict)
@require_permission("runs", "create", owner_check=True, require_existing=True)
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
"""Create a run and block until it completes, returning the final state."""
record = await start_run(body, thread_id, request)
@ -151,6 +155,7 @@ async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) ->
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
@require_permission("runs", "read", owner_check=True)
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
"""List all runs for a thread."""
run_mgr = get_run_manager(request)
@ -159,6 +164,7 @@ async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
@require_permission("runs", "read", owner_check=True)
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
"""Get details of a specific run."""
run_mgr = get_run_manager(request)
@ -169,6 +175,7 @@ async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
@router.post("/{thread_id}/runs/{run_id}/cancel")
@require_permission("runs", "cancel", owner_check=True, require_existing=True)
async def cancel_run(
thread_id: str,
run_id: str,
@ -206,6 +213,7 @@ async def cancel_run(
@router.get("/{thread_id}/runs/{run_id}/join")
@require_permission("runs", "read", owner_check=True)
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
"""Join an existing run's SSE stream."""
bridge = get_stream_bridge(request)
@ -226,6 +234,7 @@ async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingRe
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
@require_permission("runs", "read", owner_check=True)
async def stream_existing_run(
thread_id: str,
run_id: str,
@ -265,3 +274,104 @@ async def stream_existing_run(
"X-Accel-Buffering": "no",
},
)
# ---------------------------------------------------------------------------
# Messages / Events / Token usage endpoints
# ---------------------------------------------------------------------------
@router.get("/{thread_id}/messages")
@require_permission("runs", "read", owner_check=True)
async def list_thread_messages(
thread_id: str,
request: Request,
limit: int = Query(default=50, le=200),
before_seq: int | None = Query(default=None),
after_seq: int | None = Query(default=None),
) -> list[dict]:
"""Return displayable messages for a thread (across all runs), with feedback attached."""
event_store = get_run_event_store(request)
messages = await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
# Attach feedback to the last AI message of each run
feedback_repo = get_feedback_repo(request)
user_id = await get_current_user(request)
feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id)
# Find the last ai_message per run_id
last_ai_per_run: dict[str, int] = {} # run_id -> index in messages list
for i, msg in enumerate(messages):
if msg.get("event_type") == "ai_message":
last_ai_per_run[msg["run_id"]] = i
# Attach feedback field
last_ai_indices = set(last_ai_per_run.values())
for i, msg in enumerate(messages):
if i in last_ai_indices:
run_id = msg["run_id"]
fb = feedback_map.get(run_id)
msg["feedback"] = (
{
"feedback_id": fb["feedback_id"],
"rating": fb["rating"],
"comment": fb.get("comment"),
}
if fb
else None
)
else:
msg["feedback"] = None
return messages
@router.get("/{thread_id}/runs/{run_id}/messages")
@require_permission("runs", "read", owner_check=True)
async def list_run_messages(
thread_id: str,
run_id: str,
request: Request,
limit: int = Query(default=50, le=200, ge=1),
before_seq: int | None = Query(default=None),
after_seq: int | None = Query(default=None),
) -> dict:
"""Return paginated messages for a specific run.
Response: { data: [...], has_more: bool }
"""
event_store = get_run_event_store(request)
rows = await event_store.list_messages_by_run(
thread_id,
run_id,
limit=limit + 1,
before_seq=before_seq,
after_seq=after_seq,
)
has_more = len(rows) > limit
data = rows[:limit] if has_more else rows
return {"data": data, "has_more": has_more}
@router.get("/{thread_id}/runs/{run_id}/events")
@require_permission("runs", "read", owner_check=True)
async def list_run_events(
thread_id: str,
run_id: str,
request: Request,
event_types: str | None = Query(default=None),
limit: int = Query(default=500, le=2000),
) -> list[dict]:
"""Return the full event stream for a run (debug/audit)."""
event_store = get_run_event_store(request)
types = event_types.split(",") if event_types else None
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
@router.get("/{thread_id}/token-usage")
@require_permission("threads", "read", owner_check=True)
async def thread_token_usage(thread_id: str, request: Request) -> dict:
"""Thread-level token usage aggregation."""
run_store = get_run_store(request)
agg = await run_store.aggregate_tokens_by_thread(thread_id)
return {"thread_id": thread_id, **agg}

View File

@ -18,23 +18,35 @@ import uuid
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from app.gateway.deps import get_checkpointer, get_store
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer
from app.gateway.utils import sanitize_log_param
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values
# ---------------------------------------------------------------------------
# Store namespace
# ---------------------------------------------------------------------------
THREADS_NS: tuple[str, ...] = ("threads",)
"""Namespace used by the Store for thread metadata records."""
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["threads"])
# Metadata keys that the server controls; clients are not allowed to set
# them. Pydantic ``@field_validator("metadata")`` strips them on every
# inbound model below so a malicious client cannot reflect a forged
# owner identity through the API surface. Defense-in-depth — the
# row-level invariant is still ``threads_meta.user_id`` populated from
# the auth contextvar; this list closes the metadata-blob echo gap.
_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"})
def _strip_reserved_metadata(metadata: dict[str, Any] | None) -> dict[str, Any]:
"""Return ``metadata`` with server-controlled keys removed."""
if not metadata:
return metadata or {}
return {k: v for k, v in metadata.items() if k not in _SERVER_RESERVED_METADATA_KEYS}
# ---------------------------------------------------------------------------
# Response / request models
# ---------------------------------------------------------------------------
@ -63,8 +75,11 @@ class ThreadCreateRequest(BaseModel):
"""Request body for creating a thread."""
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
_strip_reserved = field_validator("metadata")(classmethod(lambda cls, v: _strip_reserved_metadata(v)))
class ThreadSearchRequest(BaseModel):
"""Request body for searching threads."""
@ -93,6 +108,8 @@ class ThreadPatchRequest(BaseModel):
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata to merge")
_strip_reserved = field_validator("metadata")(classmethod(lambda cls, v: _strip_reserved_metadata(v)))
class ThreadStateUpdateRequest(BaseModel):
"""Request body for updating thread state (human-in-the-loop resume)."""
@ -126,70 +143,25 @@ class ThreadHistoryRequest(BaseModel):
# ---------------------------------------------------------------------------
def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse:
def _delete_thread_data(thread_id: str, paths: Paths | None = None, *, user_id: str | None = None) -> ThreadDeleteResponse:
"""Delete local persisted filesystem data for a thread."""
path_manager = paths or get_paths()
try:
path_manager.delete_thread_dir(thread_id)
path_manager.delete_thread_dir(thread_id, user_id=user_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
except FileNotFoundError:
# Not critical — thread data may not exist on disk
logger.debug("No local thread data to delete for %s", thread_id)
logger.debug("No local thread data to delete for %s", sanitize_log_param(thread_id))
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
except Exception as exc:
logger.exception("Failed to delete thread data for %s", thread_id)
logger.exception("Failed to delete thread data for %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
logger.info("Deleted local thread data for %s", thread_id)
logger.info("Deleted local thread data for %s", sanitize_log_param(thread_id))
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
async def _store_get(store, thread_id: str) -> dict | None:
"""Fetch a thread record from the Store; returns ``None`` if absent."""
item = await store.aget(THREADS_NS, thread_id)
return item.value if item is not None else None
async def _store_put(store, record: dict) -> None:
"""Write a thread record to the Store."""
await store.aput(THREADS_NS, record["thread_id"], record)
async def _store_upsert(store, thread_id: str, *, metadata: dict | None = None, values: dict | None = None) -> None:
"""Create or refresh a thread record in the Store.
On creation the record is written with ``status="idle"``. On update only
``updated_at`` (and optionally ``metadata`` / ``values``) are changed so
that existing fields are preserved.
``values`` carries the agent-state snapshot exposed to the frontend
(currently just ``{"title": "..."}``).
"""
now = time.time()
existing = await _store_get(store, thread_id)
if existing is None:
await _store_put(
store,
{
"thread_id": thread_id,
"status": "idle",
"created_at": now,
"updated_at": now,
"metadata": metadata or {},
"values": values or {},
},
)
else:
val = dict(existing)
val["updated_at"] = now
if metadata:
val.setdefault("metadata", {}).update(metadata)
if values:
val.setdefault("values", {}).update(values)
await _store_put(store, val)
def _derive_thread_status(checkpoint_tuple) -> str:
"""Derive thread status from checkpoint metadata."""
if checkpoint_tuple is None:
@ -215,22 +187,18 @@ def _derive_thread_status(checkpoint_tuple) -> str:
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
"""Delete local persisted filesystem data for a thread.
Cleans DeerFlow-managed thread directories, removes checkpoint data,
and removes the thread record from the Store.
and removes the thread_meta row from the configured ThreadMetaStore
(sqlite or memory).
"""
# Clean local filesystem
response = _delete_thread_data(thread_id)
from app.gateway.deps import get_thread_store
# Remove from Store (best-effort)
store = get_store(request)
if store is not None:
try:
await store.adelete(THREADS_NS, thread_id)
except Exception:
logger.debug("Could not delete store record for thread %s (not critical)", thread_id)
# Clean local filesystem
response = _delete_thread_data(thread_id, user_id=get_effective_user_id())
# Remove checkpoints (best-effort)
checkpointer = getattr(request.app.state, "checkpointer", None)
@ -239,7 +207,15 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
if hasattr(checkpointer, "adelete_thread"):
await checkpointer.adelete_thread(thread_id)
except Exception:
logger.debug("Could not delete checkpoints for thread %s (not critical)", thread_id)
logger.debug("Could not delete checkpoints for thread %s (not critical)", sanitize_log_param(thread_id))
# Remove thread_meta row (best-effort) — required for sqlite backend
# so the deleted thread no longer appears in /threads/search.
try:
thread_store = get_thread_store(request)
await thread_store.delete(thread_id)
except Exception:
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
return response
@ -248,43 +224,40 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse:
"""Create a new thread.
The thread record is written to the Store (for fast listing) and an
empty checkpoint is written to the checkpointer (for state reads).
Writes a thread_meta record (so the thread appears in /threads/search)
and an empty checkpoint (so state endpoints work immediately).
Idempotent: returns the existing record when ``thread_id`` already exists.
"""
store = get_store(request)
from app.gateway.deps import get_thread_store
checkpointer = get_checkpointer(request)
thread_store = get_thread_store(request)
thread_id = body.thread_id or str(uuid.uuid4())
now = time.time()
# ``body.metadata`` is already stripped of server-reserved keys by
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
# Idempotency: return existing record from Store when already present
if store is not None:
existing_record = await _store_get(store, thread_id)
if existing_record is not None:
return ThreadResponse(
thread_id=thread_id,
status=existing_record.get("status", "idle"),
created_at=str(existing_record.get("created_at", "")),
updated_at=str(existing_record.get("updated_at", "")),
metadata=existing_record.get("metadata", {}),
)
# Idempotency: return existing record when already present
existing_record = await thread_store.get(thread_id)
if existing_record is not None:
return ThreadResponse(
thread_id=thread_id,
status=existing_record.get("status", "idle"),
created_at=str(existing_record.get("created_at", "")),
updated_at=str(existing_record.get("updated_at", "")),
metadata=existing_record.get("metadata", {}),
)
# Write thread record to Store
if store is not None:
try:
await _store_put(
store,
{
"thread_id": thread_id,
"status": "idle",
"created_at": now,
"updated_at": now,
"metadata": body.metadata,
},
)
except Exception:
logger.exception("Failed to write thread %s to store", thread_id)
raise HTTPException(status_code=500, detail="Failed to create thread")
# Write thread_meta so the thread appears in /threads/search immediately
try:
await thread_store.create(
thread_id,
assistant_id=getattr(body, "assistant_id", None),
metadata=body.metadata,
)
except Exception:
logger.exception("Failed to write thread_meta for %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to create thread")
# Write an empty checkpoint so state endpoints work immediately
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
@ -301,10 +274,10 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
}
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
except Exception:
logger.exception("Failed to create checkpoint for thread %s", thread_id)
logger.exception("Failed to create checkpoint for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to create thread")
logger.info("Thread created: %s", thread_id)
logger.info("Thread created: %s", sanitize_log_param(thread_id))
return ThreadResponse(
thread_id=thread_id,
status="idle",
@ -318,166 +291,91 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
"""Search and list threads.
Two-phase approach:
**Phase 1 Store (fast path, O(threads))**: returns threads that were
created or run through this Gateway. Store records are tiny metadata
dicts so fetching all of them at once is cheap.
**Phase 2 Checkpointer supplement (lazy migration)**: threads that
were created directly by LangGraph Server (and therefore absent from the
Store) are discovered here by iterating the shared checkpointer. Any
newly found thread is immediately written to the Store so that the next
search skips Phase 2 for that thread the Store converges to a full
index over time without a one-shot migration job.
Delegates to the configured ThreadMetaStore implementation
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
"""
store = get_store(request)
checkpointer = get_checkpointer(request)
from app.gateway.deps import get_thread_store
# -----------------------------------------------------------------------
# Phase 1: Store
# -----------------------------------------------------------------------
merged: dict[str, ThreadResponse] = {}
if store is not None:
try:
items = await store.asearch(THREADS_NS, limit=10_000)
except Exception:
logger.warning("Store search failed — falling back to checkpointer only", exc_info=True)
items = []
for item in items:
val = item.value
merged[val["thread_id"]] = ThreadResponse(
thread_id=val["thread_id"],
status=val.get("status", "idle"),
created_at=str(val.get("created_at", "")),
updated_at=str(val.get("updated_at", "")),
metadata=val.get("metadata", {}),
values=val.get("values", {}),
)
# -----------------------------------------------------------------------
# Phase 2: Checkpointer supplement
# Discovers threads not yet in the Store (e.g. created by LangGraph
# Server) and lazily migrates them so future searches skip this phase.
# -----------------------------------------------------------------------
try:
async for checkpoint_tuple in checkpointer.alist(None):
cfg = getattr(checkpoint_tuple, "config", {})
thread_id = cfg.get("configurable", {}).get("thread_id")
if not thread_id or thread_id in merged:
continue
# Skip sub-graph checkpoints (checkpoint_ns is non-empty for those)
if cfg.get("configurable", {}).get("checkpoint_ns", ""):
continue
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
# Strip LangGraph internal keys from the user-visible metadata dict
user_meta = {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
# Extract state values (title) from the checkpoint's channel_values
checkpoint_data = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint_data.get("channel_values", {})
ckpt_values = {}
if title := channel_values.get("title"):
ckpt_values["title"] = title
thread_resp = ThreadResponse(
thread_id=thread_id,
status=_derive_thread_status(checkpoint_tuple),
created_at=str(ckpt_meta.get("created_at", "")),
updated_at=str(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))),
metadata=user_meta,
values=ckpt_values,
)
merged[thread_id] = thread_resp
# Lazy migration — write to Store so the next search finds it there
if store is not None:
try:
await _store_upsert(store, thread_id, metadata=user_meta, values=ckpt_values or None)
except Exception:
logger.debug("Failed to migrate thread %s to store (non-fatal)", thread_id)
except Exception:
logger.exception("Checkpointer scan failed during thread search")
# Don't raise — return whatever was collected from Store + partial scan
# -----------------------------------------------------------------------
# Phase 3: Filter → sort → paginate
# -----------------------------------------------------------------------
results = list(merged.values())
if body.metadata:
results = [r for r in results if all(r.metadata.get(k) == v for k, v in body.metadata.items())]
if body.status:
results = [r for r in results if r.status == body.status]
results.sort(key=lambda r: r.updated_at, reverse=True)
return results[body.offset : body.offset + body.limit]
repo = get_thread_store(request)
rows = await repo.search(
metadata=body.metadata or None,
status=body.status,
limit=body.limit,
offset=body.offset,
)
return [
ThreadResponse(
thread_id=r["thread_id"],
status=r.get("status", "idle"),
created_at=r.get("created_at", ""),
updated_at=r.get("updated_at", ""),
metadata=r.get("metadata", {}),
values={"title": r["display_name"]} if r.get("display_name") else {},
interrupts={},
)
for r in rows
]
@router.patch("/{thread_id}", response_model=ThreadResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
"""Merge metadata into a thread record."""
store = get_store(request)
if store is None:
raise HTTPException(status_code=503, detail="Store not available")
from app.gateway.deps import get_thread_store
record = await _store_get(store, thread_id)
thread_store = get_thread_store(request)
record = await thread_store.get(thread_id)
if record is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
now = time.time()
updated = dict(record)
updated.setdefault("metadata", {}).update(body.metadata)
updated["updated_at"] = now
# ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``.
try:
await _store_put(store, updated)
await thread_store.update_metadata(thread_id, body.metadata)
except Exception:
logger.exception("Failed to patch thread %s", thread_id)
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to update thread")
# Re-read to get the merged metadata + refreshed updated_at
record = await thread_store.get(thread_id) or record
return ThreadResponse(
thread_id=thread_id,
status=updated.get("status", "idle"),
created_at=str(updated.get("created_at", "")),
updated_at=str(now),
metadata=updated.get("metadata", {}),
status=record.get("status", "idle"),
created_at=str(record.get("created_at", "")),
updated_at=str(record.get("updated_at", "")),
metadata=record.get("metadata", {}),
)
@router.get("/{thread_id}", response_model=ThreadResponse)
@require_permission("threads", "read", owner_check=True)
async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
"""Get thread info.
Reads metadata from the Store and derives the accurate execution
status from the checkpointer. Falls back to the checkpointer alone
for threads that pre-date Store adoption (backward compat).
Reads metadata from the ThreadMetaStore and derives the accurate
execution status from the checkpointer. Falls back to the checkpointer
alone for threads that pre-date ThreadMetaStore adoption (backward compat).
"""
store = get_store(request)
from app.gateway.deps import get_thread_store
thread_store = get_thread_store(request)
checkpointer = get_checkpointer(request)
record: dict | None = None
if store is not None:
record = await _store_get(store, thread_id)
record: dict | None = await thread_store.get(thread_id)
# Derive accurate status from the checkpointer
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
except Exception:
logger.exception("Failed to get checkpoint for thread %s", thread_id)
logger.exception("Failed to get checkpoint for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread")
if record is None and checkpoint_tuple is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
# If the thread exists in the checkpointer but not the store (e.g. legacy
# data), synthesize a minimal store record from the checkpoint metadata.
# If the thread exists in the checkpointer but not in thread_meta (e.g.
# legacy data created before thread_meta adoption), synthesize a minimal
# record from the checkpoint metadata.
if record is None and checkpoint_tuple is not None:
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
record = {
@ -505,7 +403,9 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
)
# ---------------------------------------------------------------------------
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
@require_permission("threads", "read", owner_check=True)
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
"""Get the latest state snapshot for a thread.
@ -518,7 +418,7 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
except Exception:
logger.exception("Failed to get state for thread %s", thread_id)
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None:
@ -542,8 +442,10 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
values = serialize_channel_values(channel_values)
return ThreadStateResponse(
values=serialize_channel_values(channel_values),
values=values,
next=next_tasks,
metadata=metadata,
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
@ -555,15 +457,19 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
"""Update thread state (e.g. for human-in-the-loop resume or title rename).
Writes a new checkpoint that merges *body.values* into the latest
channel values, then syncs any updated ``title`` field back to the Store
so that ``/threads/search`` reflects the change immediately.
channel values, then syncs any updated ``title`` field through the
ThreadMetaStore abstraction so that ``/threads/search`` reflects the
change immediately in both sqlite and memory backends.
"""
from app.gateway.deps import get_thread_store
checkpointer = get_checkpointer(request)
store = get_store(request)
thread_store = get_thread_store(request)
# checkpoint_ns must be present in the config for aput — default to ""
# (the root graph namespace). checkpoint_id is optional; omitting it
@ -580,7 +486,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
try:
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
except Exception:
logger.exception("Failed to get state for thread %s", thread_id)
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None:
@ -614,19 +520,22 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
try:
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
except Exception:
logger.exception("Failed to update state for thread %s", thread_id)
logger.exception("Failed to update state for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to update thread state")
new_checkpoint_id: str | None = None
if isinstance(new_config, dict):
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
# Sync title changes to the Store so /threads/search reflects them immediately.
if store is not None and body.values and "title" in body.values:
try:
await _store_upsert(store, thread_id, values={"title": body.values["title"]})
except Exception:
logger.debug("Failed to sync title to store for thread %s (non-fatal)", thread_id)
# Sync title changes through the ThreadMetaStore abstraction so /threads/search
# reflects them immediately in both sqlite and memory backends.
if body.values and "title" in body.values:
new_title = body.values["title"]
if new_title: # Skip empty strings and None
try:
await thread_store.update_display_name(thread_id, new_title)
except Exception:
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
return ThreadStateResponse(
values=serialize_channel_values(channel_values),
@ -638,8 +547,16 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
@require_permission("threads", "read", owner_check=True)
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
"""Get checkpoint history for a thread."""
"""Get checkpoint history for a thread.
Messages are read from the checkpointer's channel values (the
authoritative source) and serialized via
:func:`~deerflow.runtime.serialization.serialize_channel_values`.
Only the latest (first) checkpoint carries the ``messages`` key to
avoid duplicating them across every entry.
"""
checkpointer = get_checkpointer(request)
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
@ -647,6 +564,7 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
config["configurable"]["checkpoint_id"] = body.before
entries: list[HistoryEntry] = []
is_latest_checkpoint = True
try:
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
ckpt_config = getattr(checkpoint_tuple, "config", {})
@ -661,22 +579,42 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
channel_values = checkpoint.get("channel_values", {})
# Build values from checkpoint channel_values
values: dict[str, Any] = {}
if title := channel_values.get("title"):
values["title"] = title
if thread_data := channel_values.get("thread_data"):
values["thread_data"] = thread_data
# Attach messages only to the latest checkpoint entry.
if is_latest_checkpoint:
messages = channel_values.get("messages")
if messages:
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
is_latest_checkpoint = False
# Derive next tasks
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
# Strip LangGraph internal keys from metadata
user_meta = {k: v for k, v in metadata.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
# Keep step for ordering context
if "step" in metadata:
user_meta["step"] = metadata["step"]
entries.append(
HistoryEntry(
checkpoint_id=checkpoint_id,
parent_checkpoint_id=parent_id,
metadata=metadata,
values=serialize_channel_values(channel_values),
metadata=user_meta,
values=values,
created_at=str(metadata.get("created_at", "")),
next=next_tasks,
)
)
except Exception:
logger.exception("Failed to get history for thread %s", thread_id)
logger.exception("Failed to get history for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread history")
return entries

View File

@ -4,11 +4,14 @@ import logging
import os
import stat
from fastapi import APIRouter, File, HTTPException, UploadFile
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile
from pydantic import BaseModel
from deerflow.config.app_config import get_app_config
from app.gateway.authz import require_permission
from app.gateway.deps import get_config
from deerflow.config.app_config import AppConfig
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox_provider import SandboxProvider, get_sandbox_provider
from deerflow.uploads.manager import (
PathTraversalError,
@ -58,23 +61,22 @@ def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))
def _get_uploads_config_value(key: str, default: object) -> object:
def _get_uploads_config_value(app_config: AppConfig, key: str, default: object) -> object:
"""Read a value from the uploads config, supporting dict and attribute access."""
cfg = get_app_config()
uploads_cfg = getattr(cfg, "uploads", None)
uploads_cfg = getattr(app_config, "uploads", None)
if isinstance(uploads_cfg, dict):
return uploads_cfg.get(key, default)
return getattr(uploads_cfg, key, default)
def _auto_convert_documents_enabled() -> bool:
def _auto_convert_documents_enabled(app_config: AppConfig) -> bool:
"""Return whether automatic host-side document conversion is enabled.
The secure default is disabled unless an operator explicitly opts in via
uploads.auto_convert_documents in config.yaml.
"""
try:
raw = _get_uploads_config_value("auto_convert_documents", False)
raw = _get_uploads_config_value(app_config, "auto_convert_documents", False)
if isinstance(raw, str):
return raw.strip().lower() in {"1", "true", "yes", "on"}
return bool(raw)
@ -83,9 +85,12 @@ def _auto_convert_documents_enabled() -> bool:
@router.post("", response_model=UploadResponse)
@require_permission("threads", "write", owner_check=True, require_existing=False)
async def upload_files(
thread_id: str,
request: Request,
files: list[UploadFile] = File(...),
config: AppConfig = Depends(get_config),
) -> UploadResponse:
"""Upload multiple files to a thread's uploads directory."""
if not files:
@ -95,7 +100,7 @@ async def upload_files(
uploads_dir = ensure_uploads_dir(thread_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
uploaded_files = []
sandbox_provider = get_sandbox_provider()
@ -104,7 +109,7 @@ async def upload_files(
if sync_to_sandbox:
sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id)
auto_convert_documents = _auto_convert_documents_enabled()
auto_convert_documents = _auto_convert_documents_enabled(config)
for file in files:
if not file.filename:
@ -166,7 +171,8 @@ async def upload_files(
@router.get("/list", response_model=dict)
async def list_uploaded_files(thread_id: str) -> dict:
@require_permission("threads", "read", owner_check=True)
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
"""List all files in a thread's uploads directory."""
try:
uploads_dir = get_uploads_dir(thread_id)
@ -176,7 +182,7 @@ async def list_uploaded_files(thread_id: str) -> dict:
enrich_file_listing(result, thread_id)
# Gateway additionally includes the sandbox-relative path.
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id)
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
for f in result["files"]:
f["path"] = str(sandbox_uploads / f["filename"])
@ -184,7 +190,8 @@ async def list_uploaded_files(thread_id: str) -> dict:
@router.delete("/{filename}")
async def delete_uploaded_file(thread_id: str, filename: str) -> dict:
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_uploaded_file(thread_id: str, filename: str, request: Request) -> dict:
"""Delete a file from a thread's uploads directory."""
try:
uploads_dir = get_uploads_dir(thread_id)

View File

@ -11,13 +11,14 @@ import asyncio
import json
import logging
import re
import time
from collections.abc import Mapping
from typing import Any
from fastapi import HTTPException, Request
from langchain_core.messages import HumanMessage
from app.gateway.deps import get_checkpointer, get_run_manager, get_store, get_stream_bridge
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
from app.gateway.utils import sanitize_log_param
from deerflow.runtime import (
END_SENTINEL,
HEARTBEAT_SENTINEL,
@ -97,13 +98,52 @@ def normalize_input(raw_input: dict[str, Any] | None) -> dict[str, Any]:
_DEFAULT_ASSISTANT_ID = "lead_agent"
# Whitelist of run-context keys that the langgraph-compat layer forwards from
# ``body.context`` into the run config. ``config["context"]`` exists in
# LangGraph >=0.6, but these values must be written to both ``configurable``
# (for legacy ``_get_runtime_config`` consumers) and ``context`` because
# LangGraph >=1.1.9 no longer makes ``ToolRuntime.context`` fall back to
# ``configurable`` for consumers like ``setup_agent``.
_CONTEXT_CONFIGURABLE_KEYS: frozenset[str] = frozenset(
{
"model_name",
"mode",
"thinking_enabled",
"reasoning_effort",
"is_plan_mode",
"subagent_enabled",
"max_concurrent_subagents",
"agent_name",
"is_bootstrap",
}
)
def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, Any] | None) -> None:
"""Merge whitelisted keys from ``body.context`` into both ``config['configurable']``
and ``config['context']`` so they are visible to legacy configurable readers and
to LangGraph ``ToolRuntime.context`` consumers (e.g. the ``setup_agent`` tool
see issue #2677)."""
if not context:
return
configurable = config.setdefault("configurable", {})
runtime_context = config.setdefault("context", {})
for key in _CONTEXT_CONFIGURABLE_KEYS:
if key in context:
if isinstance(configurable, dict):
configurable.setdefault(key, context[key])
if isinstance(runtime_context, dict):
runtime_context.setdefault(key, context[key])
def resolve_agent_factory(assistant_id: str | None):
"""Resolve the agent factory callable from config.
Custom agents are implemented as ``lead_agent`` + an ``agent_name``
injected into ``configurable`` see :func:`build_run_config`. All
``assistant_id`` values therefore map to the same factory; the routing
happens inside ``make_lead_agent`` when it reads ``cfg["agent_name"]``.
injected into ``configurable`` or ``context`` see
:func:`build_run_config`. All ``assistant_id`` values therefore map to the
same factory; the routing happens inside ``make_lead_agent`` when it reads
``cfg["agent_name"]``.
"""
from deerflow.agents.lead_agent.agent import make_lead_agent
@ -120,10 +160,12 @@ def build_run_config(
"""Build a RunnableConfig dict for the agent.
When *assistant_id* refers to a custom agent (anything other than
``"lead_agent"`` / ``None``), the name is forwarded as
``configurable["agent_name"]``. ``make_lead_agent`` reads this key to
load the matching ``agents/<name>/SOUL.md`` and per-agent config
without it the agent silently runs as the default lead agent.
``"lead_agent"`` / ``None``), the name is forwarded as ``agent_name`` in
whichever runtime options container is active: ``context`` for
LangGraph >= 0.6.0 requests, otherwise ``configurable``.
``make_lead_agent`` reads this key to load the matching
``agents/<name>/SOUL.md`` and per-agent config without it the agent
silently runs as the default lead agent.
This mirrors the channel manager's ``_resolve_run_params`` logic so that
the LangGraph Platform-compatible HTTP API and the IM channel path behave
@ -142,7 +184,14 @@ def build_run_config(
thread_id,
list(request_config.get("configurable", {}).keys()),
)
config["context"] = request_config["context"]
context_value = request_config["context"]
if context_value is None:
context = {}
elif isinstance(context_value, Mapping):
context = dict(context_value)
else:
raise ValueError("request config 'context' must be a mapping or null.")
config["context"] = context
else:
configurable = {"thread_id": thread_id}
configurable.update(request_config.get("configurable", {}))
@ -154,13 +203,19 @@ def build_run_config(
config["configurable"] = {"thread_id": thread_id}
# Inject custom agent name when the caller specified a non-default assistant.
# Honour an explicit configurable["agent_name"] in the request if already set.
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID and "configurable" in config:
if "agent_name" not in config["configurable"]:
normalized = assistant_id.strip().lower().replace("_", "-")
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
config["configurable"]["agent_name"] = normalized
# Honour an explicit agent_name in the active runtime options container.
if assistant_id and assistant_id != _DEFAULT_ASSISTANT_ID:
normalized = assistant_id.strip().lower().replace("_", "-")
if not normalized or not re.fullmatch(r"[a-z0-9-]+", normalized):
raise ValueError(f"Invalid assistant_id {assistant_id!r}: must contain only letters, digits, and hyphens after normalization.")
if "configurable" in config:
target = config["configurable"]
elif "context" in config:
target = config["context"]
else:
target = config.setdefault("configurable", {})
if target is not None and "agent_name" not in target:
target["agent_name"] = normalized
if metadata:
config.setdefault("metadata", {}).update(metadata)
return config
@ -171,71 +226,6 @@ def build_run_config(
# ---------------------------------------------------------------------------
async def _upsert_thread_in_store(store, thread_id: str, metadata: dict | None) -> None:
"""Create or refresh the thread record in the Store.
Called from :func:`start_run` so that threads created via the stateless
``/runs/stream`` endpoint (which never calls ``POST /threads``) still
appear in ``/threads/search`` results.
"""
# Deferred import to avoid circular import with the threads router module.
from app.gateway.routers.threads import _store_upsert
try:
await _store_upsert(store, thread_id, metadata=metadata)
except Exception:
logger.warning("Failed to upsert thread %s in store (non-fatal)", thread_id)
async def _sync_thread_title_after_run(
run_task: asyncio.Task,
thread_id: str,
checkpointer: Any,
store: Any,
) -> None:
"""Wait for *run_task* to finish, then persist the generated title to the Store.
TitleMiddleware writes the generated title to the LangGraph agent state
(checkpointer) but the Gateway's Store record is not updated automatically.
This coroutine closes that gap by reading the final checkpoint after the
run completes and syncing ``values.title`` into the Store record so that
subsequent ``/threads/search`` responses include the correct title.
Runs as a fire-and-forget :func:`asyncio.create_task`; failures are
logged at DEBUG level and never propagate.
"""
# Wait for the background run task to complete (any outcome).
# asyncio.wait does not propagate task exceptions — it just returns
# when the task is done, cancelled, or failed.
await asyncio.wait({run_task})
# Deferred import to avoid circular import with the threads router module.
from app.gateway.routers.threads import _store_get, _store_put
try:
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
if ckpt_tuple is None:
return
channel_values = ckpt_tuple.checkpoint.get("channel_values", {})
title = channel_values.get("title")
if not title:
return
existing = await _store_get(store, thread_id)
if existing is None:
return
updated = dict(existing)
updated.setdefault("values", {})["title"] = title
updated["updated_at"] = time.time()
await _store_put(store, updated)
logger.debug("Synced title %r for thread %s", title, thread_id)
except Exception:
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id, exc_info=True)
async def start_run(
body: Any,
thread_id: str,
@ -255,8 +245,7 @@ async def start_run(
"""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
checkpointer = get_checkpointer(request)
store = get_store(request)
run_ctx = get_run_context(request)
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
@ -274,37 +263,31 @@ async def start_run(
except UnsupportedStrategyError as exc:
raise HTTPException(status_code=501, detail=str(exc)) from exc
# Ensure the thread is visible in /threads/search, even for threads that
# were never explicitly created via POST /threads (e.g. stateless runs).
store = get_store(request)
if store is not None:
await _upsert_thread_in_store(store, thread_id, body.metadata)
# Upsert thread metadata so the thread appears in /threads/search,
# even for threads that were never explicitly created via POST /threads
# (e.g. stateless runs).
try:
existing = await run_ctx.thread_store.get(thread_id)
if existing is None:
await run_ctx.thread_store.create(
thread_id,
assistant_id=body.assistant_id,
metadata=body.metadata,
)
else:
await run_ctx.thread_store.update_status(thread_id, "running")
except Exception:
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
agent_factory = resolve_agent_factory(body.assistant_id)
graph_input = normalize_input(body.input)
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
# Merge DeerFlow-specific context overrides into configurable.
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
# The ``context`` field is a custom extension for the langgraph-compat layer
# that carries agent configuration (model_name, thinking_enabled, etc.).
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
context = getattr(body, "context", None)
if context:
_CONTEXT_CONFIGURABLE_KEYS = {
"model_name",
"mode",
"thinking_enabled",
"reasoning_effort",
"is_plan_mode",
"subagent_enabled",
"max_concurrent_subagents",
"agent_name",
"is_bootstrap",
}
configurable = config.setdefault("configurable", {})
for key in _CONTEXT_CONFIGURABLE_KEYS:
if key in context:
configurable.setdefault(key, context[key])
merge_run_context_overrides(config, getattr(body, "context", None))
stream_modes = normalize_stream_modes(body.stream_mode)
@ -313,8 +296,7 @@ async def start_run(
bridge,
run_mgr,
record,
checkpointer=checkpointer,
store=store,
ctx=run_ctx,
agent_factory=agent_factory,
graph_input=graph_input,
config=config,
@ -326,11 +308,9 @@ async def start_run(
)
record.task = task
# After the run completes, sync the title generated by TitleMiddleware from
# the checkpointer into the Store record so that /threads/search returns the
# correct title instead of an empty values dict.
if store is not None:
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, checkpointer, store))
# Title sync is handled by worker.py's finally block which reads the
# title from the checkpoint and calls thread_store.update_display_name
# after the run completes.
return record

View File

@ -0,0 +1,6 @@
"""Shared utility helpers for the Gateway layer."""
def sanitize_log_param(value: str) -> str:
"""Strip control characters to prevent log injection."""
return value.replace("\n", "").replace("\r", "").replace("\x00", "")

View File

@ -19,24 +19,70 @@ import asyncio
import logging
from dotenv import load_dotenv
from langchain_core.messages import HumanMessage
from deerflow.agents import make_lead_agent
try:
from prompt_toolkit import PromptSession
from prompt_toolkit.history import InMemoryHistory
_HAS_PROMPT_TOOLKIT = True
except ImportError:
_HAS_PROMPT_TOOLKIT = False
load_dotenv()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
_LOG_FMT = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
_LOG_DATEFMT = "%Y-%m-%d %H:%M:%S"
def _setup_logging(log_level: int = logging.INFO) -> None:
"""Route logs to ``debug.log`` using *log_level* for the initial root/file setup.
This configures the root logger and the ``debug.log`` file handler so logs do
not print on the interactive console. It is idempotent: any pre-existing
handlers on the root logger (e.g. installed by ``logging.basicConfig`` in
transitively imported modules) are removed so the debug session output only
lands in ``debug.log``.
Note: later config-driven logging adjustments may change named logger
verbosity without raising the root logger or file-handler thresholds set
here, so the eventual contents of ``debug.log`` may not be filtered solely by
this function's ``log_level`` argument.
"""
root = logging.root
for h in list(root.handlers):
root.removeHandler(h)
h.close()
root.setLevel(log_level)
file_handler = logging.FileHandler("debug.log", mode="a", encoding="utf-8")
file_handler.setLevel(log_level)
file_handler.setFormatter(logging.Formatter(_LOG_FMT, datefmt=_LOG_DATEFMT))
root.addHandler(file_handler)
async def main():
# Install file logging first so warnings emitted while loading config do not
# leak onto the interactive terminal via Python's lastResort handler.
_setup_logging()
from deerflow.config import get_app_config
from deerflow.config.app_config import apply_logging_level
app_config = get_app_config()
apply_logging_level(app_config.log_level)
# Delay the rest of the deerflow imports until *after* logging is installed
# so that any import-time side effects (e.g. deerflow.agents starts a
# background skill-loader thread on import) emit logs to debug.log instead
# of leaking onto the interactive terminal via Python's lastResort handler.
from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
from deerflow.agents import make_lead_agent
from deerflow.mcp import initialize_mcp_tools
# Initialize MCP tools at startup
try:
from deerflow.mcp import initialize_mcp_tools
await initialize_mcp_tools()
except Exception as e:
print(f"Warning: Failed to initialize MCP tools: {e}")
@ -52,16 +98,27 @@ async def main():
}
}
runtime = Runtime(context={"thread_id": config["configurable"]["thread_id"]})
config["configurable"]["__pregel_runtime"] = runtime
agent = make_lead_agent(config)
session = PromptSession(history=InMemoryHistory()) if _HAS_PROMPT_TOOLKIT else None
print("=" * 50)
print("Lead Agent Debug Mode")
print("Type 'quit' or 'exit' to stop")
print(f"Logs: debug.log (log_level={app_config.log_level})")
if not _HAS_PROMPT_TOOLKIT:
print("Tip: `uv sync --group dev` to enable arrow-key & history support")
print("=" * 50)
while True:
try:
user_input = input("\nYou: ").strip()
if session:
user_input = (await session.prompt_async("\nYou: ")).strip()
else:
user_input = input("\nYou: ").strip()
if not user_input:
continue
if user_input.lower() in ("quit", "exit"):
@ -70,15 +127,15 @@ async def main():
# Invoke the agent
state = {"messages": [HumanMessage(content=user_input)]}
result = await agent.ainvoke(state, config=config, context={"thread_id": "debug-thread-001"})
result = await agent.ainvoke(state, config=config)
# Print the response
if result.get("messages"):
last_message = result["messages"][-1]
print(f"\nAgent: {last_message.content}")
except KeyboardInterrupt:
print("\nInterrupted. Goodbye!")
except (KeyboardInterrupt, EOFError):
print("\nGoodbye!")
break
except Exception as e:
print(f"\nError: {e}")

View File

@ -0,0 +1,77 @@
# Docker Test Gap (Section 七 7.4)
This file documents the only **un-executed** test cases from
`backend/docs/AUTH_TEST_PLAN.md` after the full release validation pass.
## Why this gap exists
The release validation environment (sg_dev: `10.251.229.92`) **does not have
a Docker daemon installed**. The TC-DOCKER cases are container-runtime
behavior tests that need an actual Docker engine to spin up
`docker/docker-compose.yaml` services.
```bash
$ ssh sg_dev "which docker; docker --version"
# (empty)
# bash: docker: command not found
```
All other test plan sections were executed against either:
- The local dev box (Mac, all services running locally), or
- The deployed sg_dev instance (gateway + frontend + nginx via SSH tunnel)
## Cases not executed
| Case | Title | What it covers | Why not run |
|---|---|---|---|
| TC-DOCKER-01 | `users.db` volume persistence | Verify the `DEER_FLOW_HOME` bind mount survives container restart | needs `docker compose up` |
| TC-DOCKER-02 | Session persistence across container restart | `AUTH_JWT_SECRET` env var keeps cookies valid after `docker compose down && up` | needs `docker compose down/up` |
| TC-DOCKER-03 | Per-worker rate limiter divergence | Confirms in-process `_login_attempts` dict doesn't share state across `gunicorn` workers (4 by default in the compose file); known limitation, documented | needs multi-worker container |
| TC-DOCKER-04 | IM channels skip AuthMiddleware | Verify Feishu/Slack/Telegram dispatchers run in-container against `http://langgraph:2024` without going through nginx | needs `docker logs` |
| TC-DOCKER-05 | Admin credentials surfacing | **Updated post-simplify** — was "log scrape", now "0600 credential file in `DEER_FLOW_HOME`". The file-based behavior is already validated by TC-1.1 + TC-UPG-13 on sg_dev (non-Docker), so the only Docker-specific gap is verifying the volume mount carries the file out to the host | needs container + host volume |
| TC-DOCKER-06 | Gateway-mode Docker deploy | `./scripts/deploy.sh --gateway` produces a 3-container topology (no `langgraph` container); same auth flow as standard mode | needs `docker compose --profile gateway` |
## Coverage already provided by non-Docker tests
The **auth-relevant** behavior in each Docker case is already exercised by
the test cases that ran on sg_dev or local:
| Docker case | Auth behavior covered by |
|---|---|
| TC-DOCKER-01 (volume persistence) | TC-REENT-01 on sg_dev (admin row survives gateway restart) — same SQLite file, just no container layer between |
| TC-DOCKER-02 (session persistence) | TC-API-02/03/06 (cookie roundtrip), plus TC-REENT-04 (multi-cookie) — JWT verification is process-state-free, container restart is equivalent to `pkill uvicorn && uv run uvicorn` |
| TC-DOCKER-03 (per-worker rate limit) | TC-GW-04 + TC-REENT-09 (single-worker rate limit + 5min expiry). The cross-worker divergence is an architectural property of the in-memory dict; no auth code path differs |
| TC-DOCKER-04 (IM channels skip auth) | Code-level only: `app/channels/manager.py` uses `langgraph_sdk` directly with no cookie handling. The langgraph_auth handler is bypassed by going through SDK, not HTTP |
| TC-DOCKER-05 (credential surfacing) | TC-1.1 on sg_dev (file at `~/deer-flow/backend/.deer-flow/admin_initial_credentials.txt`, mode 0600, password 22 chars) — the only Docker-unique step is whether the bind mount projects this path onto the host, which is a `docker compose` config check, not a runtime behavior change |
| TC-DOCKER-06 (gateway-mode container) | Section 七 7.2 covered by TC-GW-01..05 + Section 二 (gateway-mode auth flow on sg_dev) — same Gateway code, container is just a packaging change |
## Reproduction steps when Docker becomes available
Anyone with `docker` + `docker compose` installed can reproduce the gap by
running the test plan section verbatim. Pre-flight:
```bash
# Required on the host
docker --version # >=24.x
docker compose version # plugin >=2.x
# Required env var (otherwise sessions reset on every container restart)
echo "AUTH_JWT_SECRET=$(python3 -c 'import secrets; print(secrets.token_urlsafe(32))')" \
>> .env
# Optional: pin DEER_FLOW_HOME to a stable host path
echo "DEER_FLOW_HOME=$HOME/deer-flow-data" >> .env
```
Then run TC-DOCKER-01..06 from the test plan as written.
## Decision log
- **Not blocking the release.** The auth-relevant behavior in every Docker
case has an already-validated equivalent on bare metal. The gap is purely
about *container packaging* details (bind mounts, multi-worker, log
collection), not about whether the auth code paths work.
- **TC-DOCKER-05 was updated in place** in `AUTH_TEST_PLAN.md` to reflect
the post-simplify reality (credentials file → 0600 file, no log leak).
The old "grep 'Password:' in docker logs" expectation would have failed
silently and given a false sense of coverage.

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,129 @@
# Authentication Upgrade Guide
DeerFlow 内置了认证模块。本文档面向从无认证版本升级的用户。
## 核心概念
认证模块采用**始终强制**策略:
- 首次启动时自动创建 admin 账号,随机密码打印到控制台日志
- 认证从一开始就是强制的,无竞争窗口
- 历史对话(升级前创建的 thread自动迁移到 admin 名下
## 升级步骤
### 1. 更新代码
```bash
git pull origin main
cd backend && make install
```
### 2. 首次启动
```bash
make dev
```
控制台会输出:
```
============================================================
Admin account created on first boot
Email: admin@deerflow.dev
Password: aB3xK9mN_pQ7rT2w
Change it after login: Settings → Account
============================================================
```
如果未登录就重启了服务,不用担心——只要 setup 未完成,每次启动都会重置密码并重新打印到控制台。
### 3. 登录
访问 `http://localhost:2026/login`,使用控制台输出的邮箱和密码登录。
### 4. 修改密码
登录后进入 Settings → Account → Change Password。
### 5. 添加用户(可选)
其他用户通过 `/login` 页面注册,自动获得 **user** 角色。每个用户只能看到自己的对话。
## 安全机制
| 机制 | 说明 |
|------|------|
| JWT HttpOnly Cookie | Token 不暴露给 JavaScript防止 XSS 窃取 |
| CSRF Double Submit Cookie | 所有 POST/PUT/DELETE 请求需携带 `X-CSRF-Token` |
| bcrypt 密码哈希 | 密码不以明文存储 |
| 多租户隔离 | 用户只能访问自己的 thread |
| HTTPS 自适应 | 检测 `x-forwarded-proto`,自动设置 `Secure` cookie 标志 |
## 常见操作
### 忘记密码
```bash
cd backend
# 重置 admin 密码
python -m app.gateway.auth.reset_admin
# 重置指定用户密码
python -m app.gateway.auth.reset_admin --email user@example.com
```
会输出新的随机密码。
### 完全重置
删除用户数据库,重启后自动创建新 admin
```bash
rm -f backend/.deer-flow/users.db
# 重启服务,控制台输出新密码
```
## 数据存储
| 文件 | 内容 |
|------|------|
| `.deer-flow/users.db` | SQLite 用户数据库(密码哈希、角色) |
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成临时密钥,重启后 session 失效) |
### 生产环境建议
```bash
# 生成持久化 JWT 密钥,避免重启后所有用户需重新登录
python -c "import secrets; print(secrets.token_urlsafe(32))"
# 将输出添加到 .env
# AUTH_JWT_SECRET=<生成的密钥>
```
## API 端点
| 端点 | 方法 | 说明 |
|------|------|------|
| `/api/v1/auth/login/local` | POST | 邮箱密码登录OAuth2 form |
| `/api/v1/auth/register` | POST | 注册新用户user 角色) |
| `/api/v1/auth/logout` | POST | 登出(清除 cookie |
| `/api/v1/auth/me` | GET | 获取当前用户信息 |
| `/api/v1/auth/change-password` | POST | 修改密码 |
| `/api/v1/auth/setup-status` | GET | 检查 admin 是否存在 |
## 兼容性
- **标准模式**`make dev`完全兼容admin 自动创建
- **Gateway 模式**`make dev-pro`):完全兼容
- **Docker 部署**:完全兼容,`.deer-flow/users.db` 需持久化卷挂载
- **IM 渠道**Feishu/Slack/Telegram通过 LangGraph SDK 通信,不经过认证层
- **DeerFlowClient**(嵌入式):不经过 HTTP不受认证影响
## 故障排查
| 症状 | 原因 | 解决 |
|------|------|------|
| 启动后没看到密码 | admin 已存在(非首次启动) | 用 `reset_admin` 重置,或删 `users.db` |
| 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 |
| 重启后需要重新登录 | `AUTH_JWT_SECRET` 未持久化 | 在 `.env` 中设置固定密钥 |

View File

@ -259,6 +259,8 @@ sandbox:
When you configure `sandbox.mounts`, DeerFlow exposes those `container_path` values in the agent prompt so the agent can discover and operate on mounted directories directly instead of assuming everything must live under `/mnt/user-data`.
For bare-metal Docker sandbox runs that use localhost, DeerFlow binds the sandbox HTTP port to `127.0.0.1` by default so it is not exposed on every host interface. Docker-outside-of-Docker deployments that connect through `host.docker.internal` keep the broad legacy bind for compatibility. Set `DEER_FLOW_SANDBOX_BIND_HOST` explicitly if your deployment needs a different bind address.
### Skills
Configure the skills directory for specialized workflows:
@ -320,6 +322,7 @@ models:
- `NOVITA_API_KEY` - Novita API key (OpenAI-compatible endpoint)
- `TAVILY_API_KEY` - Tavily search API key
- `DEER_FLOW_CONFIG_PATH` - Custom config file path
- `GATEWAY_ENABLE_DOCS` - Set to `false` to disable Swagger UI (`/docs`), ReDoc (`/redoc`), and OpenAPI schema (`/openapi.json`) endpoints (default: `true`)
## Configuration Location

View File

@ -1,343 +0,0 @@
# DeerFlow 后端拆分设计文档Harness + App
> 状态Draft
> 作者DeerFlow Team
> 日期2026-03-13
## 1. 背景与动机
DeerFlow 后端当前是一个单一 Python 包(`src.*`),包含了从底层 agent 编排到上层用户产品的所有代码。随着项目发展,这种结构带来了几个问题:
- **复用困难**其他产品CLI 工具、Slack bot、第三方集成想用 agent 能力,必须依赖整个后端,包括 FastAPI、IM SDK 等不需要的依赖
- **职责模糊**agent 编排逻辑和用户产品逻辑混在同一个 `src/` 下,边界不清晰
- **依赖膨胀**LangGraph Server 运行时不需要 FastAPI/uvicorn/Slack SDK但当前必须安装全部依赖
本文档提出将后端拆分为两部分:**deerflow-harness**(可发布的 agent 框架包)和 **app**(不打包的用户产品代码)。
## 2. 核心概念
### 2.1 Harness线束/框架层)
Harness 是 agent 的构建与编排框架,回答 **"如何构建和运行 agent"** 的问题:
- Agent 工厂与生命周期管理
- Middleware pipeline
- 工具系统(内置工具 + MCP + 社区工具)
- 沙箱执行环境
- 子 agent 委派
- 记忆系统
- 技能加载与注入
- 模型工厂
- 配置系统
**Harness 是一个可发布的 Python 包**`deerflow-harness`),可以独立安装和使用。
**Harness 的设计原则**:对上层应用完全无感知。它不知道也不关心谁在调用它——可以是 Web App、CLI、Slack Bot、或者一个单元测试。
### 2.2 App应用层
App 是面向用户的产品代码,回答 **"如何将 agent 呈现给用户"** 的问题:
- Gateway APIFastAPI REST 接口)
- IM Channels飞书、Slack、Telegram 集成)
- Custom Agent 的 CRUD 管理
- 文件上传/下载的 HTTP 接口
**App 不打包、不发布**,它是 DeerFlow 项目内部的应用代码,直接运行。
**App 依赖 Harness但 Harness 不依赖 App。**
### 2.3 边界划分
| 模块 | 归属 | 说明 |
|------|------|------|
| `config/` | Harness | 配置系统是基础设施 |
| `reflection/` | Harness | 动态模块加载工具 |
| `utils/` | Harness | 通用工具函数 |
| `agents/` | Harness | Agent 工厂、middleware、state、memory |
| `subagents/` | Harness | 子 agent 委派系统 |
| `sandbox/` | Harness | 沙箱执行环境 |
| `tools/` | Harness | 工具注册与发现 |
| `mcp/` | Harness | MCP 协议集成 |
| `skills/` | Harness | 技能加载、解析、定义 schema |
| `models/` | Harness | LLM 模型工厂 |
| `community/` | Harness | 社区工具tavily、jina 等) |
| `client.py` | Harness | 嵌入式 Python 客户端 |
| `gateway/` | App | FastAPI REST API |
| `channels/` | App | IM 平台集成 |
**关于 Custom Agents**agent 定义格式(`config.yaml` + `SOUL.md` schema由 Harness 层的 `config/agents_config.py` 定义但文件的存储、CRUD、发现机制由 App 层的 `gateway/routers/agents.py` 负责。
## 3. 目标架构
### 3.1 目录结构
```
backend/
├── packages/
│ └── harness/
│ ├── pyproject.toml # deerflow-harness 包定义
│ └── deerflow/ # Python 包根import 前缀: deerflow.*
│ ├── __init__.py
│ ├── config/
│ ├── reflection/
│ ├── utils/
│ ├── agents/
│ │ ├── lead_agent/
│ │ ├── middlewares/
│ │ ├── memory/
│ │ ├── checkpointer/
│ │ └── thread_state.py
│ ├── subagents/
│ ├── sandbox/
│ ├── tools/
│ ├── mcp/
│ ├── skills/
│ ├── models/
│ ├── community/
│ └── client.py
├── app/ # 不打包import 前缀: app.*
│ ├── __init__.py
│ ├── gateway/
│ │ ├── __init__.py
│ │ ├── app.py
│ │ ├── config.py
│ │ ├── path_utils.py
│ │ └── routers/
│ └── channels/
│ ├── __init__.py
│ ├── base.py
│ ├── manager.py
│ ├── service.py
│ ├── store.py
│ ├── message_bus.py
│ ├── feishu.py
│ ├── slack.py
│ └── telegram.py
├── pyproject.toml # uv workspace root
├── langgraph.json
├── tests/
├── docs/
└── Makefile
```
### 3.2 Import 规则
两个层使用不同的 import 前缀,职责边界一目了然:
```python
# ---------------------------------------------------------------
# Harness 内部互相引用deerflow.* 前缀)
# ---------------------------------------------------------------
from deerflow.agents import make_lead_agent
from deerflow.models import create_chat_model
from deerflow.config import get_app_config
from deerflow.tools import get_available_tools
# ---------------------------------------------------------------
# App 内部互相引用app.* 前缀)
# ---------------------------------------------------------------
from app.gateway.app import app
from app.gateway.routers.uploads import upload_files
from app.channels.service import start_channel_service
# ---------------------------------------------------------------
# App 调用 Harness单向依赖Harness 永远不 import app
# ---------------------------------------------------------------
from deerflow.agents import make_lead_agent
from deerflow.models import create_chat_model
from deerflow.skills import load_skills
from deerflow.config.extensions_config import get_extensions_config
```
**App 调用 Harness 示例 — Gateway 中启动 agent**
```python
# app/gateway/routers/chat.py
from deerflow.agents.lead_agent.agent import make_lead_agent
from deerflow.models import create_chat_model
from deerflow.config import get_app_config
async def create_chat_session(thread_id: str, model_name: str):
config = get_app_config()
model = create_chat_model(name=model_name)
agent = make_lead_agent(config=...)
# ... 使用 agent 处理用户消息
```
**App 调用 Harness 示例 — Channel 中查询 skills**
```python
# app/channels/manager.py
from deerflow.skills import load_skills
from deerflow.agents.memory.updater import get_memory_data
def handle_status_command():
skills = load_skills(enabled_only=True)
memory = get_memory_data()
return f"Skills: {len(skills)}, Memory facts: {len(memory.get('facts', []))}"
```
**禁止方向**Harness 代码中绝不能出现 `from app.``import app.`
### 3.3 为什么 App 不打包
| 方面 | 打包(放 packages/ 下) | 不打包(放 backend/app/ |
|------|------------------------|--------------------------|
| 命名空间 | 需要 pkgutil `extend_path` 合并,或独立前缀 | 天然独立,`app.*` vs `deerflow.*` |
| 发布需求 | 没有——App 是项目内部代码 | 不需要 pyproject.toml |
| 复杂度 | 需要管理两个包的构建、版本、依赖声明 | 直接运行,零额外配置 |
| 运行方式 | `pip install deerflow-app` | `PYTHONPATH=. uvicorn app.gateway.app:app` |
App 的唯一消费者是 DeerFlow 项目自身,没有独立发布的需求。放在 `backend/app/` 下作为普通 Python 包,通过 `PYTHONPATH` 或 editable install 让 Python 找到即可。
### 3.4 依赖关系
```
┌─────────────────────────────────────┐
│ app/ (不打包,直接运行) │
│ ├── fastapi, uvicorn │
│ ├── slack-sdk, lark-oapi, ... │
│ └── import deerflow.* │
└──────────────┬──────────────────────┘
┌─────────────────────────────────────┐
│ deerflow-harness (可发布的包) │
│ ├── langgraph, langchain │
│ ├── markitdown, pydantic, ... │
│ └── 零 app 依赖 │
└─────────────────────────────────────┘
```
**依赖分类**
| 分类 | 依赖包 |
|------|--------|
| Harness only | agent-sandbox, langchain*, langgraph*, markdownify, markitdown, pydantic, pyyaml, readabilipy, tavily-python, firecrawl-py, tiktoken, ddgs, duckdb, httpx, kubernetes, dotenv |
| App only | fastapi, uvicorn, sse-starlette, python-multipart, lark-oapi, slack-sdk, python-telegram-bot, markdown-to-mrkdwn |
| Shared | langgraph-sdkchannels 用 HTTP client, pydantic, httpx |
### 3.5 Workspace 配置
`backend/pyproject.toml`workspace root
```toml
[project]
name = "deer-flow"
version = "0.1.0"
requires-python = ">=3.12"
dependencies = ["deerflow-harness"]
[dependency-groups]
dev = ["pytest>=8.0.0", "ruff>=0.14.11"]
# App 的额外依赖fastapi 等)也声明在 workspace root因为 app 不打包
app = ["fastapi", "uvicorn", "sse-starlette", "python-multipart"]
channels = ["lark-oapi", "slack-sdk", "python-telegram-bot"]
[tool.uv.workspace]
members = ["packages/harness"]
[tool.uv.sources]
deerflow-harness = { workspace = true }
```
## 4. 当前的跨层依赖问题
在拆分之前,需要先解决 `client.py` 中两处从 harness 到 app 的反向依赖:
### 4.1 `_validate_skill_frontmatter`
```python
# client.py — harness 导入了 app 层代码
from src.gateway.routers.skills import _validate_skill_frontmatter
```
**解决方案**:将该函数提取到 `deerflow/skills/validation.py`。这是一个纯逻辑函数(解析 YAML frontmatter、校验字段与 FastAPI 无关。
### 4.2 `CONVERTIBLE_EXTENSIONS` + `convert_file_to_markdown`
```python
# client.py — harness 导入了 app 层代码
from src.gateway.routers.uploads import CONVERTIBLE_EXTENSIONS, convert_file_to_markdown
```
**解决方案**:将它们提取到 `deerflow/utils/file_conversion.py`。仅依赖 `markitdown` + `pathlib`,是通用工具函数。
## 5. 基础设施变更
### 5.1 LangGraph Server
LangGraph Server 只需要 harness 包。`langgraph.json` 更新:
```json
{
"dependencies": ["./packages/harness"],
"graphs": {
"lead_agent": "deerflow.agents:make_lead_agent"
},
"checkpointer": {
"path": "./packages/harness/deerflow/agents/checkpointer/async_provider.py:make_checkpointer"
}
}
```
### 5.2 Gateway API
```bash
# serve.sh / Makefile
# PYTHONPATH 包含 backend/ 根目录,使 app.* 和 deerflow.* 都能被找到
PYTHONPATH=. uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001
```
### 5.3 Nginx
无需变更(只做 URL 路由,不涉及 Python 模块路径)。
### 5.4 Docker
Dockerfile 中的 module 引用从 `src.` 改为 `deerflow.` / `app.``COPY` 命令需覆盖 `packages/``app/` 目录。
## 6. 实施计划
分 3 个 PR 递进执行:
### PR 1提取共享工具函数Low Risk
1. 创建 `src/skills/validation.py`,从 `gateway/routers/skills.py` 提取 `_validate_skill_frontmatter`
2. 创建 `src/utils/file_conversion.py`,从 `gateway/routers/uploads.py` 提取文件转换逻辑
3. 更新 `client.py``gateway/routers/skills.py``gateway/routers/uploads.py` 的 import
4. 运行全部测试确认无回归
### PR 2Rename + 物理拆分High Risk原子操作
1. 创建 `packages/harness/` 目录,创建 `pyproject.toml`
2. `git mv` 将 harness 相关模块从 `src/` 移入 `packages/harness/deerflow/`
3. `git mv` 将 app 相关模块从 `src/` 移入 `app/`
4. 全局替换 import
- harness 模块:`src.*``deerflow.*`(所有 `.py` 文件、`langgraph.json`、测试、文档)
- app 模块:`src.gateway.*``app.gateway.*``src.channels.*``app.channels.*`
5. 更新 workspace root `pyproject.toml`
6. 更新 `langgraph.json``Makefile``Dockerfile`
7. `uv sync` + 全部测试 + 手动验证服务启动
### PR 3边界检查 + 文档Low Risk
1. 添加 lint 规则:检查 harness 不 import app 模块
2. 更新 `CLAUDE.md``README.md`
## 7. 风险与缓解
| 风险 | 影响 | 缓解措施 |
|------|------|----------|
| 全局 rename 误伤 | 字符串中的 `src` 被错误替换 | 正则精确匹配 `\bsrc\.`review diff |
| LangGraph Server 找不到模块 | 服务启动失败 | `langgraph.json``dependencies` 指向正确的 harness 包路径 |
| App 的 `PYTHONPATH` 缺失 | Gateway/Channel 启动 import 报错 | Makefile/Docker 统一设置 `PYTHONPATH=.` |
| `config.yaml` 中的 `use` 字段引用旧路径 | 运行时模块解析失败 | `config.yaml` 中的 `use` 字段同步更新为 `deerflow.*` |
| 测试中 `sys.path` 混乱 | 测试失败 | 用 editable install`uv sync`)确保 deerflow 可导入,`conftest.py` 中添加 `app/``sys.path` |
## 8. 未来演进
- **独立发布**harness 可以发布到内部 PyPI让其他项目直接 `pip install deerflow-harness`
- **插件化 App**:不同的 appweb、CLI、bot可以各自独立都依赖同一个 harness
- **更细粒度拆分**:如果 harness 内部模块继续增长,可以进一步拆分(如 `deerflow-sandbox``deerflow-mcp`

View File

@ -45,6 +45,41 @@ Example:
}
```
## Custom Tool Interceptors
You can register custom interceptors that run before every MCP tool call. This is useful for injecting per-request headers (e.g., user auth tokens from the LangGraph execution context), logging, or metrics.
Declare interceptors in `extensions_config.json` using the `mcpInterceptors` field:
```json
{
"mcpInterceptors": [
"my_package.mcp.auth:build_auth_interceptor"
],
"mcpServers": { ... }
}
```
Each entry is a Python import path in `module:variable` format (resolved via `resolve_variable`). The variable must be a **no-arg builder function** that returns an async interceptor compatible with `MultiServerMCPClient`s `tool_interceptors` interface, or `None` to skip.
Example interceptor that injects auth headers from LangGraph metadata:
```python
def build_auth_interceptor():
async def interceptor(request, handler):
from langgraph.config import get_config
metadata = get_config().get("metadata", {})
headers = dict(request.headers or {})
if token := metadata.get("auth_token"):
headers["X-Auth-Token"] = token
return await handler(request.override(headers=headers))
return interceptor
```
- A single string value is accepted and normalized to a one-element list.
- Invalid paths or builder failures are logged as warnings without blocking other interceptors.
- The builder return value must be `callable`; non-callable values are skipped with a warning.
## How It Works
MCP servers expose tools that are automatically discovered and integrated into DeerFlows agent system at runtime. Once enabled, these tools become available to agents without additional code changes.

View File

@ -124,7 +124,7 @@ title:
# checkpointer.py
from langgraph.checkpoint.sqlite import SqliteSaver
checkpointer = SqliteSaver.from_conn_string("checkpoints.db")
checkpointer = SqliteSaver.from_conn_string("deerflow.db")
```
```json

View File

@ -41,6 +41,13 @@ summarization:
# Custom summary prompt (optional)
summary_prompt: null
# Tool names treated as skill file reads for skill rescue
skill_file_read_tool_names:
- read_file
- read
- view
- cat
```
### Configuration Options
@ -125,6 +132,26 @@ keep:
- **Default**: `null` (uses LangChain's default prompt)
- **Description**: Custom prompt template for generating summaries. The prompt should guide the model to extract the most important context.
#### `preserve_recent_skill_count`
- **Type**: Integer (≥ 0)
- **Default**: `5`
- **Description**: Number of most-recently-loaded skill files (tool results whose tool name is in `skill_file_read_tool_names` and whose target path is under `skills.container_path`, e.g. `/mnt/skills/...`) that are rescued from summarization. Prevents the agent from losing skill instructions after compression. Set to `0` to disable skill rescue entirely.
#### `preserve_recent_skill_tokens`
- **Type**: Integer (≥ 0)
- **Default**: `25000`
- **Description**: Total token budget reserved for rescued skill reads. Once this budget is exhausted, older skill bundles are allowed to be summarized.
#### `preserve_recent_skill_tokens_per_skill`
- **Type**: Integer (≥ 0)
- **Default**: `5000`
- **Description**: Per-skill token cap. Any individual skill read whose tool result exceeds this size is not rescued (it falls through to the summarizer like ordinary content).
#### `skill_file_read_tool_names`
- **Type**: List of strings
- **Default**: `["read_file", "read", "view", "cat"]`
- **Description**: Tool names treated as skill file reads during summarization rescue. A tool call is only eligible for skill rescue when its name appears in this list and its target path is under `skills.container_path`.
**Default Prompt Behavior:**
The default LangChain prompt instructs the model to:
- Extract highest quality/most relevant context
@ -147,6 +174,7 @@ The default LangChain prompt instructs the model to:
- A single summary message is added
- Recent messages are preserved
6. **AI/Tool Pair Protection**: The system ensures AI messages and their corresponding tool messages stay together
7. **Skill Rescue**: Before the summary is generated, the most recently loaded skill files (tool results whose tool name is in `skill_file_read_tool_names` and whose target path is under `skills.container_path`) are lifted out of the summarization set and prepended to the preserved tail. Selection walks newest-first under three budgets: `preserve_recent_skill_count`, `preserve_recent_skill_tokens`, and `preserve_recent_skill_tokens_per_skill`. The triggering AIMessage and all of its paired ToolMessages move together so tool_call ↔ tool_result pairing stays intact.
### Token Counting

View File

@ -8,7 +8,10 @@
"graphs": {
"lead_agent": "deerflow.agents:make_lead_agent"
},
"auth": {
"path": "./app/gateway/langgraph_auth.py:auth"
},
"checkpointer": {
"path": "./packages/harness/deerflow/agents/checkpointer/async_provider.py:make_checkpointer"
"path": "./packages/harness/deerflow/runtime/checkpointer/async_provider.py:make_checkpointer"
}
}

View File

@ -1,4 +1,3 @@
from .checkpointer import get_checkpointer, make_checkpointer, reset_checkpointer
from .factory import create_deerflow_agent
from .features import Next, Prev, RuntimeFeatures
from .lead_agent import make_lead_agent
@ -18,7 +17,4 @@ __all__ = [
"make_lead_agent",
"SandboxState",
"ThreadState",
"get_checkpointer",
"reset_checkpointer",
"make_checkpointer",
]

View File

@ -254,9 +254,11 @@ def _assemble_from_features(
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
chain.append(ViewImageMiddleware())
from deerflow.tools.builtins import view_image_tool
extra_tools.append(view_image_tool)
if feat.sandbox is not False:
from deerflow.tools.builtins import view_image_tool
extra_tools.append(view_image_tool)
# --- [11] Subagent ---
if feat.subagent is not False:

View File

@ -18,7 +18,7 @@ from deerflow.agents.middlewares.tool_error_handling_middleware import build_lea
from deerflow.agents.middlewares.view_image_middleware import ViewImageMiddleware
from deerflow.agents.thread_state import ThreadState
from deerflow.config.agents_config import load_agent_config, validate_agent_name
from deerflow.config.app_config import get_app_config
from deerflow.config.app_config import AppConfig, get_app_config
from deerflow.config.memory_config import get_memory_config
from deerflow.config.summarization_config import get_summarization_config
from deerflow.models import create_chat_model
@ -26,9 +26,18 @@ from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
def _resolve_model_name(requested_model_name: str | None = None) -> str:
def _get_runtime_config(config: RunnableConfig) -> dict:
"""Merge legacy configurable options with LangGraph runtime context."""
cfg = dict(config.get("configurable", {}) or {})
context = config.get("context", {}) or {}
if isinstance(context, dict):
cfg.update(context)
return cfg
def _resolve_model_name(requested_model_name: str | None = None, *, app_config: AppConfig | None = None) -> str:
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
app_config = get_app_config()
app_config = app_config or get_app_config()
default_model_name = app_config.models[0].name if app_config.models else None
if default_model_name is None:
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
@ -41,7 +50,7 @@ def _resolve_model_name(requested_model_name: str | None = None) -> str:
return default_model_name
def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None:
def _create_summarization_middleware(*, app_config: AppConfig | None = None) -> DeerFlowSummarizationMiddleware | None:
"""Create and configure the summarization middleware from config."""
config = get_summarization_config()
@ -59,13 +68,15 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
# Prepare keep parameter
keep = config.keep.to_tuple()
# Prepare model parameter
# Prepare model parameter.
# Bind "middleware:summarize" tag so RunJournal identifies these LLM calls
# as middleware rather than lead_agent (SummarizationMiddleware is a
# LangChain built-in, so we tag the model at creation time).
if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False)
model = create_chat_model(name=config.model_name, thinking_enabled=False, app_config=app_config)
else:
# Use a lightweight model for summarization to save costs
# Falls back to default model if not explicitly specified
model = create_chat_model(thinking_enabled=False)
model = create_chat_model(thinking_enabled=False, app_config=app_config)
model = model.with_config(tags=["middleware:summarize"])
# Prepare kwargs
kwargs = {
@ -84,7 +95,25 @@ def _create_summarization_middleware() -> DeerFlowSummarizationMiddleware | None
if get_memory_config().enabled:
hooks.append(memory_flush_hook)
return DeerFlowSummarizationMiddleware(**kwargs, before_summarization=hooks)
# The logic below relies on two assumptions holding true: this factory is
# the sole entry point for DeerFlowSummarizationMiddleware, and the runtime
# config is not expected to change after startup.
try:
resolved_app_config = app_config or get_app_config()
skills_container_path = resolved_app_config.skills.container_path or "/mnt/skills"
except Exception:
logger.exception("Failed to resolve skills container path; falling back to default")
skills_container_path = "/mnt/skills"
return DeerFlowSummarizationMiddleware(
**kwargs,
skills_container_path=skills_container_path,
skill_file_read_tool_names=config.skill_file_read_tool_names,
before_summarization=hooks,
preserve_recent_skill_count=config.preserve_recent_skill_count,
preserve_recent_skill_tokens=config.preserve_recent_skill_tokens,
preserve_recent_skill_tokens_per_skill=config.preserve_recent_skill_tokens_per_skill,
)
def _create_todo_list_middleware(is_plan_mode: bool) -> TodoMiddleware | None:
@ -212,7 +241,14 @@ Being proactive with task management demonstrates thoroughness and ensures all r
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
# ToolErrorHandlingMiddleware should be before ClarificationMiddleware to convert tool exceptions to ToolMessages
# ClarificationMiddleware should be last to intercept clarification requests after model calls
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None, custom_middlewares: list[AgentMiddleware] | None = None):
def _build_middlewares(
config: RunnableConfig,
model_name: str | None,
agent_name: str | None = None,
custom_middlewares: list[AgentMiddleware] | None = None,
*,
app_config: AppConfig | None = None,
):
"""Build middleware chain based on runtime configuration.
Args:
@ -223,21 +259,23 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
Returns:
List of middleware instances.
"""
middlewares = build_lead_runtime_middlewares(lazy_init=True)
resolved_app_config = app_config or get_app_config()
middlewares = build_lead_runtime_middlewares(app_config=resolved_app_config, lazy_init=True)
# Add summarization middleware if enabled
summarization_middleware = _create_summarization_middleware()
summarization_middleware = _create_summarization_middleware(app_config=resolved_app_config)
if summarization_middleware is not None:
middlewares.append(summarization_middleware)
# Add TodoList middleware if plan mode is enabled
is_plan_mode = config.get("configurable", {}).get("is_plan_mode", False)
cfg = _get_runtime_config(config)
is_plan_mode = cfg.get("is_plan_mode", False)
todo_list_middleware = _create_todo_list_middleware(is_plan_mode)
if todo_list_middleware is not None:
middlewares.append(todo_list_middleware)
# Add TokenUsageMiddleware when token_usage tracking is enabled
if get_app_config().token_usage.enabled:
if resolved_app_config.token_usage.enabled:
middlewares.append(TokenUsageMiddleware())
# Add TitleMiddleware
@ -248,21 +286,20 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
# Add ViewImageMiddleware only if the current model supports vision.
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
app_config = get_app_config()
model_config = app_config.get_model_config(model_name) if model_name else None
model_config = resolved_app_config.get_model_config(model_name) if model_name else None
if model_config is not None and model_config.supports_vision:
middlewares.append(ViewImageMiddleware())
# Add DeferredToolFilterMiddleware to hide deferred tool schemas from model binding
if app_config.tool_search.enabled:
if resolved_app_config.tool_search.enabled:
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
middlewares.append(DeferredToolFilterMiddleware())
# Add SubagentLimitMiddleware to truncate excess parallel task calls
subagent_enabled = config.get("configurable", {}).get("subagent_enabled", False)
subagent_enabled = cfg.get("subagent_enabled", False)
if subagent_enabled:
max_concurrent_subagents = config.get("configurable", {}).get("max_concurrent_subagents", 3)
max_concurrent_subagents = cfg.get("max_concurrent_subagents", 3)
middlewares.append(SubagentLimitMiddleware(max_concurrent=max_concurrent_subagents))
# LoopDetectionMiddleware — detect and break repetitive tool call loops
@ -278,11 +315,17 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_nam
def make_lead_agent(config: RunnableConfig):
"""LangGraph graph factory; keep the signature compatible with LangGraph Server."""
return _make_lead_agent(config, app_config=get_app_config())
def _make_lead_agent(config: RunnableConfig, *, app_config: AppConfig):
# Lazy import to avoid circular dependency
from deerflow.tools import get_available_tools
from deerflow.tools.builtins import setup_agent
cfg = config.get("configurable", {})
cfg = _get_runtime_config(config)
resolved_app_config = app_config
thinking_enabled = cfg.get("thinking_enabled", True)
reasoning_effort = cfg.get("reasoning_effort", None)
@ -298,10 +341,9 @@ def make_lead_agent(config: RunnableConfig):
agent_model_name = agent_config.model if agent_config and agent_config.model else None
# Final model name resolution: request → agent config → global default, with fallback for unknown names
model_name = _resolve_model_name(requested_model_name or agent_model_name)
model_name = _resolve_model_name(requested_model_name or agent_model_name, app_config=resolved_app_config)
app_config = get_app_config()
model_config = app_config.get_model_config(model_name)
model_config = resolved_app_config.get_model_config(model_name)
if model_config is None:
raise ValueError("No chat model could be resolved. Please configure at least one model in config.yaml or provide a valid 'model_name'/'model' in the request.")
@ -333,26 +375,41 @@ def make_lead_agent(config: RunnableConfig):
"is_plan_mode": is_plan_mode,
"subagent_enabled": subagent_enabled,
"tool_groups": agent_config.tool_groups if agent_config else None,
"available_skills": ["bootstrap"] if is_bootstrap else (agent_config.skills if agent_config and agent_config.skills is not None else None),
}
)
if is_bootstrap:
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) + [setup_agent],
middleware=_build_middlewares(config, model_name=model_name),
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"])),
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, app_config=resolved_app_config),
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled, app_config=resolved_app_config) + [setup_agent],
middleware=_build_middlewares(config, model_name=model_name, app_config=resolved_app_config),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
available_skills=set(["bootstrap"]),
app_config=resolved_app_config,
),
state_schema=ThreadState,
)
# Default lead agent (unchanged behavior)
return create_agent(
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort),
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled),
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name),
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort, app_config=resolved_app_config),
tools=get_available_tools(
model_name=model_name,
groups=agent_config.tool_groups if agent_config else None,
subagent_enabled=subagent_enabled,
app_config=resolved_app_config,
),
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name, app_config=resolved_app_config),
system_prompt=apply_prompt_template(
subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name, available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None
subagent_enabled=subagent_enabled,
max_concurrent_subagents=max_concurrent_subagents,
agent_name=agent_name,
available_skills=set(agent_config.skills) if agent_config and agent_config.skills is not None else None,
app_config=resolved_app_config,
),
state_schema=ThreadState,
)

View File

@ -1,14 +1,20 @@
from __future__ import annotations
import asyncio
import logging
import threading
from datetime import datetime
from functools import lru_cache
from typing import TYPE_CHECKING
from deerflow.config.agents_config import load_agent_soul
from deerflow.skills import load_skills
from deerflow.skills.types import Skill
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.skills.types import Skill, SkillCategory
from deerflow.subagents import get_available_subagent_names
if TYPE_CHECKING:
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
_ENABLED_SKILLS_REFRESH_WAIT_TIMEOUT_SECONDS = 5.0
@ -20,7 +26,7 @@ _enabled_skills_refresh_event = threading.Event()
def _load_enabled_skills_sync() -> list[Skill]:
return list(load_skills(enabled_only=True))
return list(get_or_new_skill_storage().load_skills(enabled_only=True))
def _start_enabled_skills_refresh_thread() -> None:
@ -111,8 +117,21 @@ def _get_enabled_skills():
return []
def _skill_mutability_label(category: str) -> str:
return "[custom, editable]" if category == "custom" else "[built-in]"
def _get_enabled_skills_for_config(app_config: AppConfig | None = None) -> list[Skill]:
"""Return enabled skills using the caller's config source.
When a concrete ``app_config`` is supplied, bypass the global enabled-skills
cache so the skill list and skill paths are resolved from the same config
object. This keeps request-scoped config injection consistent even while the
release branch still supports global fallback paths.
"""
if app_config is None:
return _get_enabled_skills()
return list(get_or_new_skill_storage(app_config=app_config).load_skills(enabled_only=True))
def _skill_mutability_label(category: SkillCategory | str) -> str:
return "[custom, editable]" if category == SkillCategory.CUSTOM else "[built-in]"
def clear_skills_system_prompt_cache() -> None:
@ -123,31 +142,6 @@ async def refresh_skills_system_prompt_cache_async() -> None:
await asyncio.to_thread(_invalidate_enabled_skills_cache().wait)
def _reset_skills_system_prompt_cache_state() -> None:
global _enabled_skills_cache, _enabled_skills_refresh_active, _enabled_skills_refresh_version
_get_cached_skills_prompt_section.cache_clear()
with _enabled_skills_lock:
_enabled_skills_cache = None
_enabled_skills_refresh_active = False
_enabled_skills_refresh_version = 0
_enabled_skills_refresh_event.clear()
def _refresh_enabled_skills_cache() -> None:
"""Backward-compatible test helper for direct synchronous reload."""
try:
skills = _load_enabled_skills_sync()
except Exception:
logger.exception("Failed to load enabled skills for prompt injection")
skills = []
with _enabled_skills_lock:
_enabled_skills_cache = skills
_enabled_skills_refresh_active = False
_enabled_skills_refresh_event.set()
def _build_skill_evolution_section(skill_evolution_enabled: bool) -> str:
if not skill_evolution_enabled:
return ""
@ -164,6 +158,36 @@ Skip simple one-off tasks.
"""
def _build_available_subagents_description(available_names: list[str], bash_available: bool) -> str:
"""Dynamically build subagent type descriptions from registry.
Mirrors Codex's pattern where agent_type_description is dynamically generated
from all registered roles, so the LLM knows about every available type.
"""
# Built-in descriptions (kept for backward compatibility with existing prompt quality)
builtin_descriptions = {
"general-purpose": "For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.",
"bash": (
"For command execution (git, build, test, deploy operations)" if bash_available else "Not available in the current sandbox configuration. Use direct file/web tools or switch to AioSandboxProvider for isolated shell access."
),
}
# Lazy import moved outside loop to avoid repeated import overhead
from deerflow.subagents.registry import get_subagent_config
lines = []
for name in available_names:
if name in builtin_descriptions:
lines.append(f"- **{name}**: {builtin_descriptions[name]}")
else:
config = get_subagent_config(name)
if config is not None:
desc = config.description.split("\n")[0].strip() # First line only for brevity
lines.append(f"- **{name}**: {desc}")
return "\n".join(lines)
def _build_subagent_section(max_concurrent: int) -> str:
"""Build the subagent system prompt section with dynamic concurrency limit.
@ -174,13 +198,12 @@ def _build_subagent_section(max_concurrent: int) -> str:
Formatted subagent section string.
"""
n = max_concurrent
bash_available = "bash" in get_available_subagent_names()
available_subagents = (
"- **general-purpose**: For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.\n- **bash**: For command execution (git, build, test, deploy operations)"
if bash_available
else "- **general-purpose**: For ANY non-trivial task - web research, code exploration, file operations, analysis, etc.\n"
"- **bash**: Not available in the current sandbox configuration. Use direct file/web tools or switch to AioSandboxProvider for isolated shell access."
)
available_names = get_available_subagent_names()
bash_available = "bash" in available_names
# Dynamically build subagent type descriptions from registry (aligned with Codex's
# agent_type_description pattern where all registered roles are listed in the tool spec).
available_subagents = _build_available_subagents_description(available_names, bash_available)
direct_tool_examples = "bash, ls, read_file, web_search, etc." if bash_available else "ls, read_file, web_search, etc."
direct_execution_example = (
'# User asks: "Run the tests"\n# Thinking: Cannot decompose into parallel sub-tasks\n# → Execute directly\n\nbash("npm test") # Direct execution, not task()'
@ -519,12 +542,13 @@ def _get_memory_context(agent_name: str | None = None) -> str:
try:
from deerflow.agents.memory import format_memory_for_injection, get_memory_data
from deerflow.config.memory_config import get_memory_config
from deerflow.runtime.user_context import get_effective_user_id
config = get_memory_config()
if not config.enabled or not config.injection_enabled:
return ""
memory_data = get_memory_data(agent_name)
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
if not memory_content.strip():
@ -571,14 +595,14 @@ You have access to skills that provide optimized workflows for specific tasks. E
</skill_system>"""
def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
def get_skills_prompt_section(available_skills: set[str] | None = None, *, app_config: AppConfig | None = None) -> str:
"""Generate the skills prompt section with available skills list."""
skills = _get_enabled_skills()
skills = _get_enabled_skills_for_config(app_config)
try:
from deerflow.config import get_app_config
config = get_app_config()
config = app_config or get_app_config()
container_base_path = config.skills.container_path
skill_evolution_enabled = config.skill_evolution.enabled
except Exception:
@ -607,7 +631,7 @@ def get_agent_soul(agent_name: str | None) -> str:
return ""
def get_deferred_tools_prompt_section() -> str:
def get_deferred_tools_prompt_section(*, app_config: AppConfig | None = None) -> str:
"""Generate <available-deferred-tools> block for the system prompt.
Lists only deferred tool names so the agent knows what exists
@ -619,7 +643,8 @@ def get_deferred_tools_prompt_section() -> str:
try:
from deerflow.config import get_app_config
if not get_app_config().tool_search.enabled:
config = app_config or get_app_config()
if not config.tool_search.enabled:
return ""
except Exception:
return ""
@ -652,12 +677,13 @@ def _build_acp_section() -> str:
)
def _build_custom_mounts_section() -> str:
def _build_custom_mounts_section(*, app_config: AppConfig | None = None) -> str:
"""Build a prompt section for explicitly configured sandbox mounts."""
try:
from deerflow.config import get_app_config
mounts = get_app_config().sandbox.mounts or []
config = app_config or get_app_config()
mounts = config.sandbox.mounts or []
except Exception:
logger.exception("Failed to load configured sandbox mounts for the lead-agent prompt")
return ""
@ -674,7 +700,14 @@ def _build_custom_mounts_section() -> str:
return f"\n**Custom Mounted Directories:**\n{mounts_list}\n- If the user needs files outside `/mnt/user-data`, use these absolute container paths directly when they match the requested directory"
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3, *, agent_name: str | None = None, available_skills: set[str] | None = None) -> str:
def apply_prompt_template(
subagent_enabled: bool = False,
max_concurrent_subagents: int = 3,
*,
agent_name: str | None = None,
available_skills: set[str] | None = None,
app_config: AppConfig | None = None,
) -> str:
# Get memory context
memory_context = _get_memory_context(agent_name)
@ -701,14 +734,14 @@ def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagen
)
# Get skills section
skills_section = get_skills_prompt_section(available_skills)
skills_section = get_skills_prompt_section(available_skills, app_config=app_config)
# Get deferred tools section (tool_search)
deferred_tools_section = get_deferred_tools_prompt_section()
deferred_tools_section = get_deferred_tools_prompt_section(app_config=app_config)
# Build ACP agent section only if ACP agents are configured
acp_section = _build_acp_section()
custom_mounts_section = _build_custom_mounts_section()
custom_mounts_section = _build_custom_mounts_section(app_config=app_config)
acp_and_mounts_section = "\n".join(section for section in (acp_section, custom_mounts_section) if section)
# Format the prompt with dynamic skills and memory

View File

@ -20,6 +20,7 @@ class ConversationContext:
messages: list[Any]
timestamp: datetime = field(default_factory=lambda: datetime.now(UTC))
agent_name: str | None = None
user_id: str | None = None
correction_detected: bool = False
reinforcement_detected: bool = False
@ -44,6 +45,7 @@ class MemoryUpdateQueue:
thread_id: str,
messages: list[Any],
agent_name: str | None = None,
user_id: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> None:
@ -53,6 +55,9 @@ class MemoryUpdateQueue:
thread_id: The thread ID.
messages: The conversation messages.
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
user_id: The user ID captured at enqueue time. Stored in ConversationContext so it
survives the threading.Timer boundary (ContextVar does not propagate across
raw threads).
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
"""
@ -65,6 +70,7 @@ class MemoryUpdateQueue:
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
user_id=user_id,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
@ -77,6 +83,7 @@ class MemoryUpdateQueue:
thread_id: str,
messages: list[Any],
agent_name: str | None = None,
user_id: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
) -> None:
@ -90,6 +97,7 @@ class MemoryUpdateQueue:
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
user_id=user_id,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
@ -103,6 +111,7 @@ class MemoryUpdateQueue:
thread_id: str,
messages: list[Any],
agent_name: str | None,
user_id: str | None,
correction_detected: bool,
reinforcement_detected: bool,
) -> None:
@ -116,6 +125,7 @@ class MemoryUpdateQueue:
thread_id=thread_id,
messages=messages,
agent_name=agent_name,
user_id=user_id,
correction_detected=merged_correction_detected,
reinforcement_detected=merged_reinforcement_detected,
)
@ -176,6 +186,7 @@ class MemoryUpdateQueue:
agent_name=context.agent_name,
correction_detected=context.correction_detected,
reinforcement_detected=context.reinforcement_detected,
user_id=context.user_id,
)
if success:
logger.info("Memory updated successfully for thread %s", context.thread_id)

View File

@ -44,17 +44,17 @@ class MemoryStorage(abc.ABC):
"""Abstract base class for memory storage providers."""
@abc.abstractmethod
def load(self, agent_name: str | None = None) -> dict[str, Any]:
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Load memory data for the given agent."""
pass
@abc.abstractmethod
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Force reload memory data for the given agent."""
pass
@abc.abstractmethod
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
"""Save memory data for the given agent."""
pass
@ -64,9 +64,9 @@ class FileMemoryStorage(MemoryStorage):
def __init__(self):
"""Initialize the file memory storage."""
# Per-agent memory cache: keyed by agent_name (None = global)
# Per-user/agent memory cache: keyed by (user_id, agent_name) tuple (None = global)
# Value: (memory_data, file_mtime)
self._memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
self._memory_cache: dict[tuple[str | None, str | None], tuple[dict[str, Any], float | None]] = {}
# Guards all reads and writes to _memory_cache across concurrent callers.
self._cache_lock = threading.Lock()
@ -81,21 +81,29 @@ class FileMemoryStorage(MemoryStorage):
if not AGENT_NAME_PATTERN.match(agent_name):
raise ValueError(f"Invalid agent name {agent_name!r}: names must match {AGENT_NAME_PATTERN.pattern}")
def _get_memory_file_path(self, agent_name: str | None = None) -> Path:
def _get_memory_file_path(self, agent_name: str | None = None, *, user_id: str | None = None) -> Path:
"""Get the path to the memory file."""
if user_id is not None:
if agent_name is not None:
self._validate_agent_name(agent_name)
return get_paths().user_agent_memory_file(user_id, agent_name)
config = get_memory_config()
if config.storage_path and Path(config.storage_path).is_absolute():
return Path(config.storage_path)
return get_paths().user_memory_file(user_id)
# Legacy: no user_id
if agent_name is not None:
self._validate_agent_name(agent_name)
return get_paths().agent_memory_file(agent_name)
config = get_memory_config()
if config.storage_path:
p = Path(config.storage_path)
return p if p.is_absolute() else get_paths().base_dir / p
return get_paths().memory_file
def _load_memory_from_file(self, agent_name: str | None = None) -> dict[str, Any]:
def _load_memory_from_file(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Load memory data from file."""
file_path = self._get_memory_file_path(agent_name)
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
if not file_path.exists():
return create_empty_memory()
@ -108,9 +116,14 @@ class FileMemoryStorage(MemoryStorage):
logger.warning("Failed to load memory file: %s", e)
return create_empty_memory()
def load(self, agent_name: str | None = None) -> dict[str, Any]:
@staticmethod
def _cache_key(agent_name: str | None = None, *, user_id: str | None = None) -> tuple[str | None, str | None]:
return (user_id, agent_name)
def load(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Load memory data (cached with file modification time check)."""
file_path = self._get_memory_file_path(agent_name)
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
cache_key = self._cache_key(agent_name, user_id=user_id)
try:
current_mtime = file_path.stat().st_mtime if file_path.exists() else None
@ -118,21 +131,22 @@ class FileMemoryStorage(MemoryStorage):
current_mtime = None
with self._cache_lock:
cached = self._memory_cache.get(agent_name)
cached = self._memory_cache.get(cache_key)
if cached is not None and cached[1] == current_mtime:
return cached[0]
memory_data = self._load_memory_from_file(agent_name)
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, current_mtime)
self._memory_cache[cache_key] = (memory_data, current_mtime)
return memory_data
def reload(self, agent_name: str | None = None) -> dict[str, Any]:
def reload(self, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Reload memory data from file, forcing cache invalidation."""
file_path = self._get_memory_file_path(agent_name)
memory_data = self._load_memory_from_file(agent_name)
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
memory_data = self._load_memory_from_file(agent_name, user_id=user_id)
cache_key = self._cache_key(agent_name, user_id=user_id)
try:
mtime = file_path.stat().st_mtime if file_path.exists() else None
@ -140,12 +154,13 @@ class FileMemoryStorage(MemoryStorage):
mtime = None
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, mtime)
self._memory_cache[cache_key] = (memory_data, mtime)
return memory_data
def save(self, memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
def save(self, memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
"""Save memory data to file and update cache."""
file_path = self._get_memory_file_path(agent_name)
file_path = self._get_memory_file_path(agent_name, user_id=user_id)
cache_key = self._cache_key(agent_name, user_id=user_id)
try:
file_path.parent.mkdir(parents=True, exist_ok=True)
@ -166,7 +181,7 @@ class FileMemoryStorage(MemoryStorage):
mtime = None
with self._cache_lock:
self._memory_cache[agent_name] = (memory_data, mtime)
self._memory_cache[cache_key] = (memory_data, mtime)
logger.info("Memory saved to %s", file_path)
return True
except OSError as e:

View File

@ -9,7 +9,6 @@ import logging
import math
import re
import uuid
from collections.abc import Awaitable
from typing import Any
from deerflow.agents.memory.prompt import (
@ -26,6 +25,12 @@ from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
# Thread pool for offloading sync memory updates when called from an async
# context. Unlike the previous asyncio.run() approach, this runs *sync*
# model.invoke() calls — no event loop is created, so the langchain async
# httpx client pool (globally cached via @lru_cache) is never touched and
# cross-loop connection reuse is impossible.
_SYNC_MEMORY_UPDATER_EXECUTOR = concurrent.futures.ThreadPoolExecutor(
max_workers=4,
thread_name_prefix="memory-updater-sync",
@ -38,27 +43,28 @@ def _create_empty_memory() -> dict[str, Any]:
return create_empty_memory()
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> bool:
"""Backward-compatible wrapper around the configured memory storage save path."""
return get_memory_storage().save(memory_data, agent_name)
return get_memory_storage().save(memory_data, agent_name, user_id=user_id)
def get_memory_data(agent_name: str | None = None) -> dict[str, Any]:
def get_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Get the current memory data via storage provider."""
return get_memory_storage().load(agent_name)
return get_memory_storage().load(agent_name, user_id=user_id)
def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]:
def reload_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Reload memory data via storage provider."""
return get_memory_storage().reload(agent_name)
return get_memory_storage().reload(agent_name, user_id=user_id)
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None) -> dict[str, Any]:
def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Persist imported memory data via storage provider.
Args:
memory_data: Full memory payload to persist.
agent_name: If provided, imports into per-agent memory.
user_id: If provided, scopes memory to a specific user.
Returns:
The saved memory data after storage normalization.
@ -67,15 +73,15 @@ def import_memory_data(memory_data: dict[str, Any], agent_name: str | None = Non
OSError: If persisting the imported memory fails.
"""
storage = get_memory_storage()
if not storage.save(memory_data, agent_name):
if not storage.save(memory_data, agent_name, user_id=user_id):
raise OSError("Failed to save imported memory data")
return storage.load(agent_name)
return storage.load(agent_name, user_id=user_id)
def clear_memory_data(agent_name: str | None = None) -> dict[str, Any]:
def clear_memory_data(agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Clear all stored memory data and persist an empty structure."""
cleared_memory = create_empty_memory()
if not _save_memory_to_file(cleared_memory, agent_name):
if not _save_memory_to_file(cleared_memory, agent_name, user_id=user_id):
raise OSError("Failed to save cleared memory data")
return cleared_memory
@ -92,6 +98,8 @@ def create_memory_fact(
category: str = "context",
confidence: float = 0.5,
agent_name: str | None = None,
*,
user_id: str | None = None,
) -> dict[str, Any]:
"""Create a new fact and persist the updated memory data."""
normalized_content = content.strip()
@ -101,7 +109,7 @@ def create_memory_fact(
normalized_category = category.strip() or "context"
validated_confidence = _validate_confidence(confidence)
now = utc_now_iso_z()
memory_data = get_memory_data(agent_name)
memory_data = get_memory_data(agent_name, user_id=user_id)
updated_memory = dict(memory_data)
facts = list(memory_data.get("facts", []))
facts.append(
@ -116,15 +124,15 @@ def create_memory_fact(
)
updated_memory["facts"] = facts
if not _save_memory_to_file(updated_memory, agent_name):
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
raise OSError("Failed to save memory data after creating fact")
return updated_memory
def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str, Any]:
def delete_memory_fact(fact_id: str, agent_name: str | None = None, *, user_id: str | None = None) -> dict[str, Any]:
"""Delete a fact by its id and persist the updated memory data."""
memory_data = get_memory_data(agent_name)
memory_data = get_memory_data(agent_name, user_id=user_id)
facts = memory_data.get("facts", [])
updated_facts = [fact for fact in facts if fact.get("id") != fact_id]
if len(updated_facts) == len(facts):
@ -133,7 +141,7 @@ def delete_memory_fact(fact_id: str, agent_name: str | None = None) -> dict[str,
updated_memory = dict(memory_data)
updated_memory["facts"] = updated_facts
if not _save_memory_to_file(updated_memory, agent_name):
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
raise OSError(f"Failed to save memory data after deleting fact '{fact_id}'")
return updated_memory
@ -145,9 +153,11 @@ def update_memory_fact(
category: str | None = None,
confidence: float | None = None,
agent_name: str | None = None,
*,
user_id: str | None = None,
) -> dict[str, Any]:
"""Update an existing fact and persist the updated memory data."""
memory_data = get_memory_data(agent_name)
memory_data = get_memory_data(agent_name, user_id=user_id)
updated_memory = dict(memory_data)
updated_facts: list[dict[str, Any]] = []
found = False
@ -174,7 +184,7 @@ def update_memory_fact(
updated_memory["facts"] = updated_facts
if not _save_memory_to_file(updated_memory, agent_name):
if not _save_memory_to_file(updated_memory, agent_name, user_id=user_id):
raise OSError(f"Failed to save memory data after updating fact '{fact_id}'")
return updated_memory
@ -217,39 +227,6 @@ def _extract_text(content: Any) -> str:
return str(content)
def _run_async_update_sync(coro: Awaitable[bool]) -> bool:
"""Run an async memory update from sync code, including nested-loop contexts."""
handed_off = False
try:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None and loop.is_running():
future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit(asyncio.run, coro)
handed_off = True
return future.result()
handed_off = True
return asyncio.run(coro)
except Exception:
if not handed_off:
close = getattr(coro, "close", None)
if callable(close):
try:
close()
except Exception:
logger.debug(
"Failed to close un-awaited memory update coroutine",
exc_info=True,
)
logger.exception("Failed to run async memory update from sync context")
return False
# Matches sentences that describe a file-upload *event* rather than general
# file-related work. Deliberately narrow to avoid removing legitimate facts
# such as "User works with CSV files" or "prefers PDF export".
@ -344,13 +321,14 @@ class MemoryUpdater:
agent_name: str | None,
correction_detected: bool,
reinforcement_detected: bool,
user_id: str | None = None,
) -> tuple[dict[str, Any], str] | None:
"""Load memory and build the update prompt for a conversation."""
config = get_memory_config()
if not config.enabled or not messages:
return None
current_memory = get_memory_data(agent_name)
current_memory = get_memory_data(agent_name, user_id=user_id)
conversation_text = format_conversation_for_update(messages)
if not conversation_text.strip():
return None
@ -372,6 +350,7 @@ class MemoryUpdater:
response_content: Any,
thread_id: str | None,
agent_name: str | None,
user_id: str | None = None,
) -> bool:
"""Parse the model response, apply updates, and persist memory."""
response_text = _extract_text(response_content).strip()
@ -385,7 +364,7 @@ class MemoryUpdater:
# cannot corrupt the still-cached original object reference.
updated_memory = self._apply_updates(copy.deepcopy(current_memory), update_data, thread_id)
updated_memory = _strip_upload_mentions_from_memory(updated_memory)
return get_memory_storage().save(updated_memory, agent_name)
return get_memory_storage().save(updated_memory, agent_name, user_id=user_id)
async def aupdate_memory(
self,
@ -394,28 +373,63 @@ class MemoryUpdater:
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
user_id: str | None = None,
) -> bool:
"""Update memory asynchronously based on conversation messages."""
"""Update memory asynchronously by delegating to the sync path.
Uses ``asyncio.to_thread`` to run the *sync* ``model.invoke()`` path
in a worker thread so no second event loop is created and the
langchain async httpx client pool (shared with the lead agent) is
never touched. This eliminates the cross-loop connection-reuse bug
described in issue #2615.
"""
return await asyncio.to_thread(
self._do_update_memory_sync,
messages=messages,
thread_id=thread_id,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
user_id=user_id,
)
def _do_update_memory_sync(
self,
messages: list[Any],
thread_id: str | None = None,
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
user_id: str | None = None,
) -> bool:
"""Pure-sync memory update using ``model.invoke()``.
Uses the *sync* LLM call path so no event loop is created. This
guarantees that the langchain provider's globally cached async
httpx ``AsyncClient`` / connection pool (the one shared with the
lead agent) is never touched no cross-loop connection reuse is
possible.
"""
try:
prepared = await asyncio.to_thread(
self._prepare_update_prompt,
prepared = self._prepare_update_prompt(
messages=messages,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
user_id=user_id,
)
if prepared is None:
return False
current_memory, prompt = prepared
model = self._get_model()
response = await model.ainvoke(prompt)
return await asyncio.to_thread(
self._finalize_update,
response = model.invoke(prompt, config={"run_name": "memory_agent"})
return self._finalize_update(
current_memory=current_memory,
response_content=response.content,
thread_id=thread_id,
agent_name=agent_name,
user_id=user_id,
)
except json.JSONDecodeError as e:
logger.warning("Failed to parse LLM response for memory update: %s", e)
@ -431,8 +445,18 @@ class MemoryUpdater:
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
user_id: str | None = None,
) -> bool:
"""Synchronously update memory via the async updater path.
"""Synchronously update memory using the sync LLM path.
Uses ``model.invoke()`` (sync HTTP) which operates on a completely
separate connection pool from the async ``AsyncClient`` shared by
the lead agent. This eliminates the cross-loop connection-reuse
bug described in issue #2615.
When called from within a running event loop (e.g. from a LangGraph
node), the blocking sync call is offloaded to a thread pool so the
caller's loop is not blocked.
Args:
messages: List of conversation messages.
@ -440,18 +464,39 @@ class MemoryUpdater:
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
user_id: If provided, scopes memory to a specific user.
Returns:
True if update was successful, False otherwise.
"""
return _run_async_update_sync(
self.aupdate_memory(
messages=messages,
thread_id=thread_id,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop is not None and loop.is_running():
try:
future = _SYNC_MEMORY_UPDATER_EXECUTOR.submit(
self._do_update_memory_sync,
messages=messages,
thread_id=thread_id,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
user_id=user_id,
)
return future.result()
except Exception:
logger.exception("Failed to offload memory update to executor")
return False
return self._do_update_memory_sync(
messages=messages,
thread_id=thread_id,
agent_name=agent_name,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
user_id=user_id,
)
def _apply_updates(
@ -547,6 +592,7 @@ def update_memory_from_conversation(
agent_name: str | None = None,
correction_detected: bool = False,
reinforcement_detected: bool = False,
user_id: str | None = None,
) -> bool:
"""Convenience function to update memory from a conversation.
@ -556,9 +602,10 @@ def update_memory_from_conversation(
agent_name: If provided, updates per-agent memory. If None, updates global memory.
correction_detected: Whether recent turns include an explicit correction signal.
reinforcement_detected: Whether recent turns include a positive reinforcement signal.
user_id: If provided, scopes memory to a specific user.
Returns:
True if successful, False otherwise.
"""
updater = MemoryUpdater()
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected)
return updater.update_memory(messages, thread_id, agent_name, correction_detected, reinforcement_detected, user_id=user_id)

View File

@ -16,6 +16,9 @@ from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
from langchain_core.messages import ToolMessage
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
logger = logging.getLogger(__name__)
@ -35,7 +38,7 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
if not registry:
return request
deferred_names = {e.name for e in registry.entries}
deferred_names = registry.deferred_names
active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names]
if len(active_tools) < len(request.tools):
@ -43,6 +46,28 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
return request.override(tools=active_tools)
def _blocked_tool_message(self, request: ToolCallRequest) -> ToolMessage | None:
from deerflow.tools.builtins.tool_search import get_deferred_registry
registry = get_deferred_registry()
if not registry:
return None
tool_name = str(request.tool_call.get("name") or "")
if not tool_name:
return None
if not registry.contains(tool_name):
return None
tool_call_id = str(request.tool_call.get("id") or "missing_tool_call_id")
return ToolMessage(
content=(f"Error: Tool '{tool_name}' is deferred and has not been promoted yet. Call tool_search first to expose and promote this tool's schema, then retry."),
tool_call_id=tool_call_id,
name=tool_name,
status="error",
)
@override
def wrap_model_call(
self,
@ -51,6 +76,17 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
) -> ModelCallResult:
return handler(self._filter_tools(request))
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
blocked = self._blocked_tool_message(request)
if blocked is not None:
return blocked
return handler(request)
@override
async def awrap_model_call(
self,
@ -58,3 +94,14 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
return await handler(self._filter_tools(request))
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
blocked = self._blocked_tool_message(request)
if blocked is not None:
return blocked
return await handler(request)

View File

@ -20,7 +20,7 @@ from langchain.agents.middleware.types import (
from langchain_core.messages import AIMessage
from langgraph.errors import GraphBubbleUp
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
@ -70,20 +70,11 @@ class LLMErrorHandlingMiddleware(AgentMiddleware[AgentState]):
retry_base_delay_ms: int = 1000
retry_cap_delay_ms: int = 8000
circuit_failure_threshold: int = 5
circuit_recovery_timeout_sec: int = 60
def __init__(self, **kwargs: Any) -> None:
def __init__(self, *, app_config: AppConfig, **kwargs: Any) -> None:
super().__init__(**kwargs)
# Load Circuit Breaker configs from app config if available, fall back to defaults
try:
app_config = get_app_config()
self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
except (FileNotFoundError, RuntimeError):
# Gracefully fall back to class defaults in test environments
pass
self.circuit_failure_threshold = app_config.circuit_breaker.failure_threshold
self.circuit_recovery_timeout_sec = app_config.circuit_breaker.recovery_timeout_sec
# Circuit Breaker state
self._circuit_lock = threading.Lock()

View File

@ -362,7 +362,7 @@ class LoopDetectionMiddleware(AgentMiddleware[AgentState]):
# the conversation; injecting one mid-conversation crashes
# langchain_anthropic's _format_messages(). HumanMessage works
# with all providers. See #1299.
return {"messages": [HumanMessage(content=warning)]}
return {"messages": [HumanMessage(content=warning, name="loop_warning")]}
return None

View File

@ -11,6 +11,7 @@ from langgraph.runtime import Runtime
from deerflow.agents.memory.message_processing import detect_correction, detect_reinforcement, filter_messages_for_memory
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.config.memory_config import get_memory_config
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
@ -86,11 +87,16 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
# Queue the filtered conversation for memory update
correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
# Capture user_id at enqueue time while the request context is still alive.
# threading.Timer fires on a different thread where ContextVar values are not
# propagated, so we must store user_id explicitly in ConversationContext.
user_id = get_effective_user_id()
queue = get_memory_queue()
queue.add(
thread_id=thread_id,
messages=filtered_messages,
agent_name=self._agent_name,
user_id=user_id,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)

View File

@ -3,12 +3,13 @@
from __future__ import annotations
import logging
from collections.abc import Collection
from dataclasses import dataclass
from typing import Protocol, runtime_checkable
from typing import Any, Protocol, override, runtime_checkable
from langchain.agents import AgentState
from langchain.agents.middleware import SummarizationMiddleware
from langchain_core.messages import AnyMessage, RemoveMessage
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
from langgraph.config import get_config
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
@ -58,17 +59,63 @@ def _resolve_agent_name(runtime: Runtime) -> str | None:
return agent_name
def _tool_call_path(tool_call: dict[str, Any]) -> str | None:
"""Best-effort extraction of a file path argument from a read_file-like tool call."""
args = tool_call.get("args") or {}
if not isinstance(args, dict):
return None
for key in ("path", "file_path", "filepath"):
value = args.get(key)
if isinstance(value, str) and value:
return value
return None
def _clone_ai_message(
message: AIMessage,
tool_calls: list[dict[str, Any]],
*,
content: Any | None = None,
) -> AIMessage:
"""Clone an AIMessage while replacing its tool_calls list and optional content."""
update: dict[str, Any] = {"tool_calls": tool_calls}
if content is not None:
update["content"] = content
return message.model_copy(update=update)
@dataclass
class _SkillBundle:
"""Skill-related tool calls and tool results associated with one AIMessage."""
ai_index: int
skill_tool_indices: tuple[int, ...]
skill_tool_call_ids: frozenset[str]
skill_tool_tokens: int
skill_key: str
class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
"""Summarization middleware with pre-compression hook dispatch."""
"""Summarization middleware with pre-compression hook dispatch and skill rescue."""
def __init__(
self,
*args,
skills_container_path: str | None = None,
skill_file_read_tool_names: Collection[str] | None = None,
before_summarization: list[BeforeSummarizationHook] | None = None,
preserve_recent_skill_count: int = 5,
preserve_recent_skill_tokens: int = 25_000,
preserve_recent_skill_tokens_per_skill: int = 5_000,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self._skills_container_path = skills_container_path or "/mnt/skills"
self._skill_file_read_tool_names = frozenset(skill_file_read_tool_names or {"read_file", "read", "view", "cat"})
self._before_summarization_hooks = before_summarization or []
self._preserve_recent_skill_count = max(0, preserve_recent_skill_count)
self._preserve_recent_skill_tokens = max(0, preserve_recent_skill_tokens)
self._preserve_recent_skill_tokens_per_skill = max(0, preserve_recent_skill_tokens_per_skill)
def before_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._maybe_summarize(state, runtime)
@ -88,7 +135,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
if cutoff_index <= 0:
return None
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
summary = self._create_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
@ -113,7 +160,7 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
if cutoff_index <= 0:
return None
messages_to_summarize, preserved_messages = self._partition_messages(messages, cutoff_index)
messages_to_summarize, preserved_messages = self._partition_with_skill_rescue(messages, cutoff_index)
self._fire_hooks(messages_to_summarize, preserved_messages, runtime)
summary = await self._acreate_summary(messages_to_summarize)
new_messages = self._build_new_messages(summary)
@ -126,6 +173,162 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
]
}
@override
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
"""Override the base implementation to let the human message with the special name 'summary'.
And this message will be ignored to display in the frontend, but still can be used as context for the model.
"""
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
def _partition_with_skill_rescue(
self,
messages: list[AnyMessage],
cutoff_index: int,
) -> tuple[list[AnyMessage], list[AnyMessage]]:
"""Partition like the parent, then rescue recently-loaded skill bundles."""
to_summarize, preserved = self._partition_messages(messages, cutoff_index)
if self._preserve_recent_skill_count == 0 or self._preserve_recent_skill_tokens == 0 or not to_summarize:
return to_summarize, preserved
try:
bundles = self._find_skill_bundles(to_summarize, self._skills_container_path)
except Exception:
logger.exception("Skill-preserving summarization rescue failed; falling back to default partition")
return to_summarize, preserved
if not bundles:
return to_summarize, preserved
rescue_bundles = self._select_bundles_to_rescue(bundles)
if not rescue_bundles:
return to_summarize, preserved
bundles_by_ai_index = {bundle.ai_index: bundle for bundle in rescue_bundles}
rescue_tool_indices = {idx for bundle in rescue_bundles for idx in bundle.skill_tool_indices}
rescued: list[AnyMessage] = []
remaining: list[AnyMessage] = []
for i, msg in enumerate(to_summarize):
bundle = bundles_by_ai_index.get(i)
if bundle is not None and isinstance(msg, AIMessage):
rescued_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") in bundle.skill_tool_call_ids]
remaining_tool_calls = [tc for tc in msg.tool_calls if tc.get("id") not in bundle.skill_tool_call_ids]
if rescued_tool_calls:
rescued.append(_clone_ai_message(msg, rescued_tool_calls, content=""))
if remaining_tool_calls or msg.content:
remaining.append(_clone_ai_message(msg, remaining_tool_calls))
continue
if i in rescue_tool_indices:
rescued.append(msg)
continue
remaining.append(msg)
return remaining, rescued + preserved
def _find_skill_bundles(
self,
messages: list[AnyMessage],
skills_root: str,
) -> list[_SkillBundle]:
"""Locate AIMessage + paired ToolMessage groups that load skill files."""
bundles: list[_SkillBundle] = []
n = len(messages)
i = 0
while i < n:
msg = messages[i]
if not (isinstance(msg, AIMessage) and msg.tool_calls):
i += 1
continue
tool_calls = list(msg.tool_calls)
skill_paths_by_id: dict[str, str] = {}
for tc in tool_calls:
if self._is_skill_tool_call(tc, skills_root):
tc_id = tc.get("id")
path = _tool_call_path(tc)
if tc_id and path:
skill_paths_by_id[tc_id] = path
if not skill_paths_by_id:
i += 1
continue
skill_tool_tokens = 0
skill_key_parts: list[str] = []
skill_tool_indices: list[int] = []
matched_skill_call_ids: set[str] = set()
j = i + 1
while j < n and isinstance(messages[j], ToolMessage):
j += 1
for k in range(i + 1, j):
tool_msg = messages[k]
if isinstance(tool_msg, ToolMessage) and tool_msg.tool_call_id in skill_paths_by_id:
skill_tool_tokens += self.token_counter([tool_msg])
skill_key_parts.append(skill_paths_by_id[tool_msg.tool_call_id])
skill_tool_indices.append(k)
matched_skill_call_ids.add(tool_msg.tool_call_id)
if not skill_tool_indices:
i = j
continue
bundles.append(
_SkillBundle(
ai_index=i,
skill_tool_indices=tuple(skill_tool_indices),
skill_tool_call_ids=frozenset(matched_skill_call_ids),
skill_tool_tokens=skill_tool_tokens,
skill_key="|".join(sorted(skill_key_parts)),
)
)
i = j
return bundles
def _select_bundles_to_rescue(self, bundles: list[_SkillBundle]) -> list[_SkillBundle]:
"""Pick bundles to keep, walking newest-first under count/token budgets."""
selected: list[_SkillBundle] = []
if not bundles:
return selected
seen_skill_keys: set[str] = set()
total_tokens = 0
kept = 0
for bundle in reversed(bundles):
if kept >= self._preserve_recent_skill_count:
break
if bundle.skill_key in seen_skill_keys:
continue
if bundle.skill_tool_tokens > self._preserve_recent_skill_tokens_per_skill:
continue
if total_tokens + bundle.skill_tool_tokens > self._preserve_recent_skill_tokens:
continue
selected.append(bundle)
total_tokens += bundle.skill_tool_tokens
kept += 1
seen_skill_keys.add(bundle.skill_key)
selected.reverse()
return selected
def _is_skill_tool_call(self, tool_call: dict[str, Any], skills_root: str) -> bool:
"""Return True when ``tool_call`` reads a file under the configured skills root."""
name = tool_call.get("name") or ""
if name not in self._skill_file_read_tool_names:
return False
path = _tool_call_path(tool_call)
if not path:
return False
normalized_root = skills_root.rstrip("/")
return path == normalized_root or path.startswith(normalized_root + "/")
def _fire_hooks(
self,
messages_to_summarize: list[AnyMessage],

View File

@ -1,13 +1,16 @@
import logging
from datetime import UTC, datetime
from typing import NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain_core.messages import HumanMessage
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.agents.thread_state import ThreadDataState
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
@ -46,32 +49,34 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
self._paths = Paths(base_dir) if base_dir else get_paths()
self._lazy_init = lazy_init
def _get_thread_paths(self, thread_id: str) -> dict[str, str]:
def _get_thread_paths(self, thread_id: str, user_id: str | None = None) -> dict[str, str]:
"""Get the paths for a thread's data directories.
Args:
thread_id: The thread ID.
user_id: Optional user ID for per-user path isolation.
Returns:
Dictionary with workspace_path, uploads_path, and outputs_path.
"""
return {
"workspace_path": str(self._paths.sandbox_work_dir(thread_id)),
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id)),
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id)),
"workspace_path": str(self._paths.sandbox_work_dir(thread_id, user_id=user_id)),
"uploads_path": str(self._paths.sandbox_uploads_dir(thread_id, user_id=user_id)),
"outputs_path": str(self._paths.sandbox_outputs_dir(thread_id, user_id=user_id)),
}
def _create_thread_directories(self, thread_id: str) -> dict[str, str]:
def _create_thread_directories(self, thread_id: str, user_id: str | None = None) -> dict[str, str]:
"""Create the thread data directories.
Args:
thread_id: The thread ID.
user_id: Optional user ID for per-user path isolation.
Returns:
Dictionary with the created directory paths.
"""
self._paths.ensure_thread_dirs(thread_id)
return self._get_thread_paths(thread_id)
self._paths.ensure_thread_dirs(thread_id, user_id=user_id)
return self._get_thread_paths(thread_id, user_id=user_id)
@override
def before_agent(self, state: ThreadDataMiddlewareState, runtime: Runtime) -> dict | None:
@ -84,16 +89,30 @@ class ThreadDataMiddleware(AgentMiddleware[ThreadDataMiddlewareState]):
if thread_id is None:
raise ValueError("Thread ID is required in runtime context or config.configurable")
user_id = get_effective_user_id()
if self._lazy_init:
# Lazy initialization: only compute paths, don't create directories
paths = self._get_thread_paths(thread_id)
paths = self._get_thread_paths(thread_id, user_id=user_id)
else:
# Eager initialization: create directories immediately
paths = self._create_thread_directories(thread_id)
paths = self._create_thread_directories(thread_id, user_id=user_id)
logger.debug("Created thread data directories for thread %s", thread_id)
messages = list(state.get("messages", []))
last_message = messages[-1] if messages else None
if last_message and isinstance(last_message, HumanMessage):
messages[-1] = HumanMessage(
content=last_message.content,
id=last_message.id,
name=last_message.name or "user-input",
additional_kwargs={**last_message.additional_kwargs, "run_id": runtime.context.get("run_id"), "timestamp": datetime.now(UTC).isoformat()},
)
return {
"thread_data": {
**paths,
}
},
"messages": messages,
}

View File

@ -2,10 +2,11 @@
import logging
import re
from typing import NotRequired, override
from typing import Any, NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.config.title_config import get_title_config
@ -106,6 +107,21 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
return user_msg[:fallback_chars].rstrip() + "..."
return user_msg if user_msg else "New Conversation"
def _get_runnable_config(self) -> dict[str, Any]:
"""Inherit the parent RunnableConfig and add middleware tag.
This ensures RunJournal identifies LLM calls from this middleware
as ``middleware:title`` instead of ``lead_agent``.
"""
try:
parent = get_config()
except Exception:
parent = {}
config = {**parent}
config["run_name"] = "title_agent"
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
return config
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Generate a local fallback title without blocking on an LLM call."""
if not self._should_generate_title(state):
@ -127,7 +143,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
model = create_chat_model(name=config.model_name, thinking_enabled=False)
else:
model = create_chat_model(thinking_enabled=False)
response = await model.ainvoke(prompt)
response = await model.ainvoke(prompt, config=self._get_runnable_config())
title = self._parse_title(response.content)
if title:
return {"title": title}

View File

@ -11,6 +11,8 @@ from langgraph.errors import GraphBubbleUp
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
from deerflow.config.app_config import AppConfig
logger = logging.getLogger(__name__)
_MISSING_TOOL_CALL_ID = "missing_tool_call_id"
@ -67,6 +69,7 @@ class ToolErrorHandlingMiddleware(AgentMiddleware[AgentState]):
def _build_runtime_middlewares(
*,
app_config: AppConfig,
include_uploads: bool,
include_dangling_tool_call_patch: bool,
lazy_init: bool = True,
@ -91,12 +94,10 @@ def _build_runtime_middlewares(
middlewares.append(DanglingToolCallMiddleware())
middlewares.append(LLMErrorHandlingMiddleware())
middlewares.append(LLMErrorHandlingMiddleware(app_config=app_config))
# Guardrail middleware (if configured)
from deerflow.config.guardrails_config import get_guardrails_config
guardrails_config = get_guardrails_config()
guardrails_config = app_config.guardrails
if guardrails_config.enabled and guardrails_config.provider:
import inspect
@ -125,18 +126,20 @@ def _build_runtime_middlewares(
return middlewares
def build_lead_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
def build_lead_runtime_middlewares(*, app_config: AppConfig, lazy_init: bool = True) -> list[AgentMiddleware]:
"""Middlewares shared by lead agent runtime before lead-only middlewares."""
return _build_runtime_middlewares(
app_config=app_config,
include_uploads=True,
include_dangling_tool_call_patch=True,
lazy_init=lazy_init,
)
def build_subagent_runtime_middlewares(*, lazy_init: bool = True) -> list[AgentMiddleware]:
def build_subagent_runtime_middlewares(*, app_config: AppConfig, lazy_init: bool = True) -> list[AgentMiddleware]:
"""Middlewares shared by subagent runtime before subagent-only middlewares."""
return _build_runtime_middlewares(
app_config=app_config,
include_uploads=False,
include_dangling_tool_call_patch=True,
lazy_init=lazy_init,

View File

@ -10,6 +10,7 @@ from langchain_core.messages import HumanMessage
from langgraph.runtime import Runtime
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.utils.file_conversion import extract_outline
logger = logging.getLogger(__name__)
@ -221,7 +222,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
thread_id = get_config().get("configurable", {}).get("thread_id")
except RuntimeError:
pass # get_config() raises outside a runnable context (e.g. unit tests)
uploads_dir = self._paths.sandbox_uploads_dir(thread_id) if thread_id else None
uploads_dir = self._paths.sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) if thread_id else None
# Get newly uploaded files from the current message's additional_kwargs.files
new_files = self._files_from_kwargs(last_message, uploads_dir) or []
@ -282,6 +283,7 @@ class UploadsMiddleware(AgentMiddleware[UploadsMiddlewareState]):
updated_message = HumanMessage(
content=updated_content,
id=last_message.id,
name=last_message.name,
additional_kwargs=last_message.additional_kwargs,
)

View File

@ -40,7 +40,8 @@ from deerflow.config.app_config import get_app_config, reload_app_config
from deerflow.config.extensions_config import ExtensionsConfig, SkillStateConfig, get_extensions_config, reload_extensions_config
from deerflow.config.paths import get_paths
from deerflow.models import create_chat_model
from deerflow.skills.installer import install_skill_from_archive
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.skills.storage import get_or_new_skill_storage
from deerflow.uploads.manager import (
claim_unique_filename,
delete_file_safe,
@ -240,7 +241,7 @@ class DeerFlowClient:
}
checkpointer = self._checkpointer
if checkpointer is None:
from deerflow.agents.checkpointer import get_checkpointer
from deerflow.runtime.checkpointer import get_checkpointer
checkpointer = get_checkpointer()
if checkpointer is not None:
@ -374,7 +375,7 @@ class DeerFlowClient:
"""
checkpointer = self._checkpointer
if checkpointer is None:
from deerflow.agents.checkpointer.provider import get_checkpointer
from deerflow.runtime.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer()
@ -429,7 +430,7 @@ class DeerFlowClient:
"""
checkpointer = self._checkpointer
if checkpointer is None:
from deerflow.agents.checkpointer.provider import get_checkpointer
from deerflow.runtime.checkpointer.provider import get_checkpointer
checkpointer = get_checkpointer()
@ -751,8 +752,6 @@ class DeerFlowClient:
Dict with "skills" key containing list of skill info dicts,
matching the Gateway API ``SkillsListResponse`` schema.
"""
from deerflow.skills.loader import load_skills
return {
"skills": [
{
@ -762,7 +761,7 @@ class DeerFlowClient:
"category": s.category,
"enabled": s.enabled,
}
for s in load_skills(enabled_only=enabled_only)
for s in get_or_new_skill_storage().load_skills(enabled_only=enabled_only)
]
}
@ -774,19 +773,19 @@ class DeerFlowClient:
"""
from deerflow.agents.memory.updater import get_memory_data
return get_memory_data()
return get_memory_data(user_id=get_effective_user_id())
def export_memory(self) -> dict:
"""Export current memory data for backup or transfer."""
from deerflow.agents.memory.updater import get_memory_data
return get_memory_data()
return get_memory_data(user_id=get_effective_user_id())
def import_memory(self, memory_data: dict) -> dict:
"""Import and persist full memory data."""
from deerflow.agents.memory.updater import import_memory_data
return import_memory_data(memory_data)
return import_memory_data(memory_data, user_id=get_effective_user_id())
def get_model(self, name: str) -> dict | None:
"""Get a specific model's configuration by name.
@ -871,9 +870,9 @@ class DeerFlowClient:
Returns:
Skill info dict, or None if not found.
"""
from deerflow.skills.loader import load_skills
from deerflow.skills.storage import get_or_new_skill_storage
skill = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
skill = next((s for s in get_or_new_skill_storage().load_skills(enabled_only=False) if s.name == name), None)
if skill is None:
return None
return {
@ -898,9 +897,9 @@ class DeerFlowClient:
ValueError: If the skill is not found.
OSError: If the config file cannot be written.
"""
from deerflow.skills.loader import load_skills
from deerflow.skills.storage import get_or_new_skill_storage
skills = load_skills(enabled_only=False)
skills = get_or_new_skill_storage().load_skills(enabled_only=False)
skill = next((s for s in skills if s.name == name), None)
if skill is None:
raise ValueError(f"Skill '{name}' not found")
@ -923,7 +922,7 @@ class DeerFlowClient:
self._agent_config_key = None
reload_extensions_config()
updated = next((s for s in load_skills(enabled_only=False) if s.name == name), None)
updated = next((s for s in get_or_new_skill_storage().load_skills(enabled_only=False) if s.name == name), None)
if updated is None:
raise RuntimeError(f"Skill '{name}' disappeared after update")
return {
@ -947,7 +946,7 @@ class DeerFlowClient:
FileNotFoundError: If the file does not exist.
ValueError: If the file is invalid.
"""
return install_skill_from_archive(skill_path)
return get_or_new_skill_storage().install_skill_from_archive(skill_path)
# ------------------------------------------------------------------
# Public API — memory management
@ -961,13 +960,13 @@ class DeerFlowClient:
"""
from deerflow.agents.memory.updater import reload_memory_data
return reload_memory_data()
return reload_memory_data(user_id=get_effective_user_id())
def clear_memory(self) -> dict:
"""Clear all persisted memory data."""
from deerflow.agents.memory.updater import clear_memory_data
return clear_memory_data()
return clear_memory_data(user_id=get_effective_user_id())
def create_memory_fact(self, content: str, category: str = "context", confidence: float = 0.5) -> dict:
"""Create a single fact manually."""
@ -1184,7 +1183,7 @@ class DeerFlowClient:
ValueError: If the path is invalid.
"""
try:
actual = get_paths().resolve_virtual_path(thread_id, path)
actual = get_paths().resolve_virtual_path(thread_id, path, user_id=get_effective_user_id())
except ValueError as exc:
if "traversal" in str(exc):
from deerflow.uploads.manager import PathTraversalError

View File

@ -27,6 +27,7 @@ except ImportError: # pragma: no cover - Windows fallback
from deerflow.config import get_app_config
from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.sandbox.sandbox import Sandbox
from deerflow.sandbox.sandbox_provider import SandboxProvider
@ -270,15 +271,16 @@ class AioSandboxProvider(SandboxProvider):
mounted Docker socket (DooD), the host Docker daemon can resolve the paths.
"""
paths = get_paths()
paths.ensure_thread_dirs(thread_id)
user_id = get_effective_user_id()
paths.ensure_thread_dirs(thread_id, user_id=user_id)
return [
(paths.host_sandbox_work_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False),
(paths.host_sandbox_uploads_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False),
(paths.host_sandbox_outputs_dir(thread_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False),
(paths.host_sandbox_work_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/workspace", False),
(paths.host_sandbox_uploads_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/uploads", False),
(paths.host_sandbox_outputs_dir(thread_id, user_id=user_id), f"{VIRTUAL_PATH_PREFIX}/outputs", False),
# ACP workspace: read-only inside the sandbox (lead agent reads results;
# the ACP subprocess writes from the host side, not from within the container).
(paths.host_acp_workspace_dir(thread_id), "/mnt/acp-workspace", True),
(paths.host_acp_workspace_dir(thread_id, user_id=user_id), "/mnt/acp-workspace", True),
]
@staticmethod
@ -490,8 +492,9 @@ class AioSandboxProvider(SandboxProvider):
across multiple processes, preventing container-name conflicts.
"""
paths = get_paths()
paths.ensure_thread_dirs(thread_id)
lock_path = paths.thread_dir(thread_id) / f"{sandbox_id}.lock"
user_id = get_effective_user_id()
paths.ensure_thread_dirs(thread_id, user_id=user_id)
lock_path = paths.thread_dir(thread_id, user_id=user_id) / f"{sandbox_id}.lock"
with open(lock_path, "a", encoding="utf-8") as lock_file:
locked = False

View File

@ -9,6 +9,7 @@ from __future__ import annotations
import json
import logging
import os
import shlex
import subprocess
from datetime import datetime
@ -86,6 +87,88 @@ def _format_container_mount(runtime: str, host_path: str, container_path: str, r
return ["-v", mount_spec]
def _redact_container_command_for_log(cmd: list[str]) -> list[str]:
"""Return a Docker/Container command with environment values redacted."""
redacted: list[str] = []
redact_next_env = False
for arg in cmd:
if redact_next_env:
if "=" in arg:
key = arg.split("=", 1)[0]
redacted.append(f"{key}=<redacted>" if key else "<redacted>")
else:
redacted.append(arg)
redact_next_env = False
continue
if arg in {"-e", "--env"}:
redacted.append(arg)
redact_next_env = True
continue
if arg.startswith("--env="):
value = arg.removeprefix("--env=")
if "=" in value:
key = value.split("=", 1)[0]
redacted.append(f"--env={key}=<redacted>" if key else "--env=<redacted>")
else:
redacted.append(arg)
continue
redacted.append(arg)
return redacted
def _format_container_command_for_log(cmd: list[str]) -> str:
if os.name == "nt":
return subprocess.list2cmdline(cmd)
return shlex.join(cmd)
def _normalize_sandbox_host(host: str) -> str:
return host.strip().lower()
def _is_ipv6_loopback_sandbox_host(host: str) -> bool:
return _normalize_sandbox_host(host) in {"::1", "[::1]"}
def _is_loopback_sandbox_host(host: str) -> bool:
return _normalize_sandbox_host(host) in {"", "localhost", "127.0.0.1", "::1", "[::1]"}
def _resolve_docker_bind_host(sandbox_host: str | None = None, bind_host: str | None = None) -> str:
"""Choose the host interface for legacy Docker ``-p`` sandbox publishing.
Bare-metal/local runs talk to sandboxes through localhost and should not
expose the sandbox HTTP API on every host interface. Docker-outside-of-
Docker deployments commonly use ``host.docker.internal`` from another
container; keep their legacy broad bind unless operators opt into a
narrower bind with ``DEER_FLOW_SANDBOX_BIND_HOST``. When operators choose
an IPv6 loopback sandbox host, bind Docker to IPv6 loopback as well so the
advertised sandbox URL and published socket use the same address family.
"""
explicit_bind = bind_host if bind_host is not None else os.environ.get("DEER_FLOW_SANDBOX_BIND_HOST")
if explicit_bind is not None:
explicit_bind = explicit_bind.strip()
if explicit_bind:
logger.debug("Docker sandbox bind: %s (explicit bind host override)", explicit_bind)
return explicit_bind
host = sandbox_host if sandbox_host is not None else os.environ.get("DEER_FLOW_SANDBOX_HOST", "localhost")
if _is_ipv6_loopback_sandbox_host(host):
logger.debug("Docker sandbox bind: [::1] (IPv6 loopback sandbox host)")
return "[::1]"
if _is_loopback_sandbox_host(host):
logger.debug("Docker sandbox bind: 127.0.0.1 (loopback default)")
return "127.0.0.1"
logger.debug("Docker sandbox bind: 0.0.0.0 (non-loopback sandbox host compatibility)")
return "0.0.0.0"
class LocalContainerBackend(SandboxBackend):
"""Backend that manages sandbox containers locally using Docker or Apple Container.
@ -424,12 +507,17 @@ class LocalContainerBackend(SandboxBackend):
if self._runtime == "docker":
cmd.extend(["--security-opt", "seccomp=unconfined"])
if self._runtime == "docker":
port_mapping = f"{_resolve_docker_bind_host()}:{port}:8080"
else:
port_mapping = f"{port}:8080"
cmd.extend(
[
"--rm",
"-d",
"-p",
f"{port}:8080",
port_mapping,
"--name",
container_name,
]
@ -464,7 +552,8 @@ class LocalContainerBackend(SandboxBackend):
cmd.append(self._image)
logger.info(f"Starting container using {self._runtime}: {' '.join(cmd)}")
log_cmd = _format_container_command_for_log(_redact_container_command_for_log(cmd))
logger.info(f"Starting container using {self._runtime}: {log_cmd}")
try:
result = subprocess.run(cmd, capture_output=True, text=True, check=True)

View File

@ -38,6 +38,6 @@ class JinaClient:
return response.text
except Exception as e:
error_message = f"Request to Jina API failed: {str(e)}"
logger.exception(error_message)
error_message = f"Request to Jina API failed: {type(e).__name__}: {e}"
logger.warning(error_message)
return f"Error: {error_message}"

View File

@ -11,10 +11,12 @@ from pydantic import BaseModel, ConfigDict, Field
from deerflow.config.acp_config import load_acp_config_from_dict
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
from deerflow.config.database_config import DatabaseConfig
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.config.guardrails_config import GuardrailsConfig, load_guardrails_config_from_dict
from deerflow.config.memory_config import MemoryConfig, load_memory_config_from_dict
from deerflow.config.model_config import ModelConfig
from deerflow.config.run_events_config import RunEventsConfig
from deerflow.config.sandbox_config import SandboxConfig
from deerflow.config.skill_evolution_config import SkillEvolutionConfig
from deerflow.config.skills_config import SkillsConfig
@ -31,6 +33,12 @@ load_dotenv()
logger = logging.getLogger(__name__)
CONFIG_FILE_DATABASE_DEFAULTS = {
"backend": "sqlite",
"sqlite_dir": ".deer-flow/data",
}
class CircuitBreakerConfig(BaseModel):
"""Configuration for the LLM Circuit Breaker."""
@ -45,10 +53,34 @@ def _default_config_candidates() -> tuple[Path, ...]:
return (backend_dir / "config.yaml", repo_root / "config.yaml")
def logging_level_from_config(name: str | None) -> int:
"""Map ``config.yaml`` ``log_level`` string to a :mod:`logging` level constant."""
mapping = logging.getLevelNamesMapping()
return mapping.get((name or "info").strip().upper(), logging.INFO)
def apply_logging_level(name: str | None) -> None:
"""Resolve *name* to a logging level and apply it to the ``deerflow``/``app`` logger hierarchies.
Only the ``deerflow`` and ``app`` logger levels are changed so that
third-party library verbosity (e.g. uvicorn, sqlalchemy) is not
affected. Root handler levels are lowered (never raised) so that
messages from the configured loggers can propagate through without
being filtered, while preserving handler thresholds that may be
intentionally restrictive for third-party log output.
"""
level = logging_level_from_config(name)
for logger_name in ("deerflow", "app"):
logging.getLogger(logger_name).setLevel(level)
for handler in logging.root.handlers:
if level < handler.level:
handler.setLevel(level)
class AppConfig(BaseModel):
"""Config for the DeerFlow application"""
log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)")
log_level: str = Field(default="info", description="Logging level for deerflow and app modules (debug/info/warning/error); third-party libraries are not affected")
token_usage: TokenUsageConfig = Field(default_factory=TokenUsageConfig, description="Token usage tracking configuration")
models: list[ModelConfig] = Field(default_factory=list, description="Available models")
sandbox: SandboxConfig = Field(description="Sandbox configuration")
@ -65,7 +97,9 @@ class AppConfig(BaseModel):
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
model_config = ConfigDict(extra="allow", frozen=False)
model_config = ConfigDict(extra="allow")
database: DatabaseConfig = Field(default_factory=DatabaseConfig, description="Unified database backend configuration")
run_events: RunEventsConfig = Field(default_factory=RunEventsConfig, description="Run event storage configuration")
checkpointer: CheckpointerConfig | None = Field(default=None, description="Checkpointer configuration")
stream_bridge: StreamBridgeConfig | None = Field(default=None, description="Stream bridge configuration")
@ -114,6 +148,7 @@ class AppConfig(BaseModel):
cls._check_config_version(config_data, resolved_path)
config_data = cls.resolve_env_variables(config_data)
cls._apply_database_defaults(config_data)
# Load title config if present
if "title" in config_data:
@ -165,6 +200,18 @@ class AppConfig(BaseModel):
result = cls.model_validate(config_data)
return result
@classmethod
def _apply_database_defaults(cls, config_data: dict[str, Any]) -> None:
"""Apply config.yaml defaults for persistence when the section is absent."""
database_config = config_data.get("database")
if database_config is None:
database_config = {}
config_data["database"] = database_config
if not isinstance(database_config, dict):
return
for key, value in CONFIG_FILE_DATABASE_DEFAULTS.items():
database_config.setdefault(key, value)
@classmethod
def _check_config_version(cls, config_data: dict, config_path: Path) -> None:
"""Check if the user's config.yaml is outdated compared to config.example.yaml.
@ -269,6 +316,9 @@ class AppConfig(BaseModel):
return next((group for group in self.tool_groups if group.name == name), None)
# Compatibility singleton layer for code paths that have not yet been
# migrated to explicit ``AppConfig`` threading. New composition roots should
# prefer constructing ``AppConfig`` once and passing it down directly.
_app_config: AppConfig | None = None
_app_config_path: Path | None = None
_app_config_mtime: float | None = None

View File

@ -0,0 +1,102 @@
"""Unified database backend configuration.
Controls BOTH the LangGraph checkpointer and the DeerFlow application
persistence layer (runs, threads metadata, users, etc.). The user
configures one backend; the system handles physical separation details.
SQLite mode: checkpointer and app share a single .db file
({sqlite_dir}/deerflow.db) with WAL journal mode enabled on every
connection. WAL allows concurrent readers and a single writer without
blocking, making a unified file safe for both workloads. Writers
that contend for the lock wait via the default 5-second sqlite3
busy timeout rather than failing immediately.
Postgres mode: both use the same database URL but maintain independent
connection pools with different lifecycles.
Memory mode: checkpointer uses MemorySaver, app uses in-memory stores.
No database is initialized.
Sensitive values (postgres_url) should use $VAR syntax in config.yaml
to reference environment variables from .env:
database:
backend: postgres
postgres_url: $DATABASE_URL
The $VAR resolution is handled by AppConfig.resolve_env_variables()
before this config is instantiated -- DatabaseConfig itself does not
need to do any environment variable processing.
"""
from __future__ import annotations
import os
from typing import Literal
from pydantic import BaseModel, Field
class DatabaseConfig(BaseModel):
backend: Literal["memory", "sqlite", "postgres"] = Field(
default="memory",
description=("Storage backend for both checkpointer and application data. 'memory' for development (no persistence across restarts), 'sqlite' for single-node deployment, 'postgres' for production multi-node deployment."),
)
sqlite_dir: str = Field(
default=".deer-flow/data",
description=("Directory for the SQLite database file. Both checkpointer and application data share {sqlite_dir}/deerflow.db."),
)
postgres_url: str = Field(
default="",
description=(
"PostgreSQL connection URL, shared by checkpointer and app. "
"Use $DATABASE_URL in config.yaml to reference .env. "
"Example: postgresql://user:pass@host:5432/deerflow "
"(the +asyncpg driver suffix is added automatically where needed)."
),
)
echo_sql: bool = Field(
default=False,
description="Echo all SQL statements to log (debug only).",
)
pool_size: int = Field(
default=5,
description="Connection pool size for the app ORM engine (postgres only).",
)
# -- Derived helpers (not user-configured) --
@property
def _resolved_sqlite_dir(self) -> str:
"""Resolve sqlite_dir to an absolute path (relative to CWD)."""
from pathlib import Path
return str(Path(self.sqlite_dir).resolve())
@property
def sqlite_path(self) -> str:
"""Unified SQLite file path shared by checkpointer and app."""
return os.path.join(self._resolved_sqlite_dir, "deerflow.db")
# Backward-compatible aliases
@property
def checkpointer_sqlite_path(self) -> str:
"""SQLite file path for the LangGraph checkpointer (alias for sqlite_path)."""
return self.sqlite_path
@property
def app_sqlite_path(self) -> str:
"""SQLite file path for application ORM data (alias for sqlite_path)."""
return self.sqlite_path
@property
def app_sqlalchemy_url(self) -> str:
"""SQLAlchemy async URL for the application ORM engine."""
if self.backend == "sqlite":
return f"sqlite+aiosqlite:///{self.sqlite_path}"
if self.backend == "postgres":
url = self.postgres_url
if url.startswith("postgresql://"):
url = url.replace("postgresql://", "postgresql+asyncpg://", 1)
return url
raise ValueError(f"No SQLAlchemy URL for backend={self.backend!r}")

View File

@ -14,8 +14,9 @@ class MemoryConfig(BaseModel):
default="",
description=(
"Path to store memory data. "
"If empty, defaults to `{base_dir}/memory.json` (see Paths.memory_file). "
"Absolute paths are used as-is. "
"If empty, defaults to per-user memory at `{base_dir}/users/{user_id}/memory.json`. "
"Absolute paths are used as-is and opt out of per-user isolation "
"(all users share the same file). "
"Relative paths are resolved against `Paths.base_dir` "
"(not the backend working directory). "
"Note: if you previously set this to `.deer-flow/memory.json`, "

View File

@ -7,6 +7,7 @@ from pathlib import Path, PureWindowsPath
VIRTUAL_PATH_PREFIX = "/mnt/user-data"
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
_SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
def _default_local_base_dir() -> Path:
@ -22,6 +23,13 @@ def _validate_thread_id(thread_id: str) -> str:
return thread_id
def _validate_user_id(user_id: str) -> str:
"""Validate a user ID before using it in filesystem paths."""
if not _SAFE_USER_ID_RE.match(user_id):
raise ValueError(f"Invalid user_id {user_id!r}: only alphanumeric characters, hyphens, and underscores are allowed.")
return user_id
def _join_host_path(base: str, *parts: str) -> str:
"""Join host filesystem path segments while preserving native style.
@ -134,44 +142,63 @@ class Paths:
"""Per-agent memory file: `{base_dir}/agents/{name}/memory.json`."""
return self.agent_dir(name) / "memory.json"
def thread_dir(self, thread_id: str) -> Path:
def user_dir(self, user_id: str) -> Path:
"""Directory for a specific user: `{base_dir}/users/{user_id}/`."""
return self.base_dir / "users" / _validate_user_id(user_id)
def user_memory_file(self, user_id: str) -> Path:
"""Per-user memory file: `{base_dir}/users/{user_id}/memory.json`."""
return self.user_dir(user_id) / "memory.json"
def user_agent_memory_file(self, user_id: str, agent_name: str) -> Path:
"""Per-user per-agent memory: `{base_dir}/users/{user_id}/agents/{name}/memory.json`."""
return self.user_dir(user_id) / "agents" / agent_name.lower() / "memory.json"
def thread_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
"""
Host path for a thread's data: `{base_dir}/threads/{thread_id}/`
Host path for a thread's data.
When *user_id* is provided:
`{base_dir}/users/{user_id}/threads/{thread_id}/`
Otherwise (legacy layout):
`{base_dir}/threads/{thread_id}/`
This directory contains a `user-data/` subdirectory that is mounted
as `/mnt/user-data/` inside the sandbox.
Raises:
ValueError: If `thread_id` contains unsafe characters (path separators
or `..`) that could cause directory traversal.
ValueError: If `thread_id` or `user_id` contains unsafe characters (path
separators or `..`) that could cause directory traversal.
"""
if user_id is not None:
return self.user_dir(user_id) / "threads" / _validate_thread_id(thread_id)
return self.base_dir / "threads" / _validate_thread_id(thread_id)
def sandbox_work_dir(self, thread_id: str) -> Path:
def sandbox_work_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
"""
Host path for the agent's workspace directory.
Host: `{base_dir}/threads/{thread_id}/user-data/workspace/`
Sandbox: `/mnt/user-data/workspace/`
"""
return self.thread_dir(thread_id) / "user-data" / "workspace"
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "workspace"
def sandbox_uploads_dir(self, thread_id: str) -> Path:
def sandbox_uploads_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
"""
Host path for user-uploaded files.
Host: `{base_dir}/threads/{thread_id}/user-data/uploads/`
Sandbox: `/mnt/user-data/uploads/`
"""
return self.thread_dir(thread_id) / "user-data" / "uploads"
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "uploads"
def sandbox_outputs_dir(self, thread_id: str) -> Path:
def sandbox_outputs_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
"""
Host path for agent-generated artifacts.
Host: `{base_dir}/threads/{thread_id}/user-data/outputs/`
Sandbox: `/mnt/user-data/outputs/`
"""
return self.thread_dir(thread_id) / "user-data" / "outputs"
return self.thread_dir(thread_id, user_id=user_id) / "user-data" / "outputs"
def acp_workspace_dir(self, thread_id: str) -> Path:
def acp_workspace_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
"""
Host path for the ACP workspace of a specific thread.
Host: `{base_dir}/threads/{thread_id}/acp-workspace/`
@ -180,41 +207,43 @@ class Paths:
Each thread gets its own isolated ACP workspace so that concurrent
sessions cannot read each other's ACP agent outputs.
"""
return self.thread_dir(thread_id) / "acp-workspace"
return self.thread_dir(thread_id, user_id=user_id) / "acp-workspace"
def sandbox_user_data_dir(self, thread_id: str) -> Path:
def sandbox_user_data_dir(self, thread_id: str, *, user_id: str | None = None) -> Path:
"""
Host path for the user-data root.
Host: `{base_dir}/threads/{thread_id}/user-data/`
Sandbox: `/mnt/user-data/`
"""
return self.thread_dir(thread_id) / "user-data"
return self.thread_dir(thread_id, user_id=user_id) / "user-data"
def host_thread_dir(self, thread_id: str) -> str:
def host_thread_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
"""Host path for a thread directory, preserving Windows path syntax."""
if user_id is not None:
return _join_host_path(self._host_base_dir_str(), "users", _validate_user_id(user_id), "threads", _validate_thread_id(thread_id))
return _join_host_path(self._host_base_dir_str(), "threads", _validate_thread_id(thread_id))
def host_sandbox_user_data_dir(self, thread_id: str) -> str:
def host_sandbox_user_data_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
"""Host path for a thread's user-data root."""
return _join_host_path(self.host_thread_dir(thread_id), "user-data")
return _join_host_path(self.host_thread_dir(thread_id, user_id=user_id), "user-data")
def host_sandbox_work_dir(self, thread_id: str) -> str:
def host_sandbox_work_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
"""Host path for the workspace mount source."""
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "workspace")
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "workspace")
def host_sandbox_uploads_dir(self, thread_id: str) -> str:
def host_sandbox_uploads_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
"""Host path for the uploads mount source."""
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "uploads")
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "uploads")
def host_sandbox_outputs_dir(self, thread_id: str) -> str:
def host_sandbox_outputs_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
"""Host path for the outputs mount source."""
return _join_host_path(self.host_sandbox_user_data_dir(thread_id), "outputs")
return _join_host_path(self.host_sandbox_user_data_dir(thread_id, user_id=user_id), "outputs")
def host_acp_workspace_dir(self, thread_id: str) -> str:
def host_acp_workspace_dir(self, thread_id: str, *, user_id: str | None = None) -> str:
"""Host path for the ACP workspace mount source."""
return _join_host_path(self.host_thread_dir(thread_id), "acp-workspace")
return _join_host_path(self.host_thread_dir(thread_id, user_id=user_id), "acp-workspace")
def ensure_thread_dirs(self, thread_id: str) -> None:
def ensure_thread_dirs(self, thread_id: str, *, user_id: str | None = None) -> None:
"""Create all standard sandbox directories for a thread.
Directories are created with mode 0o777 so that sandbox containers
@ -228,24 +257,24 @@ class Paths:
ACP agent invocation.
"""
for d in [
self.sandbox_work_dir(thread_id),
self.sandbox_uploads_dir(thread_id),
self.sandbox_outputs_dir(thread_id),
self.acp_workspace_dir(thread_id),
self.sandbox_work_dir(thread_id, user_id=user_id),
self.sandbox_uploads_dir(thread_id, user_id=user_id),
self.sandbox_outputs_dir(thread_id, user_id=user_id),
self.acp_workspace_dir(thread_id, user_id=user_id),
]:
d.mkdir(parents=True, exist_ok=True)
d.chmod(0o777)
def delete_thread_dir(self, thread_id: str) -> None:
def delete_thread_dir(self, thread_id: str, *, user_id: str | None = None) -> None:
"""Delete all persisted data for a thread.
The operation is idempotent: missing thread directories are ignored.
"""
thread_dir = self.thread_dir(thread_id)
thread_dir = self.thread_dir(thread_id, user_id=user_id)
if thread_dir.exists():
shutil.rmtree(thread_dir)
def resolve_virtual_path(self, thread_id: str, virtual_path: str) -> Path:
def resolve_virtual_path(self, thread_id: str, virtual_path: str, *, user_id: str | None = None) -> Path:
"""Resolve a sandbox virtual path to the actual host filesystem path.
Args:
@ -253,6 +282,7 @@ class Paths:
virtual_path: Virtual path as seen inside the sandbox, e.g.
``/mnt/user-data/outputs/report.pdf``.
Leading slashes are stripped before matching.
user_id: Optional user ID for user-scoped path resolution.
Returns:
The resolved absolute host filesystem path.
@ -270,7 +300,7 @@ class Paths:
raise ValueError(f"Path must start with /{prefix}")
relative = stripped[len(prefix) :].lstrip("/")
base = self.sandbox_user_data_dir(thread_id).resolve()
base = self.sandbox_user_data_dir(thread_id, user_id=user_id).resolve()
actual = (base / relative).resolve()
try:

View File

@ -0,0 +1,33 @@
"""Run event storage configuration.
Controls where run events (messages + execution traces) are persisted.
Backends:
- memory: In-memory storage, data lost on restart. Suitable for
development and testing.
- db: SQL database via SQLAlchemy ORM. Provides full query capability.
Suitable for production deployments.
- jsonl: Append-only JSONL files. Lightweight alternative for
single-node deployments that need persistence without a database.
"""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field
class RunEventsConfig(BaseModel):
backend: Literal["memory", "db", "jsonl"] = Field(
default="memory",
description="Storage backend for run events. 'memory' for development (no persistence), 'db' for production (SQL queries), 'jsonl' for lightweight single-node persistence.",
)
max_trace_content: int = Field(
default=10240,
description="Maximum trace content size in bytes before truncation (db backend only).",
)
track_token_usage: bool = Field(
default=True,
description="Whether RunJournal should accumulate token counts to RunRow.",
)

View File

@ -11,6 +11,10 @@ def _default_repo_root() -> Path:
class SkillsConfig(BaseModel):
"""Configuration for skills system"""
use: str = Field(
default="deerflow.skills.storage.local_skill_storage:LocalSkillStorage",
description="Class path of the SkillStorage implementation.",
)
path: str | None = Field(
default=None,
description="Path to skills directory. If not specified, defaults to ../skills relative to backend directory",
@ -35,10 +39,8 @@ class SkillsConfig(BaseModel):
path = _default_repo_root() / path
return path.resolve()
else:
# Default: ../skills relative to backend directory
from deerflow.skills.loader import get_skills_root_path
return get_skills_root_path()
# Default: <repo_root>/skills
return _default_repo_root() / "skills"
def get_skill_container_path(self, skill_name: str, category: str = "public") -> str:
"""

View File

@ -25,6 +25,47 @@ class SubagentOverrideConfig(BaseModel):
min_length=1,
description="Model name for this subagent (None = inherit from parent agent)",
)
skills: list[str] | None = Field(
default=None,
description="Skill names whitelist for this subagent (None = inherit all enabled skills, [] = no skills)",
)
class CustomSubagentConfig(BaseModel):
"""User-defined subagent type declared in config.yaml."""
description: str = Field(
description="When the lead agent should delegate to this subagent",
)
system_prompt: str = Field(
description="System prompt that guides the subagent's behavior",
)
tools: list[str] | None = Field(
default=None,
description="Tool names whitelist (None = inherit all tools from parent)",
)
disallowed_tools: list[str] | None = Field(
default_factory=lambda: ["task", "ask_clarification", "present_files"],
description="Tool names to deny",
)
skills: list[str] | None = Field(
default=None,
description="Skill names whitelist (None = inherit all enabled skills, [] = no skills)",
)
model: str = Field(
default="inherit",
description="Model to use - 'inherit' uses parent's model",
)
max_turns: int = Field(
default=50,
ge=1,
description="Maximum number of agent turns before stopping",
)
timeout_seconds: int = Field(
default=900,
ge=1,
description="Maximum execution time in seconds",
)
class SubagentsAppConfig(BaseModel):
@ -44,6 +85,10 @@ class SubagentsAppConfig(BaseModel):
default_factory=dict,
description="Per-agent configuration overrides keyed by agent name",
)
custom_agents: dict[str, CustomSubagentConfig] = Field(
default_factory=dict,
description="User-defined subagent types keyed by agent name",
)
def get_timeout_for(self, agent_name: str) -> int:
"""Get the effective timeout for a specific agent.
@ -82,6 +127,20 @@ class SubagentsAppConfig(BaseModel):
return self.max_turns
return builtin_default
def get_skills_for(self, agent_name: str) -> list[str] | None:
"""Get the skills override for a specific agent.
Args:
agent_name: The name of the subagent.
Returns:
Skill names whitelist if overridden, None otherwise (subagent will inherit all enabled skills).
"""
override = self.agents.get(agent_name)
if override is not None and override.skills is not None:
return override.skills
return None
_subagents_config: SubagentsAppConfig = SubagentsAppConfig()
@ -105,15 +164,20 @@ def load_subagents_config_from_dict(config_dict: dict) -> None:
parts.append(f"max_turns={override.max_turns}")
if override.model is not None:
parts.append(f"model={override.model}")
if override.skills is not None:
parts.append(f"skills={override.skills}")
if parts:
overrides_summary[name] = ", ".join(parts)
if overrides_summary:
custom_agents_names = list(_subagents_config.custom_agents.keys())
if overrides_summary or custom_agents_names:
logger.info(
"Subagents config loaded: default timeout=%ss, default max_turns=%s, per-agent overrides=%s",
"Subagents config loaded: default timeout=%ss, default max_turns=%s, per-agent overrides=%s, custom_agents=%s",
_subagents_config.timeout_seconds,
_subagents_config.max_turns,
overrides_summary,
overrides_summary or "none",
custom_agents_names or "none",
)
else:
logger.info(

View File

@ -51,6 +51,25 @@ class SummarizationConfig(BaseModel):
default=None,
description="Custom prompt template for generating summaries. If not provided, uses the default LangChain prompt.",
)
preserve_recent_skill_count: int = Field(
default=5,
ge=0,
description="Number of most-recently-loaded skill files to exclude from summarization. Set to 0 to disable skill preservation.",
)
preserve_recent_skill_tokens: int = Field(
default=25000,
ge=0,
description="Total token budget reserved for recently-loaded skill files that must be preserved across summarization.",
)
preserve_recent_skill_tokens_per_skill: int = Field(
default=5000,
ge=0,
description="Per-skill token cap when preserving skill files across summarization. Skill reads above this size are not rescued.",
)
skill_file_read_tool_names: list[str] = Field(
default_factory=lambda: ["read_file", "read", "view", "cat"],
description="Tool names treated as skill file reads when preserving recently-loaded skills across summarization.",
)
# Global configuration instance

View File

@ -12,6 +12,7 @@ from langchain_core.tools import BaseTool
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.mcp.client import build_servers_config
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
from deerflow.reflection import resolve_variable
logger = logging.getLogger(__name__)
@ -95,6 +96,27 @@ async def get_mcp_tools() -> list[BaseTool]:
if oauth_interceptor is not None:
tool_interceptors.append(oauth_interceptor)
# Load custom interceptors declared in extensions_config.json
# Format: "mcpInterceptors": ["pkg.module:builder_func", ...]
raw_interceptor_paths = (extensions_config.model_extra or {}).get("mcpInterceptors")
if isinstance(raw_interceptor_paths, str):
raw_interceptor_paths = [raw_interceptor_paths]
elif not isinstance(raw_interceptor_paths, list):
if raw_interceptor_paths is not None:
logger.warning(f"mcpInterceptors must be a list of strings, got {type(raw_interceptor_paths).__name__}; skipping")
raw_interceptor_paths = []
for interceptor_path in raw_interceptor_paths:
try:
builder = resolve_variable(interceptor_path)
interceptor = builder()
if callable(interceptor):
tool_interceptors.append(interceptor)
logger.info(f"Loaded MCP interceptor: {interceptor_path}")
elif interceptor is not None:
logger.warning(f"Builder {interceptor_path} returned non-callable {type(interceptor).__name__}; skipping")
except Exception as e:
logger.warning(f"Failed to load MCP interceptor {interceptor_path}: {e}", exc_info=True)
client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True)
# Get all tools from all servers

View File

@ -190,23 +190,33 @@ class ClaudeChatModel(ChatAnthropic):
)
def _apply_prompt_caching(self, payload: dict) -> None:
"""Apply ephemeral cache_control to system and recent messages."""
# Cache system messages
"""Apply ephemeral cache_control to system, recent messages, and last tool definition.
Uses a budget of MAX_CACHE_BREAKPOINTS (4) breakpoints the hard limit
enforced by both the Anthropic API and AWS Bedrock. Breakpoints are
placed on the *last* eligible blocks because later breakpoints cover a
larger prefix and yield better cache hit rates.
"""
MAX_CACHE_BREAKPOINTS = 4
# Collect candidate blocks in document order:
# 1. system text blocks
# 2. content blocks of the last prompt_cache_size messages
# 3. the last tool definition
candidates: list[dict] = []
# 1. System blocks
system = payload.get("system")
if system and isinstance(system, list):
for block in system:
if isinstance(block, dict) and block.get("type") == "text":
block["cache_control"] = {"type": "ephemeral"}
candidates.append(block)
elif system and isinstance(system, str):
payload["system"] = [
{
"type": "text",
"text": system,
"cache_control": {"type": "ephemeral"},
}
]
new_block: dict = {"type": "text", "text": system}
payload["system"] = [new_block]
candidates.append(new_block)
# Cache recent messages
# 2. Recent message blocks
messages = payload.get("messages", [])
cache_start = max(0, len(messages) - self.prompt_cache_size)
for i in range(cache_start, len(messages)):
@ -217,20 +227,21 @@ class ClaudeChatModel(ChatAnthropic):
if isinstance(content, list):
for block in content:
if isinstance(block, dict):
block["cache_control"] = {"type": "ephemeral"}
candidates.append(block)
elif isinstance(content, str) and content:
msg["content"] = [
{
"type": "text",
"text": content,
"cache_control": {"type": "ephemeral"},
}
]
new_block = {"type": "text", "text": content}
msg["content"] = [new_block]
candidates.append(new_block)
# Cache the last tool definition
# 3. Last tool definition
tools = payload.get("tools", [])
if tools and isinstance(tools[-1], dict):
tools[-1]["cache_control"] = {"type": "ephemeral"}
candidates.append(tools[-1])
# Apply cache_control only to the last MAX_CACHE_BREAKPOINTS candidates
# to stay within the API limit.
for block in candidates[-MAX_CACHE_BREAKPOINTS:]:
block["cache_control"] = {"type": "ephemeral"}
def _apply_thinking_budget(self, payload: dict) -> None:
"""Auto-allocate thinking budget (80% of max_tokens)."""

View File

@ -3,6 +3,7 @@ import logging
from langchain.chat_models import BaseChatModel
from deerflow.config import get_app_config
from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_class
from deerflow.tracing import build_tracing_callbacks
@ -46,7 +47,7 @@ def _enable_stream_usage_by_default(model_use_path: str, model_settings_from_con
model_settings_from_config["stream_usage"] = True
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, **kwargs) -> BaseChatModel:
def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *, app_config: AppConfig | None = None, **kwargs) -> BaseChatModel:
"""Create a chat model instance from the config.
Args:
@ -55,7 +56,7 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
Returns:
A chat model instance.
"""
config = get_app_config()
config = app_config or get_app_config()
if name is None:
name = config.models[0].name
model_config = config.get_model_config(name)
@ -131,7 +132,22 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
elif "reasoning_effort" not in model_settings_from_config:
model_settings_from_config["reasoning_effort"] = "medium"
model_instance = model_class(**{**model_settings_from_config, **kwargs})
# For MindIE models: enforce conservative retry defaults.
# Timeout normalization is handled inside MindIEChatModel itself.
if getattr(model_class, "__name__", "") == "MindIEChatModel":
# Enforce max_retries constraint to prevent cascading timeouts.
model_settings_from_config["max_retries"] = model_settings_from_config.get("max_retries", 1)
# Ensure stream_usage is enabled so that token usage metadata is available
# in streaming responses. LangChain's BaseChatOpenAI only defaults
# stream_usage=True when no custom base_url/api_base is set, so models
# hitting third-party endpoints (e.g. doubao, deepseek) silently lose
# usage data. We default it to True unless explicitly configured.
if "stream_usage" not in model_settings_from_config and "stream_usage" not in kwargs:
if "stream_usage" in getattr(model_class, "model_fields", {}):
model_settings_from_config["stream_usage"] = True
model_instance = model_class(**kwargs, **model_settings_from_config)
callbacks = build_tracing_callbacks()
if callbacks:

View File

@ -0,0 +1,249 @@
import ast
import html
import json
import re
import uuid
from collections.abc import Iterator
import httpx
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage
from langchain_core.outputs import ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI
def _fix_messages(messages: list) -> list:
"""Sanitize incoming messages for MindIE compatibility.
MindIE's chat template may fail to parse LangChain's native tool_calls
or ToolMessage roles, resulting in 0-token generation errors. This function
flattens multi-modal list contents into strings and converts tool-related
messages into raw text with XML tags expected by the underlying model.
"""
fixed = []
for msg in messages:
# Flatten content if it's a list of blocks
if isinstance(msg.content, list):
parts = []
for block in msg.content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, dict) and block.get("type") == "text":
parts.append(block.get("text", ""))
text = "".join(parts)
else:
text = msg.content or ""
# Convert AIMessage with tool_calls to raw XML text format
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", []):
xml_parts = []
for tool in msg.tool_calls:
args_xml = " ".join(f"<parameter={html.escape(str(k), quote=False)}>{html.escape(v if isinstance(v, str) else json.dumps(v, ensure_ascii=False), quote=False)}</parameter>" for k, v in tool.get("args", {}).items())
xml_parts.append(f"<tool_call> <function={html.escape(str(tool['name']), quote=False)}> {args_xml} </function> </tool_call>")
full_text = f"{text}\n" + "\n".join(xml_parts) if text else "\n".join(xml_parts)
fixed.append(AIMessage(content=full_text.strip() or " "))
continue
# Wrap tool execution results in XML tags and convert to HumanMessage
if isinstance(msg, ToolMessage):
tool_result_text = f"<tool_response>\n{text}\n</tool_response>"
fixed.append(HumanMessage(content=tool_result_text))
continue
# Fallback to prevent completely empty message content
if not text.strip():
text = " "
fixed.append(msg.model_copy(update={"content": text}))
return fixed
def _parse_xml_tool_call_to_dict(content: str) -> tuple[str, list[dict]]:
"""Parse XML-style tool calls from model output into LangChain dicts.
Args:
content: The raw text output from the model.
Returns:
A tuple containing the cleaned text (with XML blocks removed) and
a list of tool call dictionaries formatted for LangChain.
"""
if not isinstance(content, str) or "<tool_call>" not in content:
return content, []
tool_calls = []
clean_parts: list[str] = []
cursor = 0
for start, end, inner_content in _iter_tool_call_blocks(content):
clean_parts.append(content[cursor:start])
cursor = end
func_match = re.search(r"<function=([^>]+)>", inner_content)
if not func_match:
continue
function_name = html.unescape(func_match.group(1).strip())
# Ignore nested tool blocks when extracting parameters for this call.
# Nested `<tool_call>` sections represent separate invocations and
# their `<parameter>` tags must not leak into the current call args.
param_source_parts: list[str] = []
nested_cursor = 0
for nested_start, nested_end, _ in _iter_tool_call_blocks(inner_content):
param_source_parts.append(inner_content[nested_cursor:nested_start])
nested_cursor = nested_end
param_source_parts.append(inner_content[nested_cursor:])
param_source = "".join(param_source_parts)
args = {}
param_pattern = re.compile(r"<parameter=([^>]+)>(.*?)</parameter>", re.DOTALL)
for param_match in param_pattern.finditer(param_source):
key = html.unescape(param_match.group(1).strip())
raw_value = html.unescape(param_match.group(2).strip())
# Attempt to deserialize string values into native Python types
# to satisfy downstream Pydantic validation.
parsed_value = raw_value
if raw_value.startswith(("[", "{")) or raw_value in ("true", "false", "null") or raw_value.isdigit():
try:
parsed_value = json.loads(raw_value)
except json.JSONDecodeError:
try:
parsed_value = ast.literal_eval(raw_value)
except (ValueError, SyntaxError):
pass
args[key] = parsed_value
tool_calls.append({"name": function_name, "args": args, "id": f"call_{uuid.uuid4().hex[:10]}"})
clean_parts.append(content[cursor:])
return "".join(clean_parts).strip(), tool_calls
def _iter_tool_call_blocks(content: str) -> Iterator[tuple[int, int, str]]:
"""Iterate `<tool_call>...</tool_call>` blocks and tolerate nesting."""
token_pattern = re.compile(r"</?tool_call>")
depth = 0
block_start = -1
for match in token_pattern.finditer(content):
token = match.group(0)
if token == "<tool_call>":
if depth == 0:
block_start = match.start()
depth += 1
continue
if depth == 0:
continue
depth -= 1
if depth == 0 and block_start != -1:
block_end = match.end()
inner_start = block_start + len("<tool_call>")
inner_end = match.start()
yield block_start, block_end, content[inner_start:inner_end]
block_start = -1
def _decode_escaped_newlines_outside_fences(content: str) -> str:
"""Decode literal `\\n` outside fenced code blocks."""
if "\\n" not in content:
return content
parts = re.split(r"(```[\s\S]*?```)", content)
for idx, part in enumerate(parts):
if part.startswith("```"):
continue
parts[idx] = part.replace("\\n", "\n")
return "".join(parts)
class MindIEChatModel(ChatOpenAI):
"""Chat model adapter for MindIE engine.
Addresses compatibility issues including:
- Flattening multimodal list contents to strings.
- Intercepting and parsing hardcoded XML tool calls into LangChain standard.
- Handling stream=True dropping choices when tools are present by falling back
to non-streaming generation and yielding simulated chunks.
- Fixing over-escaped newline characters from gateway responses.
"""
def __init__(self, **kwargs):
"""Normalize timeout kwargs without creating long-lived clients."""
connect_timeout = kwargs.pop("connect_timeout", 30.0)
read_timeout = kwargs.pop("read_timeout", 900.0)
write_timeout = kwargs.pop("write_timeout", 60.0)
pool_timeout = kwargs.pop("pool_timeout", 30.0)
kwargs.setdefault(
"timeout",
httpx.Timeout(
connect=connect_timeout,
read=read_timeout,
write=write_timeout,
pool=pool_timeout,
),
)
super().__init__(**kwargs)
def _patch_result_with_tools(self, result: ChatResult) -> ChatResult:
"""Apply post-generation fixes to the model result."""
for gen in result.generations:
msg = gen.message
if isinstance(msg.content, str):
# Keep escaped newlines inside fenced code blocks untouched.
msg.content = _decode_escaped_newlines_outside_fences(msg.content)
if "<tool_call>" in msg.content:
clean_content, extracted_tools = _parse_xml_tool_call_to_dict(msg.content)
if extracted_tools:
msg.content = clean_content
if getattr(msg, "tool_calls", None) is None:
msg.tool_calls = []
msg.tool_calls.extend(extracted_tools)
return result
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
result = super()._generate(_fix_messages(messages), stop=stop, run_manager=run_manager, **kwargs)
return self._patch_result_with_tools(result)
async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs):
result = await super()._agenerate(_fix_messages(messages), stop=stop, run_manager=run_manager, **kwargs)
return self._patch_result_with_tools(result)
async def _astream(self, messages, stop=None, run_manager=None, **kwargs):
# Route standard queries to native streaming for lower TTFB
if not kwargs.get("tools"):
async for chunk in super()._astream(_fix_messages(messages), stop=stop, run_manager=run_manager, **kwargs):
if isinstance(chunk.message.content, str):
chunk.message.content = _decode_escaped_newlines_outside_fences(chunk.message.content)
yield chunk
return
# Fallback for tool-enabled requests:
# MindIE currently drops choices when stream=True and tools are present.
# We await the full generation and yield chunks to simulate streaming.
result = await self._agenerate(messages, stop=stop, run_manager=run_manager, **kwargs)
for gen in result.generations:
msg = gen.message
content = msg.content
standard_tool_calls = getattr(msg, "tool_calls", [])
# Yield text in chunks to allow downstream UI/Markdown parsers to render smoothly
if isinstance(content, str) and content:
chunk_size = 15
for i in range(0, len(content), chunk_size):
chunk_text = content[i : i + chunk_size]
chunk_msg = AIMessageChunk(content=chunk_text, id=msg.id, response_metadata=msg.response_metadata if i == 0 else {})
yield ChatGenerationChunk(message=chunk_msg, generation_info=gen.generation_info if i == 0 else None)
if standard_tool_calls:
yield ChatGenerationChunk(message=AIMessageChunk(content="", id=msg.id, tool_calls=standard_tool_calls, invalid_tool_calls=getattr(msg, "invalid_tool_calls", [])))
else:
chunk_msg = AIMessageChunk(content=content, id=msg.id, tool_calls=standard_tool_calls, invalid_tool_calls=getattr(msg, "invalid_tool_calls", []))
yield ChatGenerationChunk(message=chunk_msg, generation_info=gen.generation_info)

View File

@ -0,0 +1,13 @@
"""DeerFlow application persistence layer (SQLAlchemy 2.0 async ORM).
This module manages DeerFlow's own application data -- runs metadata,
thread ownership, cron jobs, users. It is completely separate from
LangGraph's checkpointer, which manages graph execution state.
Usage:
from deerflow.persistence import init_engine, close_engine, get_session_factory
"""
from deerflow.persistence.engine import close_engine, get_engine, get_session_factory, init_engine
__all__ = ["close_engine", "get_engine", "get_session_factory", "init_engine"]

Some files were not shown because too many files have changed in this diff Show More