refactor(routers): reorganize routers with new langgraph/ subdirectory

Restructure app/gateway/routers/:
- Add langgraph/ subdirectory for LangGraph-related endpoints:
  - threads.py - thread management
  - runs.py - run execution and streaming
  - feedback.py - feedback endpoints
  - suggestions.py - follow-up suggestions

Remove old standalone routers:
- threads.py → langgraph/threads.py
- thread_runs.py → langgraph/runs.py
- runs.py (stateless) → langgraph/runs.py
- feedback.py → langgraph/feedback.py

Update existing routers:
- memory.py, uploads.py, artifacts.py, suggestions.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
rayhpeng 2026-04-22 11:28:26 +08:00
parent 9d0a42c1fb
commit 5f2f1941e9
13 changed files with 1332 additions and 1383 deletions

View File

@ -1,3 +1,3 @@
from . import artifacts, assistants_compat, mcp, models, skills, suggestions, thread_runs, threads, uploads
from . import artifacts, mcp, models, skills, suggestions, uploads
__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"]
__all__ = ["artifacts", "mcp", "models", "skills", "suggestions", "uploads"]

View File

@ -7,7 +7,6 @@ from urllib.parse import quote
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import FileResponse, PlainTextResponse, Response
from app.gateway.authz import require_permission
from app.gateway.path_utils import resolve_thread_virtual_path
logger = logging.getLogger(__name__)
@ -82,7 +81,6 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte
summary="Get Artifact File",
description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.",
)
@require_permission("threads", "read", owner_check=True)
async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response:
"""Get an artifact file by its path.

View File

@ -0,0 +1,6 @@
from .feedback import router as feedback_router
from .runs import router as runs_router
from .suggestions import router as suggestion_router
from .threads import router as threads_router
__all__ = ["feedback_router", "runs_router", "threads_router", "suggestion_router"]

View File

@ -1,8 +1,4 @@
"""Feedback endpoints — create, list, stats, delete.
Allows users to submit thumbs-up/down feedback on runs,
optionally scoped to a specific message.
"""
"""LangGraph-compatible run feedback endpoints."""
from __future__ import annotations
@ -12,16 +8,12 @@ from typing import Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_current_user, get_feedback_repo, get_run_store
from app.gateway.dependencies import get_feedback_repository, get_run_repository
from app.plugins.auth.security.actor_context import bind_request_actor_context, resolve_request_user_id
from app.plugins.auth.security.dependencies import get_current_user_id
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["feedback"])
# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------
router = APIRouter(tags=["feedback"])
class FeedbackCreateRequest(BaseModel):
@ -30,16 +22,11 @@ class FeedbackCreateRequest(BaseModel):
message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message")
class FeedbackUpsertRequest(BaseModel):
rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)")
comment: str | None = Field(default=None, description="Optional text feedback")
class FeedbackResponse(BaseModel):
feedback_id: str
run_id: str
thread_id: str
user_id: str | None = None
owner_id: str | None = None
message_id: str | None = None
rating: int
comment: str | None = None
@ -53,85 +40,36 @@ class FeedbackStatsResponse(BaseModel):
negative: int = 0
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def upsert_feedback(
thread_id: str,
run_id: str,
body: FeedbackUpsertRequest,
request: Request,
) -> dict[str, Any]:
"""Create or update feedback for a run (idempotent)."""
if body.rating not in (1, -1):
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
user_id = await get_current_user(request)
run_store = get_run_store(request)
run = await run_store.get(run_id)
async def _validate_run_scope(thread_id: str, run_id: str, request: Request) -> None:
run_store = get_run_repository(request)
if resolve_request_user_id(request) is None:
run = await run_store.get(run_id, user_id=None)
else:
with bind_request_actor_context(request):
run = await run_store.get(run_id)
if run is None:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
if run.get("thread_id") != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
feedback_repo = get_feedback_repo(request)
return await feedback_repo.upsert(
run_id=run_id,
thread_id=thread_id,
rating=body.rating,
user_id=user_id,
comment=body.comment,
)
async def _get_current_user(request: Request) -> str | None:
"""Extract current user id from auth dependencies when available."""
return await get_current_user_id(request)
@router.delete("/{thread_id}/runs/{run_id}/feedback")
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_run_feedback(
thread_id: str,
run_id: str,
request: Request,
) -> dict[str, bool]:
"""Delete the current user's feedback for a run."""
user_id = await get_current_user(request)
feedback_repo = get_feedback_repo(request)
deleted = await feedback_repo.delete_by_run(
thread_id=thread_id,
run_id=run_id,
user_id=user_id,
)
if not deleted:
raise HTTPException(status_code=404, detail="No feedback found for this run")
return {"success": True}
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def create_feedback(
async def _create_feedback(
thread_id: str,
run_id: str,
body: FeedbackCreateRequest,
request: Request,
) -> dict[str, Any]:
"""Submit feedback (thumbs-up/down) for a run."""
if body.rating not in (1, -1):
raise HTTPException(status_code=400, detail="rating must be +1 or -1")
user_id = await get_current_user(request)
# Validate run exists and belongs to thread
run_store = get_run_store(request)
run = await run_store.get(run_id)
if run is None:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
if run.get("thread_id") != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}")
feedback_repo = get_feedback_repo(request)
await _validate_run_scope(thread_id, run_id, request)
user_id = await _get_current_user(request)
feedback_repo = get_feedback_repository(request)
return await feedback_repo.create(
run_id=run_id,
thread_id=thread_id,
@ -142,41 +80,94 @@ async def create_feedback(
)
@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
async def upsert_feedback(
thread_id: str,
run_id: str,
body: FeedbackCreateRequest,
request: Request,
) -> dict[str, Any]:
"""Create or replace the run-level feedback record."""
feedback_repo = get_feedback_repository(request)
user_id = await _get_current_user(request)
if user_id is not None:
return await feedback_repo.upsert(
run_id=run_id,
thread_id=thread_id,
rating=body.rating,
user_id=user_id,
comment=body.comment,
)
existing = await feedback_repo.list_by_run(thread_id, run_id, limit=100, user_id=None)
for item in existing:
feedback_id = item.get("feedback_id")
if isinstance(feedback_id, str):
await feedback_repo.delete(feedback_id)
return await _create_feedback(thread_id, run_id, body, request)
@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse)
async def create_feedback(
thread_id: str,
run_id: str,
body: FeedbackCreateRequest,
request: Request,
) -> dict[str, Any]:
"""Submit feedback for a run."""
return await _create_feedback(thread_id, run_id, body, request)
@router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse])
@require_permission("threads", "read", owner_check=True)
async def list_feedback(
thread_id: str,
run_id: str,
request: Request,
) -> list[dict[str, Any]]:
"""List all feedback for a run."""
feedback_repo = get_feedback_repo(request)
return await feedback_repo.list_by_run(thread_id, run_id)
feedback_repo = get_feedback_repository(request)
user_id = await _get_current_user(request)
return await feedback_repo.list_by_run(thread_id, run_id, user_id=user_id)
@router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse)
@require_permission("threads", "read", owner_check=True)
async def feedback_stats(
thread_id: str,
run_id: str,
request: Request,
) -> dict[str, Any]:
"""Get aggregated feedback stats (positive/negative counts) for a run."""
feedback_repo = get_feedback_repo(request)
"""Get aggregated feedback stats for a run."""
feedback_repo = get_feedback_repository(request)
return await feedback_repo.aggregate_by_run(thread_id, run_id)
@router.delete("/{thread_id}/runs/{run_id}/feedback")
async def delete_run_feedback(
thread_id: str,
run_id: str,
request: Request,
) -> dict[str, bool]:
"""Delete all feedback records for a run."""
feedback_repo = get_feedback_repository(request)
user_id = await _get_current_user(request)
if user_id is not None:
return {"success": await feedback_repo.delete_by_run(thread_id=thread_id, run_id=run_id, user_id=user_id)}
existing = await feedback_repo.list_by_run(thread_id, run_id, limit=100, user_id=None)
for item in existing:
feedback_id = item.get("feedback_id")
if isinstance(feedback_id, str):
await feedback_repo.delete(feedback_id)
return {"success": True}
@router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}")
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_feedback(
thread_id: str,
run_id: str,
feedback_id: str,
request: Request,
) -> dict[str, bool]:
"""Delete a feedback record."""
feedback_repo = get_feedback_repo(request)
# Verify feedback belongs to the specified thread/run before deleting
"""Delete a single feedback record."""
feedback_repo = get_feedback_repository(request)
existing = await feedback_repo.get(feedback_id)
if existing is None:
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")

View File

@ -0,0 +1,501 @@
"""LangGraph-compatible runs endpoints backed by RunsFacade."""
from __future__ import annotations
import json
from collections.abc import AsyncIterator
from typing import Literal
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel, Field
from app.plugins.auth.security.actor_context import bind_request_actor_context
from app.gateway.services.runs.facade_factory import build_runs_facade_from_request
from app.gateway.services.runs.input import (
AdaptedRunRequest,
RunSpecBuilder,
UnsupportedRunFeatureError,
adapt_create_run_request,
adapt_create_stream_request,
adapt_create_wait_request,
adapt_join_stream_request,
adapt_join_wait_request,
)
from deerflow.runtime.runs.types import RunRecord, RunSpec
from deerflow.runtime.stream_bridge import JSONValue, StreamEvent
router = APIRouter(tags=["runs"])
class RunCreateRequest(BaseModel):
assistant_id: str | None = Field(default=None, description="Agent / assistant to use")
follow_up_to_run_id: str | None = Field(default=None, description="Lineage link to the prior run")
input: dict[str, JSONValue] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})")
command: dict[str, JSONValue] | None = Field(default=None, description="LangGraph Command")
metadata: dict[str, JSONValue] | None = Field(default=None, description="Run metadata")
config: dict[str, JSONValue] | None = Field(default=None, description="RunnableConfig overrides")
context: dict[str, JSONValue] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
webhook: str | None = Field(default=None, description="Completion callback URL")
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
checkpoint: dict[str, JSONValue] | None = Field(default=None, description="Full checkpoint object")
interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before")
interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after")
stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)")
stream_subgraphs: bool = Field(default=False, description="Include subgraph events")
stream_resumable: bool | None = Field(default=None, description="SSE resumable mode")
on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect")
on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion")
multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy")
after_seconds: float | None = Field(default=None, description="Delayed execution")
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
class RunResponse(BaseModel):
run_id: str
thread_id: str
assistant_id: str | None = None
status: str
metadata: dict[str, JSONValue] = Field(default_factory=dict)
multitask_strategy: str = "reject"
created_at: str = ""
updated_at: str = ""
class RunDeleteResponse(BaseModel):
deleted: bool
class RunMessageResponse(BaseModel):
run_id: str
content: JSONValue
metadata: dict[str, JSONValue] = Field(default_factory=dict)
created_at: str
seq: int
class RunMessagesResponse(BaseModel):
data: list[RunMessageResponse]
hasMore: bool = False
def format_sse(event: str, data: JSONValue, *, event_id: str | None = None) -> str:
"""Format a single SSE frame."""
payload = json.dumps(data, default=str, ensure_ascii=False)
parts = [f"event: {event}", f"data: {payload}"]
if event_id:
parts.append(f"id: {event_id}")
parts.append("")
parts.append("")
return "\n".join(parts)
def _record_to_response(record: RunRecord) -> RunResponse:
return RunResponse(
run_id=record.run_id,
thread_id=record.thread_id,
assistant_id=record.assistant_id,
status=record.status,
metadata=record.metadata,
multitask_strategy=record.multitask_strategy,
created_at=record.created_at,
updated_at=record.updated_at,
)
def _trim_paginated_rows(
rows: list[dict],
*,
limit: int,
after_seq: int | None,
) -> tuple[list[dict], bool]:
has_more = len(rows) > limit
if not has_more:
return rows, False
if after_seq is not None:
return rows[:limit], True
return rows[-limit:], True
def _event_to_run_message(event: dict) -> RunMessageResponse:
return RunMessageResponse(
run_id=str(event["run_id"]),
content=event.get("content"),
metadata=dict(event.get("metadata") or {}),
created_at=str(event.get("created_at") or ""),
seq=int(event["seq"]),
)
async def _sse_consumer(
stream: AsyncIterator[StreamEvent],
request: Request,
*,
cancel_on_disconnect: bool,
cancel_run,
run_id: str,
) -> AsyncIterator[str]:
try:
async for event in stream:
if await request.is_disconnected():
break
if event.event == "__heartbeat__":
yield ": heartbeat\n\n"
continue
if event.event == "__end__":
yield format_sse("end", None, event_id=event.id or None)
return
if event.event == "__cancelled__":
yield format_sse("cancel", None, event_id=event.id or None)
return
yield format_sse(event.event, event.data, event_id=event.id or None)
finally:
if cancel_on_disconnect:
await cancel_run(run_id)
def _get_run_event_store(request: Request):
event_store = getattr(request.app.state, "run_event_store", None)
if event_store is None:
raise HTTPException(status_code=503, detail="Run event store not available")
return event_store
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
async def list_runs(
thread_id: str,
request: Request,
limit: int = 100,
offset: int = 0,
status: str | None = None,
) -> list[RunResponse]:
# Accepted for API compatibility; field projection is not implemented yet.
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
records = await facade.list_runs(thread_id)
if status is not None:
records = [record for record in records if record.status == status]
records = records[offset : offset + limit]
return [_record_to_response(record) for record in records]
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record = await facade.get_run(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
return _record_to_response(record)
@router.get("/{thread_id}/runs/{run_id}/messages", response_model=RunMessagesResponse)
async def run_messages(
thread_id: str,
run_id: str,
request: Request,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
) -> RunMessagesResponse:
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record = await facade.get_run(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
event_store = _get_run_event_store(request)
with bind_request_actor_context(request):
rows = await event_store.list_messages_by_run(
thread_id,
run_id,
limit=limit + 1,
before_seq=before_seq,
after_seq=after_seq,
)
page, has_more = _trim_paginated_rows(rows, limit=limit, after_seq=after_seq)
return RunMessagesResponse(data=[_event_to_run_message(row) for row in page], hasMore=has_more)
def _build_spec(
*,
adapted: AdaptedRunRequest,
) -> RunSpec:
try:
return RunSpecBuilder().build(adapted)
except UnsupportedRunFeatureError as exc:
raise HTTPException(status_code=501, detail=str(exc)) from exc
@router.post("/{thread_id}/runs", response_model=RunResponse)
async def create_run(
thread_id: str,
body: RunCreateRequest,
request: Request,
) -> Response:
adapted = adapt_create_run_request(
thread_id=thread_id,
body=body.model_dump(),
headers=dict(request.headers),
query=dict(request.query_params),
)
spec = _build_spec(adapted=adapted)
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record = await facade.create_background(spec)
return Response(
content=_record_to_response(record).model_dump_json(),
media_type="application/json",
headers={"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}"},
)
@router.post("/{thread_id}/runs/stream")
async def stream_run(
thread_id: str,
body: RunCreateRequest,
request: Request,
) -> StreamingResponse:
adapted = adapt_create_stream_request(
thread_id=thread_id,
body=body.model_dump(),
headers=dict(request.headers),
query=dict(request.query_params),
)
spec = _build_spec(adapted=adapted)
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record, stream = await facade.create_and_stream(spec)
return StreamingResponse(
_sse_consumer(
stream,
request,
cancel_on_disconnect=spec.on_disconnect == "cancel",
cancel_run=facade.cancel,
run_id=record.run_id,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
},
)
@router.post("/{thread_id}/runs/wait")
async def wait_run(
thread_id: str,
body: RunCreateRequest,
request: Request,
) -> Response:
adapted = adapt_create_wait_request(
thread_id=thread_id,
body=body.model_dump(),
headers=dict(request.headers),
query=dict(request.query_params),
)
spec = _build_spec(adapted=adapted)
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record, result = await facade.create_and_wait(spec)
return Response(
content=json.dumps(result, default=str, ensure_ascii=False),
media_type="application/json",
headers={"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}"},
)
@router.post("/runs", response_model=RunResponse)
async def create_stateless_run(body: RunCreateRequest, request: Request) -> Response:
adapted = adapt_create_run_request(
thread_id=None,
body=body.model_dump(),
headers=dict(request.headers),
query=dict(request.query_params),
)
spec = _build_spec(adapted=adapted)
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record = await facade.create_background(spec)
return Response(
content=_record_to_response(record).model_dump_json(),
media_type="application/json",
headers={"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}"},
)
@router.post("/runs/stream")
async def create_stateless_stream_run(body: RunCreateRequest, request: Request) -> StreamingResponse:
adapted = adapt_create_stream_request(
thread_id=None,
body=body.model_dump(),
headers=dict(request.headers),
query=dict(request.query_params),
)
spec = _build_spec(adapted=adapted)
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record, stream = await facade.create_and_stream(spec)
return StreamingResponse(
_sse_consumer(
stream,
request,
cancel_on_disconnect=spec.on_disconnect == "cancel",
cancel_run=facade.cancel,
run_id=record.run_id,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}",
},
)
@router.post("/runs/wait")
async def wait_stateless_run(body: RunCreateRequest, request: Request) -> Response:
adapted = adapt_create_wait_request(
thread_id=None,
body=body.model_dump(),
headers=dict(request.headers),
query=dict(request.query_params),
)
spec = _build_spec(adapted=adapted)
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record, result = await facade.create_and_wait(spec)
return Response(
content=json.dumps(result, default=str, ensure_ascii=False),
media_type="application/json",
headers={"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}"},
)
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
async def stream_existing_run(
thread_id: str,
run_id: str,
request: Request,
action: Literal["interrupt", "rollback"] | None = None,
wait: bool = False,
cancel_on_disconnect: bool = False,
stream_mode: str | None = None,
) -> StreamingResponse | Response:
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record = await facade.get_run(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
if action is not None:
with bind_request_actor_context(request):
cancelled = await facade.cancel(run_id, action=action)
if not cancelled:
raise HTTPException(status_code=409, detail=f"Run {run_id} is not cancellable")
if wait:
with bind_request_actor_context(request):
await facade.join_wait(run_id)
return Response(status_code=204)
adapted = adapt_join_stream_request(
thread_id=thread_id,
run_id=run_id,
headers=dict(request.headers),
query=dict(request.query_params),
)
with bind_request_actor_context(request):
stream = await facade.join_stream(run_id, last_event_id=adapted.last_event_id)
return StreamingResponse(
_sse_consumer(
stream,
request,
cancel_on_disconnect=cancel_on_disconnect,
cancel_run=facade.cancel,
run_id=run_id,
),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.get("/{thread_id}/runs/{run_id}/join")
async def join_existing_run(
thread_id: str,
run_id: str,
request: Request,
cancel_on_disconnect: bool = False,
) -> JSONValue:
# Accepted for API compatibility; current join_wait path does not change
# behavior based on client disconnect.
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record = await facade.get_run(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
adapted = adapt_join_wait_request(
thread_id=thread_id,
run_id=run_id,
headers=dict(request.headers),
query=dict(request.query_params),
)
with bind_request_actor_context(request):
return await facade.join_wait(run_id, last_event_id=adapted.last_event_id)
@router.post("/{thread_id}/runs/{run_id}/cancel")
async def cancel_existing_run(
thread_id: str,
run_id: str,
request: Request,
wait: bool = False,
action: Literal["interrupt", "rollback"] = "interrupt",
) -> JSONValue:
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record = await facade.get_run(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
with bind_request_actor_context(request):
cancelled = await facade.cancel(run_id, action=action)
if not cancelled:
raise HTTPException(status_code=409, detail=f"Run {run_id} is not cancellable")
if wait:
with bind_request_actor_context(request):
return await facade.join_wait(run_id)
return {}
@router.delete("/{thread_id}/runs/{run_id}", response_model=RunDeleteResponse)
async def delete_run(
thread_id: str,
run_id: str,
request: Request,
) -> RunDeleteResponse:
facade = build_runs_facade_from_request(request)
with bind_request_actor_context(request):
record = await facade.get_run(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
with bind_request_actor_context(request):
deleted = await facade.delete_run(run_id)
return RunDeleteResponse(deleted=deleted)

View File

@ -0,0 +1,132 @@
import json
import logging
from fastapi import APIRouter
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api", tags=["suggestions"])
class SuggestionMessage(BaseModel):
role: str = Field(..., description="Message role: user|assistant")
content: str = Field(..., description="Message content as plain text")
class SuggestionsRequest(BaseModel):
messages: list[SuggestionMessage] = Field(..., description="Recent conversation messages")
n: int = Field(default=3, ge=1, le=5, description="Number of suggestions to generate")
model_name: str | None = Field(default=None, description="Optional model override")
class SuggestionsResponse(BaseModel):
suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions")
def _strip_markdown_code_fence(text: str) -> str:
stripped = text.strip()
if not stripped.startswith("```"):
return stripped
lines = stripped.splitlines()
if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"):
return "\n".join(lines[1:-1]).strip()
return stripped
def _parse_json_string_list(text: str) -> list[str] | None:
candidate = _strip_markdown_code_fence(text)
start = candidate.find("[")
end = candidate.rfind("]")
if start == -1 or end == -1 or end <= start:
return None
candidate = candidate[start : end + 1]
try:
data = json.loads(candidate)
except Exception:
return None
if not isinstance(data, list):
return None
out: list[str] = []
for item in data:
if not isinstance(item, str):
continue
s = item.strip()
if not s:
continue
out.append(s)
return out
def _extract_response_text(content: object) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for block in content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, dict) and block.get("type") in {"text", "output_text"}:
text = block.get("text")
if isinstance(text, str):
parts.append(text)
return "\n".join(parts) if parts else ""
if content is None:
return ""
return str(content)
def _format_conversation(messages: list[SuggestionMessage]) -> str:
parts: list[str] = []
for m in messages:
role = m.role.strip().lower()
if role in ("user", "human"):
parts.append(f"User: {m.content.strip()}")
elif role in ("assistant", "ai"):
parts.append(f"Assistant: {m.content.strip()}")
else:
parts.append(f"{m.role}: {m.content.strip()}")
return "\n".join(parts).strip()
@router.post(
"/threads/{thread_id}/suggestions",
response_model=SuggestionsResponse,
summary="Generate Follow-up Questions",
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
)
async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse:
if not request.messages:
return SuggestionsResponse(suggestions=[])
n = request.n
conversation = _format_conversation(request.messages)
if not conversation:
return SuggestionsResponse(suggestions=[])
system_instruction = (
"You are generating follow-up questions to help the user continue the conversation.\n"
f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n"
"Requirements:\n"
"- Questions must be relevant to the preceding conversation.\n"
"- Questions must be written in the same language as the user.\n"
"- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n"
"- Do NOT include numbering, markdown, or any extra text.\n"
"- Output MUST be a JSON array of strings only.\n"
)
user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions"
try:
model = create_chat_model(name=request.model_name, thinking_enabled=False)
response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)])
raw = _extract_response_text(response.content)
suggestions = _parse_json_string_list(raw) or []
cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()]
cleaned = cleaned[:n]
return SuggestionsResponse(suggestions=cleaned)
except Exception as exc:
logger.exception("Failed to generate suggestions: thread_id=%s err=%s", thread_id, exc)
return SuggestionsResponse(suggestions=[])

View File

@ -0,0 +1,455 @@
"""Thread management endpoints.
Provides CRUD operations for threads and checkpoint state management.
"""
from __future__ import annotations
import logging
import time
import uuid
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
from app.gateway.dependencies import CurrentCheckpointer, CurrentRunRepository, CurrentThreadMetaStorage
from app.infra.storage import ThreadMetaStorage
from app.plugins.auth.security.actor_context import bind_request_actor_context, resolve_request_user_id
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values
logger = logging.getLogger(__name__)
router = APIRouter(tags=["threads"])
# ---------------------------------------------------------------------------
# Request / Response Models
# ---------------------------------------------------------------------------
class ThreadCreateRequest(BaseModel):
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
class ThreadSearchRequest(BaseModel):
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)")
limit: int = Field(default=100, ge=1, le=1000, description="Maximum results")
offset: int = Field(default=0, ge=0, description="Pagination offset")
status: str | None = Field(default=None, description="Filter by thread status")
user_id: str | None = Field(default=None, description="Filter by user ID")
assistant_id: str | None = Field(default=None, description="Filter by assistant ID")
class ThreadResponse(BaseModel):
thread_id: str = Field(description="Unique thread identifier")
status: str = Field(default="idle", description="Thread status")
created_at: str = Field(default="", description="ISO timestamp")
updated_at: str = Field(default="", description="ISO timestamp")
metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata")
values: dict[str, Any] = Field(default_factory=dict, description="Current state values")
interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts")
class ThreadDeleteResponse(BaseModel):
success: bool
message: str
class ThreadStateUpdateRequest(BaseModel):
values: dict[str, Any] | None = Field(default=None, description="Channel values to merge")
checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from")
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
as_node: str | None = Field(default=None, description="Node identity for the update")
class ThreadStateResponse(BaseModel):
values: dict[str, Any] = Field(default_factory=dict, description="Current channel values")
next: list[str] = Field(default_factory=list, description="Next nodes to execute")
tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details")
checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info")
checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID")
parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID")
metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata")
created_at: str | None = Field(default=None, description="Checkpoint timestamp")
class ThreadHistoryRequest(BaseModel):
limit: int = Field(default=10, ge=1, le=100, description="Maximum entries")
before: str | None = Field(default=None, description="Cursor for pagination (checkpoint_id)")
class HistoryEntry(BaseModel):
checkpoint_id: str
parent_checkpoint_id: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
values: dict[str, Any] = Field(default_factory=dict)
created_at: str | None = None
next: list[str] = Field(default_factory=list)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def sanitize_log_param(value: str) -> str:
"""Strip control characters to prevent log injection."""
return value.replace("\n", "").replace("\r", "").replace("\x00", "")
def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse:
"""Delete local filesystem data for a thread."""
path_manager = paths or get_paths()
try:
path_manager.delete_thread_dir(thread_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
except FileNotFoundError:
logger.debug("No local thread data to delete for %s", sanitize_log_param(thread_id))
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
except Exception as exc:
logger.exception("Failed to delete thread data for %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
logger.info("Deleted local thread data for %s", sanitize_log_param(thread_id))
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
async def _thread_or_run_exists(
*,
request: Request,
thread_id: str,
thread_meta_storage: ThreadMetaStorage,
run_repo,
) -> bool:
request_user_id = resolve_request_user_id(request)
if request_user_id is None:
thread = await thread_meta_storage.get_thread(thread_id, user_id=None)
if thread is not None:
return True
runs = await run_repo.list_by_thread(thread_id, limit=1, user_id=None)
return bool(runs)
with bind_request_actor_context(request):
thread = await thread_meta_storage.get_thread(thread_id)
if thread is not None:
return True
runs = await run_repo.list_by_thread(thread_id, limit=1)
return bool(runs)
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("", response_model=ThreadResponse)
async def create_thread(
body: ThreadCreateRequest,
request: Request,
thread_meta_storage: CurrentThreadMetaStorage,
) -> ThreadResponse:
"""Create a new thread."""
thread_id = body.thread_id or str(uuid.uuid4())
request_user_id = resolve_request_user_id(request)
if request_user_id is None:
existing = await thread_meta_storage.get_thread(thread_id, user_id=None)
else:
with bind_request_actor_context(request):
existing = await thread_meta_storage.get_thread(thread_id)
if existing is not None:
return ThreadResponse(
thread_id=thread_id,
status=existing.status,
created_at=existing.created_time.isoformat() if existing.created_time else "",
updated_at=existing.updated_time.isoformat() if existing.updated_time else "",
metadata=existing.metadata,
)
try:
if request_user_id is None:
created = await thread_meta_storage.ensure_thread(
thread_id=thread_id,
assistant_id=body.assistant_id,
metadata=body.metadata,
user_id=None,
)
else:
with bind_request_actor_context(request):
created = await thread_meta_storage.ensure_thread(
thread_id=thread_id,
assistant_id=body.assistant_id,
metadata=body.metadata,
)
except Exception:
logger.exception("Failed to create thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to create thread")
logger.info("Thread created: %s", sanitize_log_param(thread_id))
return ThreadResponse(
thread_id=thread_id,
status=created.status,
created_at=created.created_time.isoformat() if created.created_time else "",
updated_at=created.updated_time.isoformat() if created.updated_time else "",
metadata=created.metadata,
)
@router.post("/search", response_model=list[ThreadResponse])
async def search_threads(
body: ThreadSearchRequest,
request: Request,
thread_meta_storage: CurrentThreadMetaStorage,
) -> list[ThreadResponse]:
"""Search threads with filters."""
try:
request_user_id = resolve_request_user_id(request)
if request_user_id is None:
threads = await thread_meta_storage.search_threads(
metadata=body.metadata or None,
status=body.status,
user_id=body.user_id,
assistant_id=body.assistant_id,
limit=body.limit,
offset=body.offset,
)
else:
with bind_request_actor_context(request):
threads = await thread_meta_storage.search_threads(
metadata=body.metadata or None,
status=body.status,
assistant_id=body.assistant_id,
limit=body.limit,
offset=body.offset,
)
except Exception:
logger.exception("Failed to search threads")
raise HTTPException(status_code=500, detail="Failed to search threads")
return [
ThreadResponse(
thread_id=t.thread_id,
status=t.status,
created_at=t.created_time.isoformat() if t.created_time else "",
updated_at=t.updated_time.isoformat() if t.updated_time else "",
metadata=t.metadata,
values={"title": t.display_name} if t.display_name else {},
interrupts={},
)
for t in threads
]
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
async def delete_thread(
thread_id: str,
checkpointer: CurrentCheckpointer,
thread_meta_storage: CurrentThreadMetaStorage,
) -> ThreadDeleteResponse:
"""Delete a thread and all associated data."""
response = _delete_thread_data(thread_id)
# Remove checkpoints (best-effort)
try:
if hasattr(checkpointer, "adelete_thread"):
await checkpointer.adelete_thread(thread_id)
except Exception:
logger.debug("Could not delete checkpoints for thread %s", sanitize_log_param(thread_id))
# Remove thread_meta (best-effort)
try:
await thread_meta_storage.delete_thread(thread_id)
except Exception:
logger.debug("Could not delete thread_meta for %s", sanitize_log_param(thread_id))
return response
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
async def get_thread_state(
thread_id: str,
request: Request,
checkpointer: CurrentCheckpointer,
thread_meta_storage: CurrentThreadMetaStorage,
run_repo: CurrentRunRepository,
) -> ThreadStateResponse:
"""Get the latest state snapshot for a thread."""
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
except Exception:
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None:
if await _thread_or_run_exists(
request=request,
thread_id=thread_id,
thread_meta_storage=thread_meta_storage,
run_repo=run_repo,
):
return ThreadStateResponse()
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
channel_values = checkpoint.get("channel_values", {})
ckpt_config = getattr(checkpoint_tuple, "config", {}) or {}
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id")
parent_config = getattr(checkpoint_tuple, "parent_config", None)
parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id") if parent_config else None
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
next_nodes = [t.name for t in tasks_raw if hasattr(t, "name")]
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
return ThreadStateResponse(
values=serialize_channel_values(channel_values),
next=next_nodes,
tasks=tasks,
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
checkpoint_id=checkpoint_id,
parent_checkpoint_id=parent_checkpoint_id,
metadata=metadata,
created_at=str(metadata.get("created_at", "")),
)
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
async def update_thread_state(
thread_id: str,
body: ThreadStateUpdateRequest,
checkpointer: CurrentCheckpointer,
thread_meta_storage: CurrentThreadMetaStorage,
) -> ThreadStateResponse:
"""Update thread state (human-in-the-loop or title rename)."""
read_config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
if body.checkpoint_id:
read_config["configurable"]["checkpoint_id"] = body.checkpoint_id
try:
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
except Exception:
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {})
metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {})
channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {}))
if body.values:
channel_values.update(body.values)
checkpoint["channel_values"] = channel_values
metadata["updated_at"] = time.time()
if body.as_node:
metadata["source"] = "update"
metadata["step"] = metadata.get("step", 0) + 1
metadata["writes"] = {body.as_node: body.values}
write_config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
except Exception:
logger.exception("Failed to update state for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to update thread state")
new_checkpoint_id: str | None = None
if isinstance(new_config, dict):
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
# Sync title to thread_meta
if body.values and "title" in body.values:
new_title = body.values["title"]
if new_title:
try:
await thread_meta_storage.sync_thread_title(
thread_id=thread_id,
title=new_title,
)
except Exception:
logger.debug("Failed to sync title for %s", sanitize_log_param(thread_id))
return ThreadStateResponse(
values=serialize_channel_values(channel_values),
next=[],
metadata=metadata,
checkpoint_id=new_checkpoint_id,
created_at=str(metadata.get("created_at", "")),
)
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
async def get_thread_history(
thread_id: str,
body: ThreadHistoryRequest,
request: Request,
checkpointer: CurrentCheckpointer,
thread_meta_storage: CurrentThreadMetaStorage,
run_repo: CurrentRunRepository,
) -> list[HistoryEntry]:
"""Get checkpoint history for a thread."""
config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
if body.before:
config["configurable"]["checkpoint_id"] = body.before
entries: list[HistoryEntry] = []
is_first = True
try:
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
ckpt_config = getattr(checkpoint_tuple, "config", {}) or {}
parent_config = getattr(checkpoint_tuple, "parent_config", None)
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "")
parent_id = parent_config.get("configurable", {}).get("checkpoint_id") if parent_config else None
channel_values = checkpoint.get("channel_values", {})
values: dict[str, Any] = {}
if title := channel_values.get("title"):
values["title"] = title
if is_first and (messages := channel_values.get("messages")):
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
is_first = False
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
next_nodes = [t.name for t in tasks_raw if hasattr(t, "name")]
entries.append(
HistoryEntry(
checkpoint_id=checkpoint_id,
parent_checkpoint_id=parent_id,
metadata=metadata,
values=values,
created_at=str(metadata.get("created_at", "")),
next=next_nodes,
)
)
except Exception:
logger.exception("Failed to get history for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread history")
if not entries and await _thread_or_run_exists(
request=request,
thread_id=thread_id,
thread_meta_storage=thread_meta_storage,
run_repo=run_repo,
):
return []
return entries

View File

@ -1,8 +1,9 @@
"""Memory API router for retrieving and managing global memory data."""
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field
from app.plugins.auth.security.actor_context import bind_request_actor_context
from deerflow.agents.memory.updater import (
clear_memory_data,
create_memory_fact,
@ -13,7 +14,7 @@ from deerflow.agents.memory.updater import (
update_memory_fact,
)
from deerflow.config.memory_config import get_memory_config
from deerflow.runtime.user_context import get_effective_user_id
from deerflow.runtime.actor_context import get_effective_user_id
router = APIRouter(prefix="/api", tags=["memory"])
@ -114,7 +115,7 @@ class MemoryStatusResponse(BaseModel):
summary="Get Memory Data",
description="Retrieve the current global memory data including user context, history, and facts.",
)
async def get_memory() -> MemoryResponse:
async def get_memory(request: Request) -> MemoryResponse:
"""Get the current global memory data.
Returns:
@ -148,8 +149,9 @@ async def get_memory() -> MemoryResponse:
}
```
"""
memory_data = get_memory_data(user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
with bind_request_actor_context(request):
memory_data = get_memory_data(user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
@router.post(
@ -159,7 +161,7 @@ async def get_memory() -> MemoryResponse:
summary="Reload Memory Data",
description="Reload memory data from the storage file, refreshing the in-memory cache.",
)
async def reload_memory() -> MemoryResponse:
async def reload_memory(request: Request) -> MemoryResponse:
"""Reload memory data from file.
This forces a reload of the memory data from the storage file,
@ -168,8 +170,9 @@ async def reload_memory() -> MemoryResponse:
Returns:
The reloaded memory data.
"""
memory_data = reload_memory_data(user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
with bind_request_actor_context(request):
memory_data = reload_memory_data(user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
@router.delete(
@ -179,14 +182,15 @@ async def reload_memory() -> MemoryResponse:
summary="Clear All Memory Data",
description="Delete all saved memory data and reset the memory structure to an empty state.",
)
async def clear_memory() -> MemoryResponse:
async def clear_memory(request: Request) -> MemoryResponse:
"""Clear all persisted memory data."""
try:
memory_data = clear_memory_data(user_id=get_effective_user_id())
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
with bind_request_actor_context(request):
try:
memory_data = clear_memory_data(user_id=get_effective_user_id())
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc
return MemoryResponse(**memory_data)
return MemoryResponse(**memory_data)
@router.post(
@ -196,21 +200,22 @@ async def clear_memory() -> MemoryResponse:
summary="Create Memory Fact",
description="Create a single saved memory fact manually.",
)
async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryResponse:
async def create_memory_fact_endpoint(request: Request, payload: FactCreateRequest) -> MemoryResponse:
"""Create a single fact manually."""
try:
memory_data = create_memory_fact(
content=request.content,
category=request.category,
confidence=request.confidence,
user_id=get_effective_user_id(),
)
except ValueError as exc:
raise _map_memory_fact_value_error(exc) from exc
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to create memory fact.") from exc
with bind_request_actor_context(request):
try:
memory_data = create_memory_fact(
content=payload.content,
category=payload.category,
confidence=payload.confidence,
user_id=get_effective_user_id(),
)
except ValueError as exc:
raise _map_memory_fact_value_error(exc) from exc
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to create memory fact.") from exc
return MemoryResponse(**memory_data)
return MemoryResponse(**memory_data)
@router.delete(
@ -220,16 +225,17 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo
summary="Delete Memory Fact",
description="Delete a single saved memory fact by its fact id.",
)
async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
async def delete_memory_fact_endpoint(fact_id: str, request: Request) -> MemoryResponse:
"""Delete a single fact from memory by fact id."""
try:
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id())
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to delete memory fact.") from exc
with bind_request_actor_context(request):
try:
memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id())
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to delete memory fact.") from exc
return MemoryResponse(**memory_data)
return MemoryResponse(**memory_data)
@router.patch(
@ -239,24 +245,25 @@ async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse:
summary="Patch Memory Fact",
description="Partially update a single saved memory fact by its fact id while preserving omitted fields.",
)
async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -> MemoryResponse:
async def update_memory_fact_endpoint(fact_id: str, request: Request, payload: FactPatchRequest) -> MemoryResponse:
"""Partially update a single fact manually."""
try:
memory_data = update_memory_fact(
fact_id=fact_id,
content=request.content,
category=request.category,
confidence=request.confidence,
user_id=get_effective_user_id(),
)
except ValueError as exc:
raise _map_memory_fact_value_error(exc) from exc
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to update memory fact.") from exc
with bind_request_actor_context(request):
try:
memory_data = update_memory_fact(
fact_id=fact_id,
content=payload.content,
category=payload.category,
confidence=payload.confidence,
user_id=get_effective_user_id(),
)
except ValueError as exc:
raise _map_memory_fact_value_error(exc) from exc
except KeyError as exc:
raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to update memory fact.") from exc
return MemoryResponse(**memory_data)
return MemoryResponse(**memory_data)
@router.get(
@ -266,10 +273,11 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -
summary="Export Memory Data",
description="Export the current global memory data as JSON for backup or transfer.",
)
async def export_memory() -> MemoryResponse:
async def export_memory(request: Request) -> MemoryResponse:
"""Export the current memory data."""
memory_data = get_memory_data(user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
with bind_request_actor_context(request):
memory_data = get_memory_data(user_id=get_effective_user_id())
return MemoryResponse(**memory_data)
@router.post(
@ -279,14 +287,15 @@ async def export_memory() -> MemoryResponse:
summary="Import Memory Data",
description="Import and overwrite the current global memory data from a JSON payload.",
)
async def import_memory(request: MemoryResponse) -> MemoryResponse:
async def import_memory(request: Request, payload: MemoryResponse) -> MemoryResponse:
"""Import and persist memory data."""
try:
memory_data = import_memory_data(request.model_dump(), user_id=get_effective_user_id())
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
with bind_request_actor_context(request):
try:
memory_data = import_memory_data(payload.model_dump(), user_id=get_effective_user_id())
except OSError as exc:
raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc
return MemoryResponse(**memory_data)
return MemoryResponse(**memory_data)
@router.get(
@ -333,24 +342,25 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
summary="Get Memory Status",
description="Retrieve both memory configuration and current data in a single request.",
)
async def get_memory_status() -> MemoryStatusResponse:
async def get_memory_status(request: Request) -> MemoryStatusResponse:
"""Get the memory system status including configuration and data.
Returns:
Combined memory configuration and current data.
"""
config = get_memory_config()
memory_data = get_memory_data(user_id=get_effective_user_id())
with bind_request_actor_context(request):
config = get_memory_config()
memory_data = get_memory_data(user_id=get_effective_user_id())
return MemoryStatusResponse(
config=MemoryConfigResponse(
enabled=config.enabled,
storage_path=config.storage_path,
debounce_seconds=config.debounce_seconds,
max_facts=config.max_facts,
fact_confidence_threshold=config.fact_confidence_threshold,
injection_enabled=config.injection_enabled,
max_injection_tokens=config.max_injection_tokens,
),
data=MemoryResponse(**memory_data),
)
return MemoryStatusResponse(
config=MemoryConfigResponse(
enabled=config.enabled,
storage_path=config.storage_path,
debounce_seconds=config.debounce_seconds,
max_facts=config.max_facts,
fact_confidence_threshold=config.fact_confidence_threshold,
injection_enabled=config.injection_enabled,
max_injection_tokens=config.max_injection_tokens,
),
data=MemoryResponse(**memory_data),
)

View File

@ -1,143 +0,0 @@
"""Stateless runs endpoints -- stream and wait without a pre-existing thread.
These endpoints auto-create a temporary thread when no ``thread_id`` is
supplied in the request body. When a ``thread_id`` **is** provided, it
is reused so that conversation history is preserved across calls.
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import StreamingResponse
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.routers.thread_runs import RunCreateRequest
from app.gateway.services import sse_consumer, start_run
from deerflow.runtime import serialize_channel_values
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/runs", tags=["runs"])
def _resolve_thread_id(body: RunCreateRequest) -> str:
"""Return the thread_id from the request body, or generate a new one."""
thread_id = (body.config or {}).get("configurable", {}).get("thread_id")
if thread_id:
return str(thread_id)
return str(uuid.uuid4())
@router.post("/stream")
async def stateless_stream(body: RunCreateRequest, request: Request) -> StreamingResponse:
"""Create a run and stream events via SSE.
If ``config.configurable.thread_id`` is provided, the run is created
on the given thread so that conversation history is preserved.
Otherwise a new temporary thread is created.
"""
thread_id = _resolve_thread_id(body)
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = await start_run(body, thread_id, request)
return StreamingResponse(
sse_consumer(bridge, record, request, run_mgr),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
},
)
@router.post("/wait", response_model=dict)
async def stateless_wait(body: RunCreateRequest, request: Request) -> dict:
"""Create a run and block until completion.
If ``config.configurable.thread_id`` is provided, the run is created
on the given thread so that conversation history is preserved.
Otherwise a new temporary thread is created.
"""
thread_id = _resolve_thread_id(body)
record = await start_run(body, thread_id, request)
if record.task is not None:
try:
await record.task
except asyncio.CancelledError:
pass
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
return {"status": record.status.value, "error": record.error}
# ---------------------------------------------------------------------------
# Run-scoped read endpoints
# ---------------------------------------------------------------------------
async def _resolve_run(run_id: str, request: Request) -> dict:
"""Fetch run by run_id with user ownership check. Raises 404 if not found."""
run_store = get_run_store(request)
record = await run_store.get(run_id) # user_id=AUTO filters by contextvar
if record is None:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
return record
@router.get("/{run_id}/messages")
@require_permission("runs", "read")
async def run_messages(
run_id: str,
request: Request,
limit: int = Query(default=50, le=200, ge=1),
before_seq: int | None = Query(default=None),
after_seq: int | None = Query(default=None),
) -> dict:
"""Return paginated messages for a run (cursor-based).
Pagination:
- after_seq: messages with seq > after_seq (forward)
- before_seq: messages with seq < before_seq (backward)
- neither: latest messages
Response: { data: [...], has_more: bool }
"""
run = await _resolve_run(run_id, request)
event_store = get_run_event_store(request)
rows = await event_store.list_messages_by_run(
run["thread_id"],
run_id,
limit=limit + 1,
before_seq=before_seq,
after_seq=after_seq,
)
has_more = len(rows) > limit
data = rows[:limit] if has_more else rows
return {"data": data, "has_more": has_more}
@router.get("/{run_id}/feedback")
@require_permission("runs", "read")
async def run_feedback(run_id: str, request: Request) -> list[dict]:
"""Return all feedback for a run."""
run = await _resolve_run(run_id, request)
feedback_repo = get_feedback_repo(request)
return await feedback_repo.list_by_run(run["thread_id"], run_id)

View File

@ -5,7 +5,6 @@ from fastapi import APIRouter, Request
from langchain_core.messages import HumanMessage, SystemMessage
from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from deerflow.models import create_chat_model
logger = logging.getLogger(__name__)
@ -99,7 +98,6 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str:
summary="Generate Follow-up Questions",
description="Generate short follow-up questions a user might ask next, based on recent conversation context.",
)
@require_permission("threads", "read", owner_check=True)
async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse:
if not body.messages:
return SuggestionsResponse(suggestions=[])

View File

@ -1,377 +0,0 @@
"""Runs endpoints — create, stream, wait, cancel.
Implements the LangGraph Platform runs API on top of
:class:`deerflow.agents.runs.RunManager` and
:class:`deerflow.agents.stream_bridge.StreamBridge`.
SSE format is aligned with the LangGraph Platform protocol so that
the ``useStream`` React hook from ``@langchain/langgraph-sdk/react``
works without modification.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Any, Literal
from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import Response, StreamingResponse
from pydantic import BaseModel, Field
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.services import sse_consumer, start_run
from deerflow.runtime import RunRecord, serialize_channel_values
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["runs"])
# ---------------------------------------------------------------------------
# Request / response models
# ---------------------------------------------------------------------------
class RunCreateRequest(BaseModel):
assistant_id: str | None = Field(default=None, description="Agent / assistant to use")
input: dict[str, Any] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})")
command: dict[str, Any] | None = Field(default=None, description="LangGraph Command")
metadata: dict[str, Any] | None = Field(default=None, description="Run metadata")
config: dict[str, Any] | None = Field(default=None, description="RunnableConfig overrides")
context: dict[str, Any] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
webhook: str | None = Field(default=None, description="Completion callback URL")
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before")
interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after")
stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)")
stream_subgraphs: bool = Field(default=False, description="Include subgraph events")
stream_resumable: bool | None = Field(default=None, description="SSE resumable mode")
on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect")
on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion")
multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy")
after_seconds: float | None = Field(default=None, description="Delayed execution")
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
class RunResponse(BaseModel):
run_id: str
thread_id: str
assistant_id: str | None = None
status: str
metadata: dict[str, Any] = Field(default_factory=dict)
kwargs: dict[str, Any] = Field(default_factory=dict)
multitask_strategy: str = "reject"
created_at: str = ""
updated_at: str = ""
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _record_to_response(record: RunRecord) -> RunResponse:
return RunResponse(
run_id=record.run_id,
thread_id=record.thread_id,
assistant_id=record.assistant_id,
status=record.status.value,
metadata=record.metadata,
kwargs=record.kwargs,
multitask_strategy=record.multitask_strategy,
created_at=record.created_at,
updated_at=record.updated_at,
)
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.post("/{thread_id}/runs", response_model=RunResponse)
@require_permission("runs", "create", owner_check=True, require_existing=True)
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
"""Create a background run (returns immediately)."""
record = await start_run(body, thread_id, request)
return _record_to_response(record)
@router.post("/{thread_id}/runs/stream")
@require_permission("runs", "create", owner_check=True, require_existing=True)
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
"""Create a run and stream events via SSE.
The response includes a ``Content-Location`` header with the run's
resource URL, matching the LangGraph Platform protocol. The
``useStream`` React hook uses this to extract run metadata.
"""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = await start_run(body, thread_id, request)
return StreamingResponse(
sse_consumer(bridge, record, request, run_mgr),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
# LangGraph Platform includes run metadata in this header.
# The SDK uses a greedy regex to extract the run id from this path,
# so it must point at the canonical run resource without extra suffixes.
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
},
)
@router.post("/{thread_id}/runs/wait", response_model=dict)
@require_permission("runs", "create", owner_check=True, require_existing=True)
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
"""Create a run and block until it completes, returning the final state."""
record = await start_run(body, thread_id, request)
if record.task is not None:
try:
await record.task
except asyncio.CancelledError:
pass
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
if checkpoint_tuple is not None:
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
channel_values = checkpoint.get("channel_values", {})
return serialize_channel_values(channel_values)
except Exception:
logger.exception("Failed to fetch final state for run %s", record.run_id)
return {"status": record.status.value, "error": record.error}
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
@require_permission("runs", "read", owner_check=True)
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
"""List all runs for a thread."""
run_mgr = get_run_manager(request)
records = await run_mgr.list_by_thread(thread_id)
return [_record_to_response(r) for r in records]
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
@require_permission("runs", "read", owner_check=True)
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
"""Get details of a specific run."""
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
return _record_to_response(record)
@router.post("/{thread_id}/runs/{run_id}/cancel")
@require_permission("runs", "cancel", owner_check=True, require_existing=True)
async def cancel_run(
thread_id: str,
run_id: str,
request: Request,
wait: bool = Query(default=False, description="Block until run completes after cancel"),
action: Literal["interrupt", "rollback"] = Query(default="interrupt", description="Cancel action"),
) -> Response:
"""Cancel a running or pending run.
- action=interrupt: Stop execution, keep current checkpoint (can be resumed)
- action=rollback: Stop execution, revert to pre-run checkpoint state
- wait=true: Block until the run fully stops, return 204
- wait=false: Return immediately with 202
"""
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
cancelled = await run_mgr.cancel(run_id, action=action)
if not cancelled:
raise HTTPException(
status_code=409,
detail=f"Run {run_id} is not cancellable (status: {record.status.value})",
)
if wait and record.task is not None:
try:
await record.task
except asyncio.CancelledError:
pass
return Response(status_code=204)
return Response(status_code=202)
@router.get("/{thread_id}/runs/{run_id}/join")
@require_permission("runs", "read", owner_check=True)
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
"""Join an existing run's SSE stream."""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
return StreamingResponse(
sse_consumer(bridge, record, request, run_mgr),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
@require_permission("runs", "read", owner_check=True)
async def stream_existing_run(
thread_id: str,
run_id: str,
request: Request,
action: Literal["interrupt", "rollback"] | None = Query(default=None, description="Cancel action"),
wait: int = Query(default=0, description="Block until cancelled (1) or return immediately (0)"),
):
"""Join an existing run's SSE stream (GET), or cancel-then-stream (POST).
The LangGraph SDK's ``joinStream`` and ``useStream`` stop button both use
``POST`` to this endpoint. When ``action=interrupt`` or ``action=rollback``
is present the run is cancelled first; the response then streams any
remaining buffered events so the client observes a clean shutdown.
"""
run_mgr = get_run_manager(request)
record = run_mgr.get(run_id)
if record is None or record.thread_id != thread_id:
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
# Cancel if an action was requested (stop-button / interrupt flow)
if action is not None:
cancelled = await run_mgr.cancel(run_id, action=action)
if cancelled and wait and record.task is not None:
try:
await record.task
except (asyncio.CancelledError, Exception):
pass
return Response(status_code=204)
bridge = get_stream_bridge(request)
return StreamingResponse(
sse_consumer(bridge, record, request, run_mgr),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
# ---------------------------------------------------------------------------
# Messages / Events / Token usage endpoints
# ---------------------------------------------------------------------------
@router.get("/{thread_id}/messages")
@require_permission("runs", "read", owner_check=True)
async def list_thread_messages(
thread_id: str,
request: Request,
limit: int = Query(default=50, le=200),
before_seq: int | None = Query(default=None),
after_seq: int | None = Query(default=None),
) -> list[dict]:
"""Return displayable messages for a thread (across all runs), with feedback attached."""
event_store = get_run_event_store(request)
messages = await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
# Attach feedback to the last AI message of each run
feedback_repo = get_feedback_repo(request)
user_id = await get_current_user(request)
feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id)
# Find the last ai_message per run_id
last_ai_per_run: dict[str, int] = {} # run_id -> index in messages list
for i, msg in enumerate(messages):
if msg.get("event_type") == "ai_message":
last_ai_per_run[msg["run_id"]] = i
# Attach feedback field
last_ai_indices = set(last_ai_per_run.values())
for i, msg in enumerate(messages):
if i in last_ai_indices:
run_id = msg["run_id"]
fb = feedback_map.get(run_id)
msg["feedback"] = (
{
"feedback_id": fb["feedback_id"],
"rating": fb["rating"],
"comment": fb.get("comment"),
}
if fb
else None
)
else:
msg["feedback"] = None
return messages
@router.get("/{thread_id}/runs/{run_id}/messages")
@require_permission("runs", "read", owner_check=True)
async def list_run_messages(
thread_id: str,
run_id: str,
request: Request,
limit: int = Query(default=50, le=200, ge=1),
before_seq: int | None = Query(default=None),
after_seq: int | None = Query(default=None),
) -> dict:
"""Return paginated messages for a specific run.
Response: { data: [...], has_more: bool }
"""
event_store = get_run_event_store(request)
rows = await event_store.list_messages_by_run(
thread_id,
run_id,
limit=limit + 1,
before_seq=before_seq,
after_seq=after_seq,
)
has_more = len(rows) > limit
data = rows[:limit] if has_more else rows
return {"data": data, "has_more": has_more}
@router.get("/{thread_id}/runs/{run_id}/events")
@require_permission("runs", "read", owner_check=True)
async def list_run_events(
thread_id: str,
run_id: str,
request: Request,
event_types: str | None = Query(default=None),
limit: int = Query(default=500, le=2000),
) -> list[dict]:
"""Return the full event stream for a run (debug/audit)."""
event_store = get_run_event_store(request)
types = event_types.split(",") if event_types else None
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
@router.get("/{thread_id}/token-usage")
@require_permission("threads", "read", owner_check=True)
async def thread_token_usage(thread_id: str, request: Request) -> dict:
"""Thread-level token usage aggregation."""
run_store = get_run_store(request)
agg = await run_store.aggregate_tokens_by_thread(thread_id)
return {"thread_id": thread_id, **agg}

View File

@ -1,621 +0,0 @@
"""Thread CRUD, state, and history endpoints.
Combines the existing thread-local filesystem cleanup with LangGraph
Platform-compatible thread management backed by the checkpointer.
Channel values returned in state responses are serialized through
:func:`deerflow.runtime.serialization.serialize_channel_values` to
ensure LangChain message objects are converted to JSON-safe dicts
matching the LangGraph Platform wire format expected by the
``useStream`` React hook.
"""
from __future__ import annotations
import logging
import re
import time
import uuid
from typing import Any
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel, Field, field_validator
from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer
from app.gateway.utils import sanitize_log_param
from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values
from deerflow.runtime.user_context import get_effective_user_id
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["threads"])
# Metadata keys that the server controls; clients are not allowed to set
# them. Pydantic ``@field_validator("metadata")`` strips them on every
# inbound model below so a malicious client cannot reflect a forged
# owner identity through the API surface. Defense-in-depth — the
# row-level invariant is still ``threads_meta.user_id`` populated from
# the auth contextvar; this list closes the metadata-blob echo gap.
_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"})
def _strip_reserved_metadata(metadata: dict[str, Any] | None) -> dict[str, Any]:
"""Return ``metadata`` with server-controlled keys removed."""
if not metadata:
return metadata or {}
return {k: v for k, v in metadata.items() if k not in _SERVER_RESERVED_METADATA_KEYS}
# ---------------------------------------------------------------------------
# Response / request models
# ---------------------------------------------------------------------------
class ThreadDeleteResponse(BaseModel):
"""Response model for thread cleanup."""
success: bool
message: str
class ThreadResponse(BaseModel):
"""Response model for a single thread."""
thread_id: str = Field(description="Unique thread identifier")
status: str = Field(default="idle", description="Thread status: idle, busy, interrupted, error")
created_at: str = Field(default="", description="ISO timestamp")
updated_at: str = Field(default="", description="ISO timestamp")
metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata")
values: dict[str, Any] = Field(default_factory=dict, description="Current state channel values")
interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts")
class ThreadCreateRequest(BaseModel):
"""Request body for creating a thread."""
thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)")
assistant_id: str | None = Field(default=None, description="Associate thread with an assistant")
metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata")
_strip_reserved = field_validator("metadata")(classmethod(lambda cls, v: _strip_reserved_metadata(v)))
class ThreadSearchRequest(BaseModel):
"""Request body for searching threads."""
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)")
limit: int = Field(default=100, ge=1, le=1000, description="Maximum results")
offset: int = Field(default=0, ge=0, description="Pagination offset")
status: str | None = Field(default=None, description="Filter by thread status")
class ThreadStateResponse(BaseModel):
"""Response model for thread state."""
values: dict[str, Any] = Field(default_factory=dict, description="Current channel values")
next: list[str] = Field(default_factory=list, description="Next tasks to execute")
metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata")
checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info")
checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID")
parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID")
created_at: str | None = Field(default=None, description="Checkpoint timestamp")
tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details")
class ThreadPatchRequest(BaseModel):
"""Request body for patching thread metadata."""
metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata to merge")
_strip_reserved = field_validator("metadata")(classmethod(lambda cls, v: _strip_reserved_metadata(v)))
class ThreadStateUpdateRequest(BaseModel):
"""Request body for updating thread state (human-in-the-loop resume)."""
values: dict[str, Any] | None = Field(default=None, description="Channel values to merge")
checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from")
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
as_node: str | None = Field(default=None, description="Node identity for the update")
class HistoryEntry(BaseModel):
"""Single checkpoint history entry."""
checkpoint_id: str
parent_checkpoint_id: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
values: dict[str, Any] = Field(default_factory=dict)
created_at: str | None = None
next: list[str] = Field(default_factory=list)
class ThreadHistoryRequest(BaseModel):
"""Request body for checkpoint history."""
limit: int = Field(default=10, ge=1, le=100, description="Maximum entries")
before: str | None = Field(default=None, description="Cursor for pagination")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _delete_thread_data(thread_id: str, paths: Paths | None = None, *, user_id: str | None = None) -> ThreadDeleteResponse:
"""Delete local persisted filesystem data for a thread."""
path_manager = paths or get_paths()
try:
path_manager.delete_thread_dir(thread_id, user_id=user_id)
except ValueError as exc:
raise HTTPException(status_code=422, detail=str(exc)) from exc
except FileNotFoundError:
# Not critical — thread data may not exist on disk
logger.debug("No local thread data to delete for %s", sanitize_log_param(thread_id))
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
except Exception as exc:
logger.exception("Failed to delete thread data for %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
logger.info("Deleted local thread data for %s", sanitize_log_param(thread_id))
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
def _derive_thread_status(checkpoint_tuple) -> str:
"""Derive thread status from checkpoint metadata."""
if checkpoint_tuple is None:
return "idle"
pending_writes = getattr(checkpoint_tuple, "pending_writes", None) or []
# Check for error in pending writes
for pw in pending_writes:
if len(pw) >= 2 and pw[1] == "__error__":
return "error"
# Check for pending next tasks (indicates interrupt)
tasks = getattr(checkpoint_tuple, "tasks", None)
if tasks:
return "interrupted"
return "idle"
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@router.delete("/{thread_id}", response_model=ThreadDeleteResponse)
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse:
"""Delete local persisted filesystem data for a thread.
Cleans DeerFlow-managed thread directories, removes checkpoint data,
and removes the thread_meta row from the configured ThreadMetaStore
(sqlite or memory).
"""
from app.gateway.deps import get_thread_store
# Clean local filesystem
response = _delete_thread_data(thread_id, user_id=get_effective_user_id())
# Remove checkpoints (best-effort)
checkpointer = getattr(request.app.state, "checkpointer", None)
if checkpointer is not None:
try:
if hasattr(checkpointer, "adelete_thread"):
await checkpointer.adelete_thread(thread_id)
except Exception:
logger.debug("Could not delete checkpoints for thread %s (not critical)", sanitize_log_param(thread_id))
# Remove thread_meta row (best-effort) — required for sqlite backend
# so the deleted thread no longer appears in /threads/search.
try:
thread_store = get_thread_store(request)
await thread_store.delete(thread_id)
except Exception:
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
return response
@router.post("", response_model=ThreadResponse)
async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse:
"""Create a new thread.
Writes a thread_meta record (so the thread appears in /threads/search)
and an empty checkpoint (so state endpoints work immediately).
Idempotent: returns the existing record when ``thread_id`` already exists.
"""
from app.gateway.deps import get_thread_store
checkpointer = get_checkpointer(request)
thread_store = get_thread_store(request)
thread_id = body.thread_id or str(uuid.uuid4())
now = time.time()
# ``body.metadata`` is already stripped of server-reserved keys by
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
# Idempotency: return existing record when already present
existing_record = await thread_store.get(thread_id)
if existing_record is not None:
return ThreadResponse(
thread_id=thread_id,
status=existing_record.get("status", "idle"),
created_at=str(existing_record.get("created_at", "")),
updated_at=str(existing_record.get("updated_at", "")),
metadata=existing_record.get("metadata", {}),
)
# Write thread_meta so the thread appears in /threads/search immediately
try:
await thread_store.create(
thread_id,
assistant_id=getattr(body, "assistant_id", None),
metadata=body.metadata,
)
except Exception:
logger.exception("Failed to write thread_meta for %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to create thread")
# Write an empty checkpoint so state endpoints work immediately
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
from langgraph.checkpoint.base import empty_checkpoint
ckpt_metadata = {
"step": -1,
"source": "input",
"writes": None,
"parents": {},
**body.metadata,
"created_at": now,
}
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
except Exception:
logger.exception("Failed to create checkpoint for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to create thread")
logger.info("Thread created: %s", sanitize_log_param(thread_id))
return ThreadResponse(
thread_id=thread_id,
status="idle",
created_at=str(now),
updated_at=str(now),
metadata=body.metadata,
)
@router.post("/search", response_model=list[ThreadResponse])
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
"""Search and list threads.
Delegates to the configured ThreadMetaStore implementation
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
"""
from app.gateway.deps import get_thread_store
repo = get_thread_store(request)
rows = await repo.search(
metadata=body.metadata or None,
status=body.status,
limit=body.limit,
offset=body.offset,
)
return [
ThreadResponse(
thread_id=r["thread_id"],
status=r.get("status", "idle"),
created_at=r.get("created_at", ""),
updated_at=r.get("updated_at", ""),
metadata=r.get("metadata", {}),
values={"title": r["display_name"]} if r.get("display_name") else {},
interrupts={},
)
for r in rows
]
@router.patch("/{thread_id}", response_model=ThreadResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
"""Merge metadata into a thread record."""
from app.gateway.deps import get_thread_store
thread_store = get_thread_store(request)
record = await thread_store.get(thread_id)
if record is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
# ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``.
try:
await thread_store.update_metadata(thread_id, body.metadata)
except Exception:
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to update thread")
# Re-read to get the merged metadata + refreshed updated_at
record = await thread_store.get(thread_id) or record
return ThreadResponse(
thread_id=thread_id,
status=record.get("status", "idle"),
created_at=str(record.get("created_at", "")),
updated_at=str(record.get("updated_at", "")),
metadata=record.get("metadata", {}),
)
@router.get("/{thread_id}", response_model=ThreadResponse)
@require_permission("threads", "read", owner_check=True)
async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
"""Get thread info.
Reads metadata from the ThreadMetaStore and derives the accurate
execution status from the checkpointer. Falls back to the checkpointer
alone for threads that pre-date ThreadMetaStore adoption (backward compat).
"""
from app.gateway.deps import get_thread_store
thread_store = get_thread_store(request)
checkpointer = get_checkpointer(request)
record: dict | None = await thread_store.get(thread_id)
# Derive accurate status from the checkpointer
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
except Exception:
logger.exception("Failed to get checkpoint for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread")
if record is None and checkpoint_tuple is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
# If the thread exists in the checkpointer but not in thread_meta (e.g.
# legacy data created before thread_meta adoption), synthesize a minimal
# record from the checkpoint metadata.
if record is None and checkpoint_tuple is not None:
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
record = {
"thread_id": thread_id,
"status": "idle",
"created_at": ckpt_meta.get("created_at", ""),
"updated_at": ckpt_meta.get("updated_at", ckpt_meta.get("created_at", "")),
"metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")},
}
if record is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle")
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} if checkpoint_tuple is not None else {}
channel_values = checkpoint.get("channel_values", {})
return ThreadResponse(
thread_id=thread_id,
status=status,
created_at=str(record.get("created_at", "")),
updated_at=str(record.get("updated_at", "")),
metadata=record.get("metadata", {}),
values=serialize_channel_values(channel_values),
)
# ---------------------------------------------------------------------------
@router.get("/{thread_id}/state", response_model=ThreadStateResponse)
@require_permission("threads", "read", owner_check=True)
async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse:
"""Get the latest state snapshot for a thread.
Channel values are serialized to ensure LangChain message objects
are converted to JSON-safe dicts.
"""
checkpointer = get_checkpointer(request)
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
try:
checkpoint_tuple = await checkpointer.aget_tuple(config)
except Exception:
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
checkpoint_id = None
ckpt_config = getattr(checkpoint_tuple, "config", {})
if ckpt_config:
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id")
channel_values = checkpoint.get("channel_values", {})
parent_config = getattr(checkpoint_tuple, "parent_config", None)
parent_checkpoint_id = None
if parent_config:
parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id")
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw]
values = serialize_channel_values(channel_values)
return ThreadStateResponse(
values=values,
next=next_tasks,
metadata=metadata,
checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))},
checkpoint_id=checkpoint_id,
parent_checkpoint_id=parent_checkpoint_id,
created_at=str(metadata.get("created_at", "")),
tasks=tasks,
)
@router.post("/{thread_id}/state", response_model=ThreadStateResponse)
@require_permission("threads", "write", owner_check=True, require_existing=True)
async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse:
"""Update thread state (e.g. for human-in-the-loop resume or title rename).
Writes a new checkpoint that merges *body.values* into the latest
channel values, then syncs any updated ``title`` field through the
ThreadMetaStore abstraction so that ``/threads/search`` reflects the
change immediately in both sqlite and memory backends.
"""
from app.gateway.deps import get_thread_store
checkpointer = get_checkpointer(request)
thread_store = get_thread_store(request)
# checkpoint_ns must be present in the config for aput — default to ""
# (the root graph namespace). checkpoint_id is optional; omitting it
# fetches the latest checkpoint for the thread.
read_config: dict[str, Any] = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": "",
}
}
if body.checkpoint_id:
read_config["configurable"]["checkpoint_id"] = body.checkpoint_id
try:
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
except Exception:
logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
# Work on mutable copies so we don't accidentally mutate cached objects.
checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {})
metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {})
channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {}))
if body.values:
channel_values.update(body.values)
checkpoint["channel_values"] = channel_values
metadata["updated_at"] = time.time()
if body.as_node:
metadata["source"] = "update"
metadata["step"] = metadata.get("step", 0) + 1
metadata["writes"] = {body.as_node: body.values}
# aput requires checkpoint_ns in the config — use the same config used for the
# read (which always includes checkpoint_ns=""). Do NOT include checkpoint_id
# so that aput generates a fresh checkpoint ID for the new snapshot.
write_config: dict[str, Any] = {
"configurable": {
"thread_id": thread_id,
"checkpoint_ns": "",
}
}
try:
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
except Exception:
logger.exception("Failed to update state for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to update thread state")
new_checkpoint_id: str | None = None
if isinstance(new_config, dict):
new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id")
# Sync title changes through the ThreadMetaStore abstraction so /threads/search
# reflects them immediately in both sqlite and memory backends.
if body.values and "title" in body.values:
new_title = body.values["title"]
if new_title: # Skip empty strings and None
try:
await thread_store.update_display_name(thread_id, new_title)
except Exception:
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
return ThreadStateResponse(
values=serialize_channel_values(channel_values),
next=[],
metadata=metadata,
checkpoint_id=new_checkpoint_id,
created_at=str(metadata.get("created_at", "")),
)
@router.post("/{thread_id}/history", response_model=list[HistoryEntry])
@require_permission("threads", "read", owner_check=True)
async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]:
"""Get checkpoint history for a thread.
Messages are read from the checkpointer's channel values (the
authoritative source) and serialized via
:func:`~deerflow.runtime.serialization.serialize_channel_values`.
Only the latest (first) checkpoint carries the ``messages`` key to
avoid duplicating them across every entry.
"""
checkpointer = get_checkpointer(request)
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
if body.before:
config["configurable"]["checkpoint_id"] = body.before
entries: list[HistoryEntry] = []
is_latest_checkpoint = True
try:
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
ckpt_config = getattr(checkpoint_tuple, "config", {})
parent_config = getattr(checkpoint_tuple, "parent_config", None)
metadata = getattr(checkpoint_tuple, "metadata", {}) or {}
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "")
parent_id = None
if parent_config:
parent_id = parent_config.get("configurable", {}).get("checkpoint_id")
channel_values = checkpoint.get("channel_values", {})
# Build values from checkpoint channel_values
values: dict[str, Any] = {}
if title := channel_values.get("title"):
values["title"] = title
if thread_data := channel_values.get("thread_data"):
values["thread_data"] = thread_data
# Attach messages only to the latest checkpoint entry.
if is_latest_checkpoint:
messages = channel_values.get("messages")
if messages:
values["messages"] = serialize_channel_values({"messages": messages}).get("messages", [])
is_latest_checkpoint = False
# Derive next tasks
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")]
# Strip LangGraph internal keys from metadata
user_meta = {k: v for k, v in metadata.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
# Keep step for ordering context
if "step" in metadata:
user_meta["step"] = metadata["step"]
entries.append(
HistoryEntry(
checkpoint_id=checkpoint_id,
parent_checkpoint_id=parent_id,
metadata=user_meta,
values=values,
created_at=str(metadata.get("created_at", "")),
next=next_tasks,
)
)
except Exception:
logger.exception("Failed to get history for thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread history")
return entries

View File

@ -7,10 +7,10 @@ import stat
from fastapi import APIRouter, File, HTTPException, Request, UploadFile
from pydantic import BaseModel
from app.gateway.authz import require_permission
from deerflow.config.paths import get_paths
from deerflow.runtime.user_context import get_effective_user_id
from app.plugins.auth.security.actor_context import bind_request_actor_context
from deerflow.sandbox.sandbox_provider import get_sandbox_provider
from deerflow.config.paths import get_paths
from deerflow.runtime.actor_context import get_effective_user_id
from deerflow.uploads.manager import (
PathTraversalError,
delete_file_safe,
@ -56,7 +56,6 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
@router.post("", response_model=UploadResponse)
@require_permission("threads", "write", owner_check=True, require_existing=False)
async def upload_files(
thread_id: str,
request: Request,
@ -66,68 +65,69 @@ async def upload_files(
if not files:
raise HTTPException(status_code=400, detail="No files provided")
try:
uploads_dir = ensure_uploads_dir(thread_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
uploaded_files = []
sandbox_provider = get_sandbox_provider()
sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id)
for file in files:
if not file.filename:
continue
with bind_request_actor_context(request):
try:
safe_filename = normalize_filename(file.filename)
except ValueError:
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
continue
uploads_dir = ensure_uploads_dir(thread_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
uploaded_files = []
try:
content = await file.read()
file_path = uploads_dir / safe_filename
file_path.write_bytes(content)
sandbox_provider = get_sandbox_provider()
sandbox_id = sandbox_provider.acquire(thread_id)
sandbox = sandbox_provider.get(sandbox_id)
virtual_path = upload_virtual_path(safe_filename)
for file in files:
if not file.filename:
continue
if sandbox_id != "local":
_make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, content)
try:
safe_filename = normalize_filename(file.filename)
except ValueError:
logger.warning(f"Skipping file with unsafe filename: {file.filename!r}")
continue
file_info = {
"filename": safe_filename,
"size": str(len(content)),
"path": str(sandbox_uploads / safe_filename),
"virtual_path": virtual_path,
"artifact_url": upload_artifact_url(thread_id, safe_filename),
}
try:
content = await file.read()
file_path = uploads_dir / safe_filename
file_path.write_bytes(content)
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
virtual_path = upload_virtual_path(safe_filename)
file_ext = file_path.suffix.lower()
if file_ext in CONVERTIBLE_EXTENSIONS:
md_path = await convert_file_to_markdown(file_path)
if md_path:
md_virtual_path = upload_virtual_path(md_path.name)
if sandbox_id != "local":
_make_file_sandbox_writable(file_path)
sandbox.update_file(virtual_path, content)
if sandbox_id != "local":
_make_file_sandbox_writable(md_path)
sandbox.update_file(md_virtual_path, md_path.read_bytes())
file_info = {
"filename": safe_filename,
"size": str(len(content)),
"path": str(sandbox_uploads / safe_filename),
"virtual_path": virtual_path,
"artifact_url": upload_artifact_url(thread_id, safe_filename),
}
file_info["markdown_file"] = md_path.name
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
file_info["markdown_virtual_path"] = md_virtual_path
file_info["markdown_artifact_url"] = upload_artifact_url(thread_id, md_path.name)
logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}")
uploaded_files.append(file_info)
file_ext = file_path.suffix.lower()
if file_ext in CONVERTIBLE_EXTENSIONS:
md_path = await convert_file_to_markdown(file_path)
if md_path:
md_virtual_path = upload_virtual_path(md_path.name)
except Exception as e:
logger.error(f"Failed to upload {file.filename}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
if sandbox_id != "local":
_make_file_sandbox_writable(md_path)
sandbox.update_file(md_virtual_path, md_path.read_bytes())
file_info["markdown_file"] = md_path.name
file_info["markdown_path"] = str(sandbox_uploads / md_path.name)
file_info["markdown_virtual_path"] = md_virtual_path
file_info["markdown_artifact_url"] = upload_artifact_url(thread_id, md_path.name)
uploaded_files.append(file_info)
except Exception as e:
logger.error(f"Failed to upload {file.filename}: {e}")
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
return UploadResponse(
success=True,
@ -137,26 +137,25 @@ async def upload_files(
@router.get("/list", response_model=dict)
@require_permission("threads", "read", owner_check=True)
async def list_uploaded_files(thread_id: str, request: Request) -> dict:
"""List all files in a thread's uploads directory."""
try:
uploads_dir = get_uploads_dir(thread_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
result = list_files_in_dir(uploads_dir)
enrich_file_listing(result, thread_id)
with bind_request_actor_context(request):
try:
uploads_dir = get_uploads_dir(thread_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
result = list_files_in_dir(uploads_dir)
enrich_file_listing(result, thread_id)
# Gateway additionally includes the sandbox-relative path.
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
for f in result["files"]:
f["path"] = str(sandbox_uploads / f["filename"])
# Gateway additionally includes the sandbox-relative path.
sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id())
for f in result["files"]:
f["path"] = str(sandbox_uploads / f["filename"])
return result
return result
@router.delete("/{filename}")
@require_permission("threads", "delete", owner_check=True, require_existing=True)
async def delete_uploaded_file(thread_id: str, filename: str, request: Request) -> dict:
"""Delete a file from a thread's uploads directory."""
try: