diff --git a/backend/app/gateway/routers/__init__.py b/backend/app/gateway/routers/__init__.py index c5f67a396..06523304e 100644 --- a/backend/app/gateway/routers/__init__.py +++ b/backend/app/gateway/routers/__init__.py @@ -1,3 +1,3 @@ -from . import artifacts, assistants_compat, mcp, models, skills, suggestions, thread_runs, threads, uploads +from . import artifacts, mcp, models, skills, suggestions, uploads -__all__ = ["artifacts", "assistants_compat", "mcp", "models", "skills", "suggestions", "threads", "thread_runs", "uploads"] +__all__ = ["artifacts", "mcp", "models", "skills", "suggestions", "uploads"] diff --git a/backend/app/gateway/routers/artifacts.py b/backend/app/gateway/routers/artifacts.py index 78ea5fa00..a58fd5c0b 100644 --- a/backend/app/gateway/routers/artifacts.py +++ b/backend/app/gateway/routers/artifacts.py @@ -7,7 +7,6 @@ from urllib.parse import quote from fastapi import APIRouter, HTTPException, Request from fastapi.responses import FileResponse, PlainTextResponse, Response -from app.gateway.authz import require_permission from app.gateway.path_utils import resolve_thread_virtual_path logger = logging.getLogger(__name__) @@ -82,7 +81,6 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte summary="Get Artifact File", description="Retrieve an artifact file generated by the AI agent. Text and binary files can be viewed inline, while active web content is always downloaded.", ) -@require_permission("threads", "read", owner_check=True) async def get_artifact(thread_id: str, path: str, request: Request, download: bool = False) -> Response: """Get an artifact file by its path. diff --git a/backend/app/gateway/routers/langgraph/__init__.py b/backend/app/gateway/routers/langgraph/__init__.py new file mode 100644 index 000000000..4a5e7281d --- /dev/null +++ b/backend/app/gateway/routers/langgraph/__init__.py @@ -0,0 +1,6 @@ +from .feedback import router as feedback_router +from .runs import router as runs_router +from .suggestions import router as suggestion_router +from .threads import router as threads_router + +__all__ = ["feedback_router", "runs_router", "threads_router", "suggestion_router"] diff --git a/backend/app/gateway/routers/feedback.py b/backend/app/gateway/routers/langgraph/feedback.py similarity index 50% rename from backend/app/gateway/routers/feedback.py rename to backend/app/gateway/routers/langgraph/feedback.py index ca5c1d406..a5c0ed7bc 100644 --- a/backend/app/gateway/routers/feedback.py +++ b/backend/app/gateway/routers/langgraph/feedback.py @@ -1,8 +1,4 @@ -"""Feedback endpoints — create, list, stats, delete. - -Allows users to submit thumbs-up/down feedback on runs, -optionally scoped to a specific message. -""" +"""LangGraph-compatible run feedback endpoints.""" from __future__ import annotations @@ -12,16 +8,12 @@ from typing import Any from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel, Field -from app.gateway.authz import require_permission -from app.gateway.deps import get_current_user, get_feedback_repo, get_run_store +from app.gateway.dependencies import get_feedback_repository, get_run_repository +from app.plugins.auth.security.actor_context import bind_request_actor_context, resolve_request_user_id +from app.plugins.auth.security.dependencies import get_current_user_id logger = logging.getLogger(__name__) -router = APIRouter(prefix="/api/threads", tags=["feedback"]) - - -# --------------------------------------------------------------------------- -# Request / response models -# --------------------------------------------------------------------------- +router = APIRouter(tags=["feedback"]) class FeedbackCreateRequest(BaseModel): @@ -30,16 +22,11 @@ class FeedbackCreateRequest(BaseModel): message_id: str | None = Field(default=None, description="Optional: scope feedback to a specific message") -class FeedbackUpsertRequest(BaseModel): - rating: int = Field(..., description="Feedback rating: +1 (positive) or -1 (negative)") - comment: str | None = Field(default=None, description="Optional text feedback") - - class FeedbackResponse(BaseModel): feedback_id: str run_id: str thread_id: str - user_id: str | None = None + owner_id: str | None = None message_id: str | None = None rating: int comment: str | None = None @@ -53,85 +40,36 @@ class FeedbackStatsResponse(BaseModel): negative: int = 0 -# --------------------------------------------------------------------------- -# Endpoints -# --------------------------------------------------------------------------- - - -@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse) -@require_permission("threads", "write", owner_check=True, require_existing=True) -async def upsert_feedback( - thread_id: str, - run_id: str, - body: FeedbackUpsertRequest, - request: Request, -) -> dict[str, Any]: - """Create or update feedback for a run (idempotent).""" - if body.rating not in (1, -1): - raise HTTPException(status_code=400, detail="rating must be +1 or -1") - - user_id = await get_current_user(request) - - run_store = get_run_store(request) - run = await run_store.get(run_id) +async def _validate_run_scope(thread_id: str, run_id: str, request: Request) -> None: + run_store = get_run_repository(request) + if resolve_request_user_id(request) is None: + run = await run_store.get(run_id, user_id=None) + else: + with bind_request_actor_context(request): + run = await run_store.get(run_id) if run is None: raise HTTPException(status_code=404, detail=f"Run {run_id} not found") if run.get("thread_id") != thread_id: raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}") - feedback_repo = get_feedback_repo(request) - return await feedback_repo.upsert( - run_id=run_id, - thread_id=thread_id, - rating=body.rating, - user_id=user_id, - comment=body.comment, - ) + +async def _get_current_user(request: Request) -> str | None: + """Extract current user id from auth dependencies when available.""" + return await get_current_user_id(request) -@router.delete("/{thread_id}/runs/{run_id}/feedback") -@require_permission("threads", "delete", owner_check=True, require_existing=True) -async def delete_run_feedback( - thread_id: str, - run_id: str, - request: Request, -) -> dict[str, bool]: - """Delete the current user's feedback for a run.""" - user_id = await get_current_user(request) - feedback_repo = get_feedback_repo(request) - deleted = await feedback_repo.delete_by_run( - thread_id=thread_id, - run_id=run_id, - user_id=user_id, - ) - if not deleted: - raise HTTPException(status_code=404, detail="No feedback found for this run") - return {"success": True} - - -@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse) -@require_permission("threads", "write", owner_check=True, require_existing=True) -async def create_feedback( +async def _create_feedback( thread_id: str, run_id: str, body: FeedbackCreateRequest, request: Request, ) -> dict[str, Any]: - """Submit feedback (thumbs-up/down) for a run.""" if body.rating not in (1, -1): raise HTTPException(status_code=400, detail="rating must be +1 or -1") - user_id = await get_current_user(request) - - # Validate run exists and belongs to thread - run_store = get_run_store(request) - run = await run_store.get(run_id) - if run is None: - raise HTTPException(status_code=404, detail=f"Run {run_id} not found") - if run.get("thread_id") != thread_id: - raise HTTPException(status_code=404, detail=f"Run {run_id} not found in thread {thread_id}") - - feedback_repo = get_feedback_repo(request) + await _validate_run_scope(thread_id, run_id, request) + user_id = await _get_current_user(request) + feedback_repo = get_feedback_repository(request) return await feedback_repo.create( run_id=run_id, thread_id=thread_id, @@ -142,41 +80,94 @@ async def create_feedback( ) +@router.put("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse) +async def upsert_feedback( + thread_id: str, + run_id: str, + body: FeedbackCreateRequest, + request: Request, +) -> dict[str, Any]: + """Create or replace the run-level feedback record.""" + feedback_repo = get_feedback_repository(request) + user_id = await _get_current_user(request) + if user_id is not None: + return await feedback_repo.upsert( + run_id=run_id, + thread_id=thread_id, + rating=body.rating, + user_id=user_id, + comment=body.comment, + ) + existing = await feedback_repo.list_by_run(thread_id, run_id, limit=100, user_id=None) + for item in existing: + feedback_id = item.get("feedback_id") + if isinstance(feedback_id, str): + await feedback_repo.delete(feedback_id) + return await _create_feedback(thread_id, run_id, body, request) + + +@router.post("/{thread_id}/runs/{run_id}/feedback", response_model=FeedbackResponse) +async def create_feedback( + thread_id: str, + run_id: str, + body: FeedbackCreateRequest, + request: Request, +) -> dict[str, Any]: + """Submit feedback for a run.""" + return await _create_feedback(thread_id, run_id, body, request) + + @router.get("/{thread_id}/runs/{run_id}/feedback", response_model=list[FeedbackResponse]) -@require_permission("threads", "read", owner_check=True) async def list_feedback( thread_id: str, run_id: str, request: Request, ) -> list[dict[str, Any]]: """List all feedback for a run.""" - feedback_repo = get_feedback_repo(request) - return await feedback_repo.list_by_run(thread_id, run_id) + feedback_repo = get_feedback_repository(request) + user_id = await _get_current_user(request) + return await feedback_repo.list_by_run(thread_id, run_id, user_id=user_id) @router.get("/{thread_id}/runs/{run_id}/feedback/stats", response_model=FeedbackStatsResponse) -@require_permission("threads", "read", owner_check=True) async def feedback_stats( thread_id: str, run_id: str, request: Request, ) -> dict[str, Any]: - """Get aggregated feedback stats (positive/negative counts) for a run.""" - feedback_repo = get_feedback_repo(request) + """Get aggregated feedback stats for a run.""" + feedback_repo = get_feedback_repository(request) return await feedback_repo.aggregate_by_run(thread_id, run_id) +@router.delete("/{thread_id}/runs/{run_id}/feedback") +async def delete_run_feedback( + thread_id: str, + run_id: str, + request: Request, +) -> dict[str, bool]: + """Delete all feedback records for a run.""" + feedback_repo = get_feedback_repository(request) + user_id = await _get_current_user(request) + if user_id is not None: + return {"success": await feedback_repo.delete_by_run(thread_id=thread_id, run_id=run_id, user_id=user_id)} + existing = await feedback_repo.list_by_run(thread_id, run_id, limit=100, user_id=None) + for item in existing: + feedback_id = item.get("feedback_id") + if isinstance(feedback_id, str): + await feedback_repo.delete(feedback_id) + return {"success": True} + + @router.delete("/{thread_id}/runs/{run_id}/feedback/{feedback_id}") -@require_permission("threads", "delete", owner_check=True, require_existing=True) async def delete_feedback( thread_id: str, run_id: str, feedback_id: str, request: Request, ) -> dict[str, bool]: - """Delete a feedback record.""" - feedback_repo = get_feedback_repo(request) - # Verify feedback belongs to the specified thread/run before deleting + """Delete a single feedback record.""" + feedback_repo = get_feedback_repository(request) existing = await feedback_repo.get(feedback_id) if existing is None: raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found") diff --git a/backend/app/gateway/routers/langgraph/runs.py b/backend/app/gateway/routers/langgraph/runs.py new file mode 100644 index 000000000..6d48e6cca --- /dev/null +++ b/backend/app/gateway/routers/langgraph/runs.py @@ -0,0 +1,501 @@ +"""LangGraph-compatible runs endpoints backed by RunsFacade.""" + +from __future__ import annotations + +import json +from collections.abc import AsyncIterator +from typing import Literal + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import Response, StreamingResponse +from pydantic import BaseModel, Field + +from app.plugins.auth.security.actor_context import bind_request_actor_context +from app.gateway.services.runs.facade_factory import build_runs_facade_from_request +from app.gateway.services.runs.input import ( + AdaptedRunRequest, + RunSpecBuilder, + UnsupportedRunFeatureError, + adapt_create_run_request, + adapt_create_stream_request, + adapt_create_wait_request, + adapt_join_stream_request, + adapt_join_wait_request, +) +from deerflow.runtime.runs.types import RunRecord, RunSpec +from deerflow.runtime.stream_bridge import JSONValue, StreamEvent + +router = APIRouter(tags=["runs"]) + + +class RunCreateRequest(BaseModel): + assistant_id: str | None = Field(default=None, description="Agent / assistant to use") + follow_up_to_run_id: str | None = Field(default=None, description="Lineage link to the prior run") + input: dict[str, JSONValue] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})") + command: dict[str, JSONValue] | None = Field(default=None, description="LangGraph Command") + metadata: dict[str, JSONValue] | None = Field(default=None, description="Run metadata") + config: dict[str, JSONValue] | None = Field(default=None, description="RunnableConfig overrides") + context: dict[str, JSONValue] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)") + webhook: str | None = Field(default=None, description="Completion callback URL") + checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint") + checkpoint: dict[str, JSONValue] | None = Field(default=None, description="Full checkpoint object") + interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before") + interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after") + stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)") + stream_subgraphs: bool = Field(default=False, description="Include subgraph events") + stream_resumable: bool | None = Field(default=None, description="SSE resumable mode") + on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect") + on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion") + multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy") + after_seconds: float | None = Field(default=None, description="Delayed execution") + if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy") + feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys") + + +class RunResponse(BaseModel): + run_id: str + thread_id: str + assistant_id: str | None = None + status: str + metadata: dict[str, JSONValue] = Field(default_factory=dict) + multitask_strategy: str = "reject" + created_at: str = "" + updated_at: str = "" + + +class RunDeleteResponse(BaseModel): + deleted: bool + + +class RunMessageResponse(BaseModel): + run_id: str + content: JSONValue + metadata: dict[str, JSONValue] = Field(default_factory=dict) + created_at: str + seq: int + + +class RunMessagesResponse(BaseModel): + data: list[RunMessageResponse] + hasMore: bool = False + + +def format_sse(event: str, data: JSONValue, *, event_id: str | None = None) -> str: + """Format a single SSE frame.""" + payload = json.dumps(data, default=str, ensure_ascii=False) + parts = [f"event: {event}", f"data: {payload}"] + if event_id: + parts.append(f"id: {event_id}") + parts.append("") + parts.append("") + return "\n".join(parts) + + +def _record_to_response(record: RunRecord) -> RunResponse: + return RunResponse( + run_id=record.run_id, + thread_id=record.thread_id, + assistant_id=record.assistant_id, + status=record.status, + metadata=record.metadata, + multitask_strategy=record.multitask_strategy, + created_at=record.created_at, + updated_at=record.updated_at, + ) + + +def _trim_paginated_rows( + rows: list[dict], + *, + limit: int, + after_seq: int | None, +) -> tuple[list[dict], bool]: + has_more = len(rows) > limit + if not has_more: + return rows, False + if after_seq is not None: + return rows[:limit], True + return rows[-limit:], True + + +def _event_to_run_message(event: dict) -> RunMessageResponse: + return RunMessageResponse( + run_id=str(event["run_id"]), + content=event.get("content"), + metadata=dict(event.get("metadata") or {}), + created_at=str(event.get("created_at") or ""), + seq=int(event["seq"]), + ) + + +async def _sse_consumer( + stream: AsyncIterator[StreamEvent], + request: Request, + *, + cancel_on_disconnect: bool, + cancel_run, + run_id: str, +) -> AsyncIterator[str]: + try: + async for event in stream: + if await request.is_disconnected(): + break + + if event.event == "__heartbeat__": + yield ": heartbeat\n\n" + continue + + if event.event == "__end__": + yield format_sse("end", None, event_id=event.id or None) + return + + if event.event == "__cancelled__": + yield format_sse("cancel", None, event_id=event.id or None) + return + + yield format_sse(event.event, event.data, event_id=event.id or None) + finally: + if cancel_on_disconnect: + await cancel_run(run_id) + + +def _get_run_event_store(request: Request): + event_store = getattr(request.app.state, "run_event_store", None) + if event_store is None: + raise HTTPException(status_code=503, detail="Run event store not available") + return event_store + + +@router.get("/{thread_id}/runs", response_model=list[RunResponse]) +async def list_runs( + thread_id: str, + request: Request, + limit: int = 100, + offset: int = 0, + status: str | None = None, +) -> list[RunResponse]: + # Accepted for API compatibility; field projection is not implemented yet. + facade = build_runs_facade_from_request(request) + with bind_request_actor_context(request): + records = await facade.list_runs(thread_id) + if status is not None: + records = [record for record in records if record.status == status] + records = records[offset : offset + limit] + return [_record_to_response(record) for record in records] + + +@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse) +async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse: + facade = build_runs_facade_from_request(request) + with bind_request_actor_context(request): + record = await facade.get_run(run_id) + if record is None or record.thread_id != thread_id: + raise HTTPException(status_code=404, detail=f"Run {run_id} not found") + return _record_to_response(record) + + +@router.get("/{thread_id}/runs/{run_id}/messages", response_model=RunMessagesResponse) +async def run_messages( + thread_id: str, + run_id: str, + request: Request, + limit: int = 50, + before_seq: int | None = None, + after_seq: int | None = None, +) -> RunMessagesResponse: + facade = build_runs_facade_from_request(request) + with bind_request_actor_context(request): + record = await facade.get_run(run_id) + if record is None or record.thread_id != thread_id: + raise HTTPException(status_code=404, detail=f"Run {run_id} not found") + + event_store = _get_run_event_store(request) + with bind_request_actor_context(request): + rows = await event_store.list_messages_by_run( + thread_id, + run_id, + limit=limit + 1, + before_seq=before_seq, + after_seq=after_seq, + ) + page, has_more = _trim_paginated_rows(rows, limit=limit, after_seq=after_seq) + return RunMessagesResponse(data=[_event_to_run_message(row) for row in page], hasMore=has_more) + + +def _build_spec( + *, + adapted: AdaptedRunRequest, +) -> RunSpec: + try: + return RunSpecBuilder().build(adapted) + except UnsupportedRunFeatureError as exc: + raise HTTPException(status_code=501, detail=str(exc)) from exc + + +@router.post("/{thread_id}/runs", response_model=RunResponse) +async def create_run( + thread_id: str, + body: RunCreateRequest, + request: Request, +) -> Response: + adapted = adapt_create_run_request( + thread_id=thread_id, + body=body.model_dump(), + headers=dict(request.headers), + query=dict(request.query_params), + ) + spec = _build_spec(adapted=adapted) + facade = build_runs_facade_from_request(request) + with bind_request_actor_context(request): + record = await facade.create_background(spec) + return Response( + content=_record_to_response(record).model_dump_json(), + media_type="application/json", + headers={"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}"}, + ) + + +@router.post("/{thread_id}/runs/stream") +async def stream_run( + thread_id: str, + body: RunCreateRequest, + request: Request, +) -> StreamingResponse: + adapted = adapt_create_stream_request( + thread_id=thread_id, + body=body.model_dump(), + headers=dict(request.headers), + query=dict(request.query_params), + ) + + spec = _build_spec(adapted=adapted) + + facade = build_runs_facade_from_request(request) + with bind_request_actor_context(request): + record, stream = await facade.create_and_stream(spec) + + return StreamingResponse( + _sse_consumer( + stream, + request, + cancel_on_disconnect=spec.on_disconnect == "cancel", + cancel_run=facade.cancel, + run_id=record.run_id, + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + "Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}", + }, + ) + + +@router.post("/{thread_id}/runs/wait") +async def wait_run( + thread_id: str, + body: RunCreateRequest, + request: Request, +) -> Response: + adapted = adapt_create_wait_request( + thread_id=thread_id, + body=body.model_dump(), + headers=dict(request.headers), + query=dict(request.query_params), + ) + spec = _build_spec(adapted=adapted) + facade = build_runs_facade_from_request(request) + with bind_request_actor_context(request): + record, result = await facade.create_and_wait(spec) + return Response( + content=json.dumps(result, default=str, ensure_ascii=False), + media_type="application/json", + headers={"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}"}, + ) + + +@router.post("/runs", response_model=RunResponse) +async def create_stateless_run(body: RunCreateRequest, request: Request) -> Response: + adapted = adapt_create_run_request( + thread_id=None, + body=body.model_dump(), + headers=dict(request.headers), + query=dict(request.query_params), + ) + spec = _build_spec(adapted=adapted) + facade = build_runs_facade_from_request(request) + with bind_request_actor_context(request): + record = await facade.create_background(spec) + return Response( + content=_record_to_response(record).model_dump_json(), + media_type="application/json", + headers={"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}"}, + ) + + +@router.post("/runs/stream") +async def create_stateless_stream_run(body: RunCreateRequest, request: Request) -> StreamingResponse: + adapted = adapt_create_stream_request( + thread_id=None, + body=body.model_dump(), + headers=dict(request.headers), + query=dict(request.query_params), + ) + spec = _build_spec(adapted=adapted) + facade = build_runs_facade_from_request(request) + with bind_request_actor_context(request): + record, stream = await facade.create_and_stream(spec) + + return StreamingResponse( + _sse_consumer( + stream, + request, + cancel_on_disconnect=spec.on_disconnect == "cancel", + cancel_run=facade.cancel, + run_id=record.run_id, + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + "Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}", + }, + ) + + +@router.post("/runs/wait") +async def wait_stateless_run(body: RunCreateRequest, request: Request) -> Response: + adapted = adapt_create_wait_request( + thread_id=None, + body=body.model_dump(), + headers=dict(request.headers), + query=dict(request.query_params), + ) + spec = _build_spec(adapted=adapted) + facade = build_runs_facade_from_request(request) + with bind_request_actor_context(request): + record, result = await facade.create_and_wait(spec) + return Response( + content=json.dumps(result, default=str, ensure_ascii=False), + media_type="application/json", + headers={"Content-Location": f"/api/threads/{record.thread_id}/runs/{record.run_id}"}, + ) + + +@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None) +async def stream_existing_run( + thread_id: str, + run_id: str, + request: Request, + action: Literal["interrupt", "rollback"] | None = None, + wait: bool = False, + cancel_on_disconnect: bool = False, + stream_mode: str | None = None, +) -> StreamingResponse | Response: + facade = build_runs_facade_from_request(request) + with bind_request_actor_context(request): + record = await facade.get_run(run_id) + if record is None or record.thread_id != thread_id: + raise HTTPException(status_code=404, detail=f"Run {run_id} not found") + + if action is not None: + with bind_request_actor_context(request): + cancelled = await facade.cancel(run_id, action=action) + if not cancelled: + raise HTTPException(status_code=409, detail=f"Run {run_id} is not cancellable") + if wait: + with bind_request_actor_context(request): + await facade.join_wait(run_id) + return Response(status_code=204) + + adapted = adapt_join_stream_request( + thread_id=thread_id, + run_id=run_id, + headers=dict(request.headers), + query=dict(request.query_params), + ) + with bind_request_actor_context(request): + stream = await facade.join_stream(run_id, last_event_id=adapted.last_event_id) + + return StreamingResponse( + _sse_consumer( + stream, + request, + cancel_on_disconnect=cancel_on_disconnect, + cancel_run=facade.cancel, + run_id=run_id, + ), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + }, + ) + + +@router.get("/{thread_id}/runs/{run_id}/join") +async def join_existing_run( + thread_id: str, + run_id: str, + request: Request, + cancel_on_disconnect: bool = False, +) -> JSONValue: + # Accepted for API compatibility; current join_wait path does not change + # behavior based on client disconnect. + facade = build_runs_facade_from_request(request) + with bind_request_actor_context(request): + record = await facade.get_run(run_id) + if record is None or record.thread_id != thread_id: + raise HTTPException(status_code=404, detail=f"Run {run_id} not found") + + adapted = adapt_join_wait_request( + thread_id=thread_id, + run_id=run_id, + headers=dict(request.headers), + query=dict(request.query_params), + ) + with bind_request_actor_context(request): + return await facade.join_wait(run_id, last_event_id=adapted.last_event_id) + + +@router.post("/{thread_id}/runs/{run_id}/cancel") +async def cancel_existing_run( + thread_id: str, + run_id: str, + request: Request, + wait: bool = False, + action: Literal["interrupt", "rollback"] = "interrupt", +) -> JSONValue: + facade = build_runs_facade_from_request(request) + with bind_request_actor_context(request): + record = await facade.get_run(run_id) + if record is None or record.thread_id != thread_id: + raise HTTPException(status_code=404, detail=f"Run {run_id} not found") + + with bind_request_actor_context(request): + cancelled = await facade.cancel(run_id, action=action) + if not cancelled: + raise HTTPException(status_code=409, detail=f"Run {run_id} is not cancellable") + if wait: + with bind_request_actor_context(request): + return await facade.join_wait(run_id) + return {} + + +@router.delete("/{thread_id}/runs/{run_id}", response_model=RunDeleteResponse) +async def delete_run( + thread_id: str, + run_id: str, + request: Request, +) -> RunDeleteResponse: + facade = build_runs_facade_from_request(request) + with bind_request_actor_context(request): + record = await facade.get_run(run_id) + if record is None or record.thread_id != thread_id: + raise HTTPException(status_code=404, detail=f"Run {run_id} not found") + with bind_request_actor_context(request): + deleted = await facade.delete_run(run_id) + return RunDeleteResponse(deleted=deleted) diff --git a/backend/app/gateway/routers/langgraph/suggestions.py b/backend/app/gateway/routers/langgraph/suggestions.py new file mode 100644 index 000000000..ac54e674d --- /dev/null +++ b/backend/app/gateway/routers/langgraph/suggestions.py @@ -0,0 +1,132 @@ +import json +import logging + +from fastapi import APIRouter +from langchain_core.messages import HumanMessage, SystemMessage +from pydantic import BaseModel, Field + +from deerflow.models import create_chat_model + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/api", tags=["suggestions"]) + + +class SuggestionMessage(BaseModel): + role: str = Field(..., description="Message role: user|assistant") + content: str = Field(..., description="Message content as plain text") + + +class SuggestionsRequest(BaseModel): + messages: list[SuggestionMessage] = Field(..., description="Recent conversation messages") + n: int = Field(default=3, ge=1, le=5, description="Number of suggestions to generate") + model_name: str | None = Field(default=None, description="Optional model override") + + +class SuggestionsResponse(BaseModel): + suggestions: list[str] = Field(default_factory=list, description="Suggested follow-up questions") + + +def _strip_markdown_code_fence(text: str) -> str: + stripped = text.strip() + if not stripped.startswith("```"): + return stripped + lines = stripped.splitlines() + if len(lines) >= 3 and lines[0].startswith("```") and lines[-1].startswith("```"): + return "\n".join(lines[1:-1]).strip() + return stripped + + +def _parse_json_string_list(text: str) -> list[str] | None: + candidate = _strip_markdown_code_fence(text) + start = candidate.find("[") + end = candidate.rfind("]") + if start == -1 or end == -1 or end <= start: + return None + candidate = candidate[start : end + 1] + try: + data = json.loads(candidate) + except Exception: + return None + if not isinstance(data, list): + return None + out: list[str] = [] + for item in data: + if not isinstance(item, str): + continue + s = item.strip() + if not s: + continue + out.append(s) + return out + + +def _extract_response_text(content: object) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for block in content: + if isinstance(block, str): + parts.append(block) + elif isinstance(block, dict) and block.get("type") in {"text", "output_text"}: + text = block.get("text") + if isinstance(text, str): + parts.append(text) + return "\n".join(parts) if parts else "" + if content is None: + return "" + return str(content) + + +def _format_conversation(messages: list[SuggestionMessage]) -> str: + parts: list[str] = [] + for m in messages: + role = m.role.strip().lower() + if role in ("user", "human"): + parts.append(f"User: {m.content.strip()}") + elif role in ("assistant", "ai"): + parts.append(f"Assistant: {m.content.strip()}") + else: + parts.append(f"{m.role}: {m.content.strip()}") + return "\n".join(parts).strip() + + +@router.post( + "/threads/{thread_id}/suggestions", + response_model=SuggestionsResponse, + summary="Generate Follow-up Questions", + description="Generate short follow-up questions a user might ask next, based on recent conversation context.", +) +async def generate_suggestions(thread_id: str, request: SuggestionsRequest) -> SuggestionsResponse: + if not request.messages: + return SuggestionsResponse(suggestions=[]) + + n = request.n + conversation = _format_conversation(request.messages) + if not conversation: + return SuggestionsResponse(suggestions=[]) + + system_instruction = ( + "You are generating follow-up questions to help the user continue the conversation.\n" + f"Based on the conversation below, produce EXACTLY {n} short questions the user might ask next.\n" + "Requirements:\n" + "- Questions must be relevant to the preceding conversation.\n" + "- Questions must be written in the same language as the user.\n" + "- Keep each question concise (ideally <= 20 words / <= 40 Chinese characters).\n" + "- Do NOT include numbering, markdown, or any extra text.\n" + "- Output MUST be a JSON array of strings only.\n" + ) + user_content = f"Conversation Context:\n{conversation}\n\nGenerate {n} follow-up questions" + + try: + model = create_chat_model(name=request.model_name, thinking_enabled=False) + response = await model.ainvoke([SystemMessage(content=system_instruction), HumanMessage(content=user_content)]) + raw = _extract_response_text(response.content) + suggestions = _parse_json_string_list(raw) or [] + cleaned = [s.replace("\n", " ").strip() for s in suggestions if s.strip()] + cleaned = cleaned[:n] + return SuggestionsResponse(suggestions=cleaned) + except Exception as exc: + logger.exception("Failed to generate suggestions: thread_id=%s err=%s", thread_id, exc) + return SuggestionsResponse(suggestions=[]) diff --git a/backend/app/gateway/routers/langgraph/threads.py b/backend/app/gateway/routers/langgraph/threads.py new file mode 100644 index 000000000..e500fd0ca --- /dev/null +++ b/backend/app/gateway/routers/langgraph/threads.py @@ -0,0 +1,455 @@ +"""Thread management endpoints. + +Provides CRUD operations for threads and checkpoint state management. +""" + +from __future__ import annotations + +import logging +import time +import uuid +from typing import Any + +from fastapi import APIRouter, HTTPException, Request +from pydantic import BaseModel, Field + +from app.gateway.dependencies import CurrentCheckpointer, CurrentRunRepository, CurrentThreadMetaStorage +from app.infra.storage import ThreadMetaStorage +from app.plugins.auth.security.actor_context import bind_request_actor_context, resolve_request_user_id +from deerflow.config.paths import Paths, get_paths +from deerflow.runtime import serialize_channel_values + +logger = logging.getLogger(__name__) +router = APIRouter(tags=["threads"]) + + +# --------------------------------------------------------------------------- +# Request / Response Models +# --------------------------------------------------------------------------- + + +class ThreadCreateRequest(BaseModel): + thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)") + assistant_id: str | None = Field(default=None, description="Associate thread with an assistant") + metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata") + + +class ThreadSearchRequest(BaseModel): + metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)") + limit: int = Field(default=100, ge=1, le=1000, description="Maximum results") + offset: int = Field(default=0, ge=0, description="Pagination offset") + status: str | None = Field(default=None, description="Filter by thread status") + user_id: str | None = Field(default=None, description="Filter by user ID") + assistant_id: str | None = Field(default=None, description="Filter by assistant ID") + + +class ThreadResponse(BaseModel): + thread_id: str = Field(description="Unique thread identifier") + status: str = Field(default="idle", description="Thread status") + created_at: str = Field(default="", description="ISO timestamp") + updated_at: str = Field(default="", description="ISO timestamp") + metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata") + values: dict[str, Any] = Field(default_factory=dict, description="Current state values") + interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts") + + +class ThreadDeleteResponse(BaseModel): + success: bool + message: str + + +class ThreadStateUpdateRequest(BaseModel): + values: dict[str, Any] | None = Field(default=None, description="Channel values to merge") + checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from") + checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object") + as_node: str | None = Field(default=None, description="Node identity for the update") + + +class ThreadStateResponse(BaseModel): + values: dict[str, Any] = Field(default_factory=dict, description="Current channel values") + next: list[str] = Field(default_factory=list, description="Next nodes to execute") + tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details") + checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info") + checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID") + parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID") + metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata") + created_at: str | None = Field(default=None, description="Checkpoint timestamp") + + +class ThreadHistoryRequest(BaseModel): + limit: int = Field(default=10, ge=1, le=100, description="Maximum entries") + before: str | None = Field(default=None, description="Cursor for pagination (checkpoint_id)") + + +class HistoryEntry(BaseModel): + checkpoint_id: str + parent_checkpoint_id: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + values: dict[str, Any] = Field(default_factory=dict) + created_at: str | None = None + next: list[str] = Field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def sanitize_log_param(value: str) -> str: + """Strip control characters to prevent log injection.""" + + return value.replace("\n", "").replace("\r", "").replace("\x00", "") + + +def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse: + """Delete local filesystem data for a thread.""" + path_manager = paths or get_paths() + try: + path_manager.delete_thread_dir(thread_id) + except ValueError as exc: + raise HTTPException(status_code=422, detail=str(exc)) from exc + except FileNotFoundError: + logger.debug("No local thread data to delete for %s", sanitize_log_param(thread_id)) + return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}") + except Exception as exc: + logger.exception("Failed to delete thread data for %s", sanitize_log_param(thread_id)) + raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc + + logger.info("Deleted local thread data for %s", sanitize_log_param(thread_id)) + return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}") + + +async def _thread_or_run_exists( + *, + request: Request, + thread_id: str, + thread_meta_storage: ThreadMetaStorage, + run_repo, +) -> bool: + request_user_id = resolve_request_user_id(request) + + if request_user_id is None: + thread = await thread_meta_storage.get_thread(thread_id, user_id=None) + if thread is not None: + return True + runs = await run_repo.list_by_thread(thread_id, limit=1, user_id=None) + return bool(runs) + + with bind_request_actor_context(request): + thread = await thread_meta_storage.get_thread(thread_id) + if thread is not None: + return True + runs = await run_repo.list_by_thread(thread_id, limit=1) + return bool(runs) + + +# --------------------------------------------------------------------------- +# Endpoints +# --------------------------------------------------------------------------- + + +@router.post("", response_model=ThreadResponse) +async def create_thread( + body: ThreadCreateRequest, + request: Request, + thread_meta_storage: CurrentThreadMetaStorage, +) -> ThreadResponse: + """Create a new thread.""" + thread_id = body.thread_id or str(uuid.uuid4()) + + request_user_id = resolve_request_user_id(request) + if request_user_id is None: + existing = await thread_meta_storage.get_thread(thread_id, user_id=None) + else: + with bind_request_actor_context(request): + existing = await thread_meta_storage.get_thread(thread_id) + if existing is not None: + return ThreadResponse( + thread_id=thread_id, + status=existing.status, + created_at=existing.created_time.isoformat() if existing.created_time else "", + updated_at=existing.updated_time.isoformat() if existing.updated_time else "", + metadata=existing.metadata, + ) + + try: + if request_user_id is None: + created = await thread_meta_storage.ensure_thread( + thread_id=thread_id, + assistant_id=body.assistant_id, + metadata=body.metadata, + user_id=None, + ) + else: + with bind_request_actor_context(request): + created = await thread_meta_storage.ensure_thread( + thread_id=thread_id, + assistant_id=body.assistant_id, + metadata=body.metadata, + ) + except Exception: + logger.exception("Failed to create thread %s", sanitize_log_param(thread_id)) + raise HTTPException(status_code=500, detail="Failed to create thread") + + logger.info("Thread created: %s", sanitize_log_param(thread_id)) + return ThreadResponse( + thread_id=thread_id, + status=created.status, + created_at=created.created_time.isoformat() if created.created_time else "", + updated_at=created.updated_time.isoformat() if created.updated_time else "", + metadata=created.metadata, + ) + + +@router.post("/search", response_model=list[ThreadResponse]) +async def search_threads( + body: ThreadSearchRequest, + request: Request, + thread_meta_storage: CurrentThreadMetaStorage, +) -> list[ThreadResponse]: + """Search threads with filters.""" + try: + request_user_id = resolve_request_user_id(request) + if request_user_id is None: + threads = await thread_meta_storage.search_threads( + metadata=body.metadata or None, + status=body.status, + user_id=body.user_id, + assistant_id=body.assistant_id, + limit=body.limit, + offset=body.offset, + ) + else: + with bind_request_actor_context(request): + threads = await thread_meta_storage.search_threads( + metadata=body.metadata or None, + status=body.status, + assistant_id=body.assistant_id, + limit=body.limit, + offset=body.offset, + ) + except Exception: + logger.exception("Failed to search threads") + raise HTTPException(status_code=500, detail="Failed to search threads") + + return [ + ThreadResponse( + thread_id=t.thread_id, + status=t.status, + created_at=t.created_time.isoformat() if t.created_time else "", + updated_at=t.updated_time.isoformat() if t.updated_time else "", + metadata=t.metadata, + values={"title": t.display_name} if t.display_name else {}, + interrupts={}, + ) + for t in threads + ] + + +@router.delete("/{thread_id}", response_model=ThreadDeleteResponse) +async def delete_thread( + thread_id: str, + checkpointer: CurrentCheckpointer, + thread_meta_storage: CurrentThreadMetaStorage, +) -> ThreadDeleteResponse: + """Delete a thread and all associated data.""" + response = _delete_thread_data(thread_id) + + # Remove checkpoints (best-effort) + try: + if hasattr(checkpointer, "adelete_thread"): + await checkpointer.adelete_thread(thread_id) + except Exception: + logger.debug("Could not delete checkpoints for thread %s", sanitize_log_param(thread_id)) + + # Remove thread_meta (best-effort) + try: + await thread_meta_storage.delete_thread(thread_id) + except Exception: + logger.debug("Could not delete thread_meta for %s", sanitize_log_param(thread_id)) + + return response + + +@router.get("/{thread_id}/state", response_model=ThreadStateResponse) +async def get_thread_state( + thread_id: str, + request: Request, + checkpointer: CurrentCheckpointer, + thread_meta_storage: CurrentThreadMetaStorage, + run_repo: CurrentRunRepository, +) -> ThreadStateResponse: + """Get the latest state snapshot for a thread.""" + config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} + + try: + checkpoint_tuple = await checkpointer.aget_tuple(config) + except Exception: + logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id)) + raise HTTPException(status_code=500, detail="Failed to get thread state") + + if checkpoint_tuple is None: + if await _thread_or_run_exists( + request=request, + thread_id=thread_id, + thread_meta_storage=thread_meta_storage, + run_repo=run_repo, + ): + return ThreadStateResponse() + raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") + + checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} + metadata = getattr(checkpoint_tuple, "metadata", {}) or {} + channel_values = checkpoint.get("channel_values", {}) + + ckpt_config = getattr(checkpoint_tuple, "config", {}) or {} + checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id") + + parent_config = getattr(checkpoint_tuple, "parent_config", None) + parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id") if parent_config else None + + tasks_raw = getattr(checkpoint_tuple, "tasks", []) or [] + next_nodes = [t.name for t in tasks_raw if hasattr(t, "name")] + tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw] + + return ThreadStateResponse( + values=serialize_channel_values(channel_values), + next=next_nodes, + tasks=tasks, + checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))}, + checkpoint_id=checkpoint_id, + parent_checkpoint_id=parent_checkpoint_id, + metadata=metadata, + created_at=str(metadata.get("created_at", "")), + ) + + +@router.post("/{thread_id}/state", response_model=ThreadStateResponse) +async def update_thread_state( + thread_id: str, + body: ThreadStateUpdateRequest, + checkpointer: CurrentCheckpointer, + thread_meta_storage: CurrentThreadMetaStorage, +) -> ThreadStateResponse: + """Update thread state (human-in-the-loop or title rename).""" + read_config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} + if body.checkpoint_id: + read_config["configurable"]["checkpoint_id"] = body.checkpoint_id + + try: + checkpoint_tuple = await checkpointer.aget_tuple(read_config) + except Exception: + logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id)) + raise HTTPException(status_code=500, detail="Failed to get thread state") + + if checkpoint_tuple is None: + raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") + + checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {}) + metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {}) + channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {})) + + if body.values: + channel_values.update(body.values) + + checkpoint["channel_values"] = channel_values + metadata["updated_at"] = time.time() + + if body.as_node: + metadata["source"] = "update" + metadata["step"] = metadata.get("step", 0) + 1 + metadata["writes"] = {body.as_node: body.values} + + write_config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} + try: + new_config = await checkpointer.aput(write_config, checkpoint, metadata, {}) + except Exception: + logger.exception("Failed to update state for thread %s", sanitize_log_param(thread_id)) + raise HTTPException(status_code=500, detail="Failed to update thread state") + + new_checkpoint_id: str | None = None + if isinstance(new_config, dict): + new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id") + + # Sync title to thread_meta + if body.values and "title" in body.values: + new_title = body.values["title"] + if new_title: + try: + await thread_meta_storage.sync_thread_title( + thread_id=thread_id, + title=new_title, + ) + except Exception: + logger.debug("Failed to sync title for %s", sanitize_log_param(thread_id)) + + return ThreadStateResponse( + values=serialize_channel_values(channel_values), + next=[], + metadata=metadata, + checkpoint_id=new_checkpoint_id, + created_at=str(metadata.get("created_at", "")), + ) + + +@router.post("/{thread_id}/history", response_model=list[HistoryEntry]) +async def get_thread_history( + thread_id: str, + body: ThreadHistoryRequest, + request: Request, + checkpointer: CurrentCheckpointer, + thread_meta_storage: CurrentThreadMetaStorage, + run_repo: CurrentRunRepository, +) -> list[HistoryEntry]: + """Get checkpoint history for a thread.""" + config: dict[str, Any] = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} + if body.before: + config["configurable"]["checkpoint_id"] = body.before + + entries: list[HistoryEntry] = [] + is_first = True + + try: + async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit): + ckpt_config = getattr(checkpoint_tuple, "config", {}) or {} + parent_config = getattr(checkpoint_tuple, "parent_config", None) + metadata = getattr(checkpoint_tuple, "metadata", {}) or {} + checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} + + checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "") + parent_id = parent_config.get("configurable", {}).get("checkpoint_id") if parent_config else None + channel_values = checkpoint.get("channel_values", {}) + + values: dict[str, Any] = {} + if title := channel_values.get("title"): + values["title"] = title + if is_first and (messages := channel_values.get("messages")): + values["messages"] = serialize_channel_values({"messages": messages}).get("messages", []) + is_first = False + + tasks_raw = getattr(checkpoint_tuple, "tasks", []) or [] + next_nodes = [t.name for t in tasks_raw if hasattr(t, "name")] + + entries.append( + HistoryEntry( + checkpoint_id=checkpoint_id, + parent_checkpoint_id=parent_id, + metadata=metadata, + values=values, + created_at=str(metadata.get("created_at", "")), + next=next_nodes, + ) + ) + except Exception: + logger.exception("Failed to get history for thread %s", sanitize_log_param(thread_id)) + raise HTTPException(status_code=500, detail="Failed to get thread history") + + if not entries and await _thread_or_run_exists( + request=request, + thread_id=thread_id, + thread_meta_storage=thread_meta_storage, + run_repo=run_repo, + ): + return [] + + return entries diff --git a/backend/app/gateway/routers/memory.py b/backend/app/gateway/routers/memory.py index ca9e5f5e5..191a20828 100644 --- a/backend/app/gateway/routers/memory.py +++ b/backend/app/gateway/routers/memory.py @@ -1,8 +1,9 @@ """Memory API router for retrieving and managing global memory data.""" -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel, Field +from app.plugins.auth.security.actor_context import bind_request_actor_context from deerflow.agents.memory.updater import ( clear_memory_data, create_memory_fact, @@ -13,7 +14,7 @@ from deerflow.agents.memory.updater import ( update_memory_fact, ) from deerflow.config.memory_config import get_memory_config -from deerflow.runtime.user_context import get_effective_user_id +from deerflow.runtime.actor_context import get_effective_user_id router = APIRouter(prefix="/api", tags=["memory"]) @@ -114,7 +115,7 @@ class MemoryStatusResponse(BaseModel): summary="Get Memory Data", description="Retrieve the current global memory data including user context, history, and facts.", ) -async def get_memory() -> MemoryResponse: +async def get_memory(request: Request) -> MemoryResponse: """Get the current global memory data. Returns: @@ -148,8 +149,9 @@ async def get_memory() -> MemoryResponse: } ``` """ - memory_data = get_memory_data(user_id=get_effective_user_id()) - return MemoryResponse(**memory_data) + with bind_request_actor_context(request): + memory_data = get_memory_data(user_id=get_effective_user_id()) + return MemoryResponse(**memory_data) @router.post( @@ -159,7 +161,7 @@ async def get_memory() -> MemoryResponse: summary="Reload Memory Data", description="Reload memory data from the storage file, refreshing the in-memory cache.", ) -async def reload_memory() -> MemoryResponse: +async def reload_memory(request: Request) -> MemoryResponse: """Reload memory data from file. This forces a reload of the memory data from the storage file, @@ -168,8 +170,9 @@ async def reload_memory() -> MemoryResponse: Returns: The reloaded memory data. """ - memory_data = reload_memory_data(user_id=get_effective_user_id()) - return MemoryResponse(**memory_data) + with bind_request_actor_context(request): + memory_data = reload_memory_data(user_id=get_effective_user_id()) + return MemoryResponse(**memory_data) @router.delete( @@ -179,14 +182,15 @@ async def reload_memory() -> MemoryResponse: summary="Clear All Memory Data", description="Delete all saved memory data and reset the memory structure to an empty state.", ) -async def clear_memory() -> MemoryResponse: +async def clear_memory(request: Request) -> MemoryResponse: """Clear all persisted memory data.""" - try: - memory_data = clear_memory_data(user_id=get_effective_user_id()) - except OSError as exc: - raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc + with bind_request_actor_context(request): + try: + memory_data = clear_memory_data(user_id=get_effective_user_id()) + except OSError as exc: + raise HTTPException(status_code=500, detail="Failed to clear memory data.") from exc - return MemoryResponse(**memory_data) + return MemoryResponse(**memory_data) @router.post( @@ -196,21 +200,22 @@ async def clear_memory() -> MemoryResponse: summary="Create Memory Fact", description="Create a single saved memory fact manually.", ) -async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryResponse: +async def create_memory_fact_endpoint(request: Request, payload: FactCreateRequest) -> MemoryResponse: """Create a single fact manually.""" - try: - memory_data = create_memory_fact( - content=request.content, - category=request.category, - confidence=request.confidence, - user_id=get_effective_user_id(), - ) - except ValueError as exc: - raise _map_memory_fact_value_error(exc) from exc - except OSError as exc: - raise HTTPException(status_code=500, detail="Failed to create memory fact.") from exc + with bind_request_actor_context(request): + try: + memory_data = create_memory_fact( + content=payload.content, + category=payload.category, + confidence=payload.confidence, + user_id=get_effective_user_id(), + ) + except ValueError as exc: + raise _map_memory_fact_value_error(exc) from exc + except OSError as exc: + raise HTTPException(status_code=500, detail="Failed to create memory fact.") from exc - return MemoryResponse(**memory_data) + return MemoryResponse(**memory_data) @router.delete( @@ -220,16 +225,17 @@ async def create_memory_fact_endpoint(request: FactCreateRequest) -> MemoryRespo summary="Delete Memory Fact", description="Delete a single saved memory fact by its fact id.", ) -async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse: +async def delete_memory_fact_endpoint(fact_id: str, request: Request) -> MemoryResponse: """Delete a single fact from memory by fact id.""" - try: - memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id()) - except KeyError as exc: - raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc - except OSError as exc: - raise HTTPException(status_code=500, detail="Failed to delete memory fact.") from exc + with bind_request_actor_context(request): + try: + memory_data = delete_memory_fact(fact_id, user_id=get_effective_user_id()) + except KeyError as exc: + raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc + except OSError as exc: + raise HTTPException(status_code=500, detail="Failed to delete memory fact.") from exc - return MemoryResponse(**memory_data) + return MemoryResponse(**memory_data) @router.patch( @@ -239,24 +245,25 @@ async def delete_memory_fact_endpoint(fact_id: str) -> MemoryResponse: summary="Patch Memory Fact", description="Partially update a single saved memory fact by its fact id while preserving omitted fields.", ) -async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) -> MemoryResponse: +async def update_memory_fact_endpoint(fact_id: str, request: Request, payload: FactPatchRequest) -> MemoryResponse: """Partially update a single fact manually.""" - try: - memory_data = update_memory_fact( - fact_id=fact_id, - content=request.content, - category=request.category, - confidence=request.confidence, - user_id=get_effective_user_id(), - ) - except ValueError as exc: - raise _map_memory_fact_value_error(exc) from exc - except KeyError as exc: - raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc - except OSError as exc: - raise HTTPException(status_code=500, detail="Failed to update memory fact.") from exc + with bind_request_actor_context(request): + try: + memory_data = update_memory_fact( + fact_id=fact_id, + content=payload.content, + category=payload.category, + confidence=payload.confidence, + user_id=get_effective_user_id(), + ) + except ValueError as exc: + raise _map_memory_fact_value_error(exc) from exc + except KeyError as exc: + raise HTTPException(status_code=404, detail=f"Memory fact '{fact_id}' not found.") from exc + except OSError as exc: + raise HTTPException(status_code=500, detail="Failed to update memory fact.") from exc - return MemoryResponse(**memory_data) + return MemoryResponse(**memory_data) @router.get( @@ -266,10 +273,11 @@ async def update_memory_fact_endpoint(fact_id: str, request: FactPatchRequest) - summary="Export Memory Data", description="Export the current global memory data as JSON for backup or transfer.", ) -async def export_memory() -> MemoryResponse: +async def export_memory(request: Request) -> MemoryResponse: """Export the current memory data.""" - memory_data = get_memory_data(user_id=get_effective_user_id()) - return MemoryResponse(**memory_data) + with bind_request_actor_context(request): + memory_data = get_memory_data(user_id=get_effective_user_id()) + return MemoryResponse(**memory_data) @router.post( @@ -279,14 +287,15 @@ async def export_memory() -> MemoryResponse: summary="Import Memory Data", description="Import and overwrite the current global memory data from a JSON payload.", ) -async def import_memory(request: MemoryResponse) -> MemoryResponse: +async def import_memory(request: Request, payload: MemoryResponse) -> MemoryResponse: """Import and persist memory data.""" - try: - memory_data = import_memory_data(request.model_dump(), user_id=get_effective_user_id()) - except OSError as exc: - raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc + with bind_request_actor_context(request): + try: + memory_data = import_memory_data(payload.model_dump(), user_id=get_effective_user_id()) + except OSError as exc: + raise HTTPException(status_code=500, detail="Failed to import memory data.") from exc - return MemoryResponse(**memory_data) + return MemoryResponse(**memory_data) @router.get( @@ -333,24 +342,25 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse: summary="Get Memory Status", description="Retrieve both memory configuration and current data in a single request.", ) -async def get_memory_status() -> MemoryStatusResponse: +async def get_memory_status(request: Request) -> MemoryStatusResponse: """Get the memory system status including configuration and data. Returns: Combined memory configuration and current data. """ - config = get_memory_config() - memory_data = get_memory_data(user_id=get_effective_user_id()) + with bind_request_actor_context(request): + config = get_memory_config() + memory_data = get_memory_data(user_id=get_effective_user_id()) - return MemoryStatusResponse( - config=MemoryConfigResponse( - enabled=config.enabled, - storage_path=config.storage_path, - debounce_seconds=config.debounce_seconds, - max_facts=config.max_facts, - fact_confidence_threshold=config.fact_confidence_threshold, - injection_enabled=config.injection_enabled, - max_injection_tokens=config.max_injection_tokens, - ), - data=MemoryResponse(**memory_data), - ) + return MemoryStatusResponse( + config=MemoryConfigResponse( + enabled=config.enabled, + storage_path=config.storage_path, + debounce_seconds=config.debounce_seconds, + max_facts=config.max_facts, + fact_confidence_threshold=config.fact_confidence_threshold, + injection_enabled=config.injection_enabled, + max_injection_tokens=config.max_injection_tokens, + ), + data=MemoryResponse(**memory_data), + ) diff --git a/backend/app/gateway/routers/runs.py b/backend/app/gateway/routers/runs.py deleted file mode 100644 index f2775466c..000000000 --- a/backend/app/gateway/routers/runs.py +++ /dev/null @@ -1,143 +0,0 @@ -"""Stateless runs endpoints -- stream and wait without a pre-existing thread. - -These endpoints auto-create a temporary thread when no ``thread_id`` is -supplied in the request body. When a ``thread_id`` **is** provided, it -is reused so that conversation history is preserved across calls. -""" - -from __future__ import annotations - -import asyncio -import logging -import uuid - -from fastapi import APIRouter, HTTPException, Query, Request -from fastapi.responses import StreamingResponse - -from app.gateway.authz import require_permission -from app.gateway.deps import get_checkpointer, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge -from app.gateway.routers.thread_runs import RunCreateRequest -from app.gateway.services import sse_consumer, start_run -from deerflow.runtime import serialize_channel_values - -logger = logging.getLogger(__name__) -router = APIRouter(prefix="/api/runs", tags=["runs"]) - - -def _resolve_thread_id(body: RunCreateRequest) -> str: - """Return the thread_id from the request body, or generate a new one.""" - thread_id = (body.config or {}).get("configurable", {}).get("thread_id") - if thread_id: - return str(thread_id) - return str(uuid.uuid4()) - - -@router.post("/stream") -async def stateless_stream(body: RunCreateRequest, request: Request) -> StreamingResponse: - """Create a run and stream events via SSE. - - If ``config.configurable.thread_id`` is provided, the run is created - on the given thread so that conversation history is preserved. - Otherwise a new temporary thread is created. - """ - thread_id = _resolve_thread_id(body) - bridge = get_stream_bridge(request) - run_mgr = get_run_manager(request) - record = await start_run(body, thread_id, request) - - return StreamingResponse( - sse_consumer(bridge, record, request, run_mgr), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - "Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}", - }, - ) - - -@router.post("/wait", response_model=dict) -async def stateless_wait(body: RunCreateRequest, request: Request) -> dict: - """Create a run and block until completion. - - If ``config.configurable.thread_id`` is provided, the run is created - on the given thread so that conversation history is preserved. - Otherwise a new temporary thread is created. - """ - thread_id = _resolve_thread_id(body) - record = await start_run(body, thread_id, request) - - if record.task is not None: - try: - await record.task - except asyncio.CancelledError: - pass - - checkpointer = get_checkpointer(request) - config = {"configurable": {"thread_id": thread_id}} - try: - checkpoint_tuple = await checkpointer.aget_tuple(config) - if checkpoint_tuple is not None: - checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} - channel_values = checkpoint.get("channel_values", {}) - return serialize_channel_values(channel_values) - except Exception: - logger.exception("Failed to fetch final state for run %s", record.run_id) - - return {"status": record.status.value, "error": record.error} - - -# --------------------------------------------------------------------------- -# Run-scoped read endpoints -# --------------------------------------------------------------------------- - - -async def _resolve_run(run_id: str, request: Request) -> dict: - """Fetch run by run_id with user ownership check. Raises 404 if not found.""" - run_store = get_run_store(request) - record = await run_store.get(run_id) # user_id=AUTO filters by contextvar - if record is None: - raise HTTPException(status_code=404, detail=f"Run {run_id} not found") - return record - - -@router.get("/{run_id}/messages") -@require_permission("runs", "read") -async def run_messages( - run_id: str, - request: Request, - limit: int = Query(default=50, le=200, ge=1), - before_seq: int | None = Query(default=None), - after_seq: int | None = Query(default=None), -) -> dict: - """Return paginated messages for a run (cursor-based). - - Pagination: - - after_seq: messages with seq > after_seq (forward) - - before_seq: messages with seq < before_seq (backward) - - neither: latest messages - - Response: { data: [...], has_more: bool } - """ - run = await _resolve_run(run_id, request) - event_store = get_run_event_store(request) - rows = await event_store.list_messages_by_run( - run["thread_id"], - run_id, - limit=limit + 1, - before_seq=before_seq, - after_seq=after_seq, - ) - has_more = len(rows) > limit - data = rows[:limit] if has_more else rows - return {"data": data, "has_more": has_more} - - -@router.get("/{run_id}/feedback") -@require_permission("runs", "read") -async def run_feedback(run_id: str, request: Request) -> list[dict]: - """Return all feedback for a run.""" - run = await _resolve_run(run_id, request) - feedback_repo = get_feedback_repo(request) - return await feedback_repo.list_by_run(run["thread_id"], run_id) diff --git a/backend/app/gateway/routers/suggestions.py b/backend/app/gateway/routers/suggestions.py index 0da5e4322..3624a4b41 100644 --- a/backend/app/gateway/routers/suggestions.py +++ b/backend/app/gateway/routers/suggestions.py @@ -5,7 +5,6 @@ from fastapi import APIRouter, Request from langchain_core.messages import HumanMessage, SystemMessage from pydantic import BaseModel, Field -from app.gateway.authz import require_permission from deerflow.models import create_chat_model logger = logging.getLogger(__name__) @@ -99,7 +98,6 @@ def _format_conversation(messages: list[SuggestionMessage]) -> str: summary="Generate Follow-up Questions", description="Generate short follow-up questions a user might ask next, based on recent conversation context.", ) -@require_permission("threads", "read", owner_check=True) async def generate_suggestions(thread_id: str, body: SuggestionsRequest, request: Request) -> SuggestionsResponse: if not body.messages: return SuggestionsResponse(suggestions=[]) diff --git a/backend/app/gateway/routers/thread_runs.py b/backend/app/gateway/routers/thread_runs.py deleted file mode 100644 index e6847c50f..000000000 --- a/backend/app/gateway/routers/thread_runs.py +++ /dev/null @@ -1,377 +0,0 @@ -"""Runs endpoints — create, stream, wait, cancel. - -Implements the LangGraph Platform runs API on top of -:class:`deerflow.agents.runs.RunManager` and -:class:`deerflow.agents.stream_bridge.StreamBridge`. - -SSE format is aligned with the LangGraph Platform protocol so that -the ``useStream`` React hook from ``@langchain/langgraph-sdk/react`` -works without modification. -""" - -from __future__ import annotations - -import asyncio -import logging -from typing import Any, Literal - -from fastapi import APIRouter, HTTPException, Query, Request -from fastapi.responses import Response, StreamingResponse -from pydantic import BaseModel, Field - -from app.gateway.authz import require_permission -from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge -from app.gateway.services import sse_consumer, start_run -from deerflow.runtime import RunRecord, serialize_channel_values - -logger = logging.getLogger(__name__) -router = APIRouter(prefix="/api/threads", tags=["runs"]) - - -# --------------------------------------------------------------------------- -# Request / response models -# --------------------------------------------------------------------------- - - -class RunCreateRequest(BaseModel): - assistant_id: str | None = Field(default=None, description="Agent / assistant to use") - input: dict[str, Any] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})") - command: dict[str, Any] | None = Field(default=None, description="LangGraph Command") - metadata: dict[str, Any] | None = Field(default=None, description="Run metadata") - config: dict[str, Any] | None = Field(default=None, description="RunnableConfig overrides") - context: dict[str, Any] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)") - webhook: str | None = Field(default=None, description="Completion callback URL") - checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint") - checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object") - interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before") - interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after") - stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)") - stream_subgraphs: bool = Field(default=False, description="Include subgraph events") - stream_resumable: bool | None = Field(default=None, description="SSE resumable mode") - on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect") - on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion") - multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy") - after_seconds: float | None = Field(default=None, description="Delayed execution") - if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy") - feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys") - - -class RunResponse(BaseModel): - run_id: str - thread_id: str - assistant_id: str | None = None - status: str - metadata: dict[str, Any] = Field(default_factory=dict) - kwargs: dict[str, Any] = Field(default_factory=dict) - multitask_strategy: str = "reject" - created_at: str = "" - updated_at: str = "" - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _record_to_response(record: RunRecord) -> RunResponse: - return RunResponse( - run_id=record.run_id, - thread_id=record.thread_id, - assistant_id=record.assistant_id, - status=record.status.value, - metadata=record.metadata, - kwargs=record.kwargs, - multitask_strategy=record.multitask_strategy, - created_at=record.created_at, - updated_at=record.updated_at, - ) - - -# --------------------------------------------------------------------------- -# Endpoints -# --------------------------------------------------------------------------- - - -@router.post("/{thread_id}/runs", response_model=RunResponse) -@require_permission("runs", "create", owner_check=True, require_existing=True) -async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse: - """Create a background run (returns immediately).""" - record = await start_run(body, thread_id, request) - return _record_to_response(record) - - -@router.post("/{thread_id}/runs/stream") -@require_permission("runs", "create", owner_check=True, require_existing=True) -async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse: - """Create a run and stream events via SSE. - - The response includes a ``Content-Location`` header with the run's - resource URL, matching the LangGraph Platform protocol. The - ``useStream`` React hook uses this to extract run metadata. - """ - bridge = get_stream_bridge(request) - run_mgr = get_run_manager(request) - record = await start_run(body, thread_id, request) - - return StreamingResponse( - sse_consumer(bridge, record, request, run_mgr), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - # LangGraph Platform includes run metadata in this header. - # The SDK uses a greedy regex to extract the run id from this path, - # so it must point at the canonical run resource without extra suffixes. - "Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}", - }, - ) - - -@router.post("/{thread_id}/runs/wait", response_model=dict) -@require_permission("runs", "create", owner_check=True, require_existing=True) -async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict: - """Create a run and block until it completes, returning the final state.""" - record = await start_run(body, thread_id, request) - - if record.task is not None: - try: - await record.task - except asyncio.CancelledError: - pass - - checkpointer = get_checkpointer(request) - config = {"configurable": {"thread_id": thread_id}} - try: - checkpoint_tuple = await checkpointer.aget_tuple(config) - if checkpoint_tuple is not None: - checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} - channel_values = checkpoint.get("channel_values", {}) - return serialize_channel_values(channel_values) - except Exception: - logger.exception("Failed to fetch final state for run %s", record.run_id) - - return {"status": record.status.value, "error": record.error} - - -@router.get("/{thread_id}/runs", response_model=list[RunResponse]) -@require_permission("runs", "read", owner_check=True) -async def list_runs(thread_id: str, request: Request) -> list[RunResponse]: - """List all runs for a thread.""" - run_mgr = get_run_manager(request) - records = await run_mgr.list_by_thread(thread_id) - return [_record_to_response(r) for r in records] - - -@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse) -@require_permission("runs", "read", owner_check=True) -async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse: - """Get details of a specific run.""" - run_mgr = get_run_manager(request) - record = run_mgr.get(run_id) - if record is None or record.thread_id != thread_id: - raise HTTPException(status_code=404, detail=f"Run {run_id} not found") - return _record_to_response(record) - - -@router.post("/{thread_id}/runs/{run_id}/cancel") -@require_permission("runs", "cancel", owner_check=True, require_existing=True) -async def cancel_run( - thread_id: str, - run_id: str, - request: Request, - wait: bool = Query(default=False, description="Block until run completes after cancel"), - action: Literal["interrupt", "rollback"] = Query(default="interrupt", description="Cancel action"), -) -> Response: - """Cancel a running or pending run. - - - action=interrupt: Stop execution, keep current checkpoint (can be resumed) - - action=rollback: Stop execution, revert to pre-run checkpoint state - - wait=true: Block until the run fully stops, return 204 - - wait=false: Return immediately with 202 - """ - run_mgr = get_run_manager(request) - record = run_mgr.get(run_id) - if record is None or record.thread_id != thread_id: - raise HTTPException(status_code=404, detail=f"Run {run_id} not found") - - cancelled = await run_mgr.cancel(run_id, action=action) - if not cancelled: - raise HTTPException( - status_code=409, - detail=f"Run {run_id} is not cancellable (status: {record.status.value})", - ) - - if wait and record.task is not None: - try: - await record.task - except asyncio.CancelledError: - pass - return Response(status_code=204) - - return Response(status_code=202) - - -@router.get("/{thread_id}/runs/{run_id}/join") -@require_permission("runs", "read", owner_check=True) -async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse: - """Join an existing run's SSE stream.""" - bridge = get_stream_bridge(request) - run_mgr = get_run_manager(request) - record = run_mgr.get(run_id) - if record is None or record.thread_id != thread_id: - raise HTTPException(status_code=404, detail=f"Run {run_id} not found") - - return StreamingResponse( - sse_consumer(bridge, record, request, run_mgr), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - - -@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None) -@require_permission("runs", "read", owner_check=True) -async def stream_existing_run( - thread_id: str, - run_id: str, - request: Request, - action: Literal["interrupt", "rollback"] | None = Query(default=None, description="Cancel action"), - wait: int = Query(default=0, description="Block until cancelled (1) or return immediately (0)"), -): - """Join an existing run's SSE stream (GET), or cancel-then-stream (POST). - - The LangGraph SDK's ``joinStream`` and ``useStream`` stop button both use - ``POST`` to this endpoint. When ``action=interrupt`` or ``action=rollback`` - is present the run is cancelled first; the response then streams any - remaining buffered events so the client observes a clean shutdown. - """ - run_mgr = get_run_manager(request) - record = run_mgr.get(run_id) - if record is None or record.thread_id != thread_id: - raise HTTPException(status_code=404, detail=f"Run {run_id} not found") - - # Cancel if an action was requested (stop-button / interrupt flow) - if action is not None: - cancelled = await run_mgr.cancel(run_id, action=action) - if cancelled and wait and record.task is not None: - try: - await record.task - except (asyncio.CancelledError, Exception): - pass - return Response(status_code=204) - - bridge = get_stream_bridge(request) - return StreamingResponse( - sse_consumer(bridge, record, request, run_mgr), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "X-Accel-Buffering": "no", - }, - ) - - -# --------------------------------------------------------------------------- -# Messages / Events / Token usage endpoints -# --------------------------------------------------------------------------- - - -@router.get("/{thread_id}/messages") -@require_permission("runs", "read", owner_check=True) -async def list_thread_messages( - thread_id: str, - request: Request, - limit: int = Query(default=50, le=200), - before_seq: int | None = Query(default=None), - after_seq: int | None = Query(default=None), -) -> list[dict]: - """Return displayable messages for a thread (across all runs), with feedback attached.""" - event_store = get_run_event_store(request) - messages = await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq) - - # Attach feedback to the last AI message of each run - feedback_repo = get_feedback_repo(request) - user_id = await get_current_user(request) - feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id) - - # Find the last ai_message per run_id - last_ai_per_run: dict[str, int] = {} # run_id -> index in messages list - for i, msg in enumerate(messages): - if msg.get("event_type") == "ai_message": - last_ai_per_run[msg["run_id"]] = i - - # Attach feedback field - last_ai_indices = set(last_ai_per_run.values()) - for i, msg in enumerate(messages): - if i in last_ai_indices: - run_id = msg["run_id"] - fb = feedback_map.get(run_id) - msg["feedback"] = ( - { - "feedback_id": fb["feedback_id"], - "rating": fb["rating"], - "comment": fb.get("comment"), - } - if fb - else None - ) - else: - msg["feedback"] = None - - return messages - - -@router.get("/{thread_id}/runs/{run_id}/messages") -@require_permission("runs", "read", owner_check=True) -async def list_run_messages( - thread_id: str, - run_id: str, - request: Request, - limit: int = Query(default=50, le=200, ge=1), - before_seq: int | None = Query(default=None), - after_seq: int | None = Query(default=None), -) -> dict: - """Return paginated messages for a specific run. - - Response: { data: [...], has_more: bool } - """ - event_store = get_run_event_store(request) - rows = await event_store.list_messages_by_run( - thread_id, - run_id, - limit=limit + 1, - before_seq=before_seq, - after_seq=after_seq, - ) - has_more = len(rows) > limit - data = rows[:limit] if has_more else rows - return {"data": data, "has_more": has_more} - - -@router.get("/{thread_id}/runs/{run_id}/events") -@require_permission("runs", "read", owner_check=True) -async def list_run_events( - thread_id: str, - run_id: str, - request: Request, - event_types: str | None = Query(default=None), - limit: int = Query(default=500, le=2000), -) -> list[dict]: - """Return the full event stream for a run (debug/audit).""" - event_store = get_run_event_store(request) - types = event_types.split(",") if event_types else None - return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit) - - -@router.get("/{thread_id}/token-usage") -@require_permission("threads", "read", owner_check=True) -async def thread_token_usage(thread_id: str, request: Request) -> dict: - """Thread-level token usage aggregation.""" - run_store = get_run_store(request) - agg = await run_store.aggregate_tokens_by_thread(thread_id) - return {"thread_id": thread_id, **agg} diff --git a/backend/app/gateway/routers/threads.py b/backend/app/gateway/routers/threads.py deleted file mode 100644 index c7bfa69b6..000000000 --- a/backend/app/gateway/routers/threads.py +++ /dev/null @@ -1,621 +0,0 @@ -"""Thread CRUD, state, and history endpoints. - -Combines the existing thread-local filesystem cleanup with LangGraph -Platform-compatible thread management backed by the checkpointer. - -Channel values returned in state responses are serialized through -:func:`deerflow.runtime.serialization.serialize_channel_values` to -ensure LangChain message objects are converted to JSON-safe dicts -matching the LangGraph Platform wire format expected by the -``useStream`` React hook. -""" - -from __future__ import annotations - -import logging -import re -import time -import uuid -from typing import Any - -from fastapi import APIRouter, HTTPException, Request -from pydantic import BaseModel, Field, field_validator - -from app.gateway.authz import require_permission -from app.gateway.deps import get_checkpointer -from app.gateway.utils import sanitize_log_param -from deerflow.config.paths import Paths, get_paths -from deerflow.runtime import serialize_channel_values -from deerflow.runtime.user_context import get_effective_user_id - -logger = logging.getLogger(__name__) -router = APIRouter(prefix="/api/threads", tags=["threads"]) - - -# Metadata keys that the server controls; clients are not allowed to set -# them. Pydantic ``@field_validator("metadata")`` strips them on every -# inbound model below so a malicious client cannot reflect a forged -# owner identity through the API surface. Defense-in-depth — the -# row-level invariant is still ``threads_meta.user_id`` populated from -# the auth contextvar; this list closes the metadata-blob echo gap. -_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"}) - - -def _strip_reserved_metadata(metadata: dict[str, Any] | None) -> dict[str, Any]: - """Return ``metadata`` with server-controlled keys removed.""" - if not metadata: - return metadata or {} - return {k: v for k, v in metadata.items() if k not in _SERVER_RESERVED_METADATA_KEYS} - - -# --------------------------------------------------------------------------- -# Response / request models -# --------------------------------------------------------------------------- - - -class ThreadDeleteResponse(BaseModel): - """Response model for thread cleanup.""" - - success: bool - message: str - - -class ThreadResponse(BaseModel): - """Response model for a single thread.""" - - thread_id: str = Field(description="Unique thread identifier") - status: str = Field(default="idle", description="Thread status: idle, busy, interrupted, error") - created_at: str = Field(default="", description="ISO timestamp") - updated_at: str = Field(default="", description="ISO timestamp") - metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata") - values: dict[str, Any] = Field(default_factory=dict, description="Current state channel values") - interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts") - - -class ThreadCreateRequest(BaseModel): - """Request body for creating a thread.""" - - thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)") - assistant_id: str | None = Field(default=None, description="Associate thread with an assistant") - metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata") - - _strip_reserved = field_validator("metadata")(classmethod(lambda cls, v: _strip_reserved_metadata(v))) - - -class ThreadSearchRequest(BaseModel): - """Request body for searching threads.""" - - metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)") - limit: int = Field(default=100, ge=1, le=1000, description="Maximum results") - offset: int = Field(default=0, ge=0, description="Pagination offset") - status: str | None = Field(default=None, description="Filter by thread status") - - -class ThreadStateResponse(BaseModel): - """Response model for thread state.""" - - values: dict[str, Any] = Field(default_factory=dict, description="Current channel values") - next: list[str] = Field(default_factory=list, description="Next tasks to execute") - metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata") - checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info") - checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID") - parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID") - created_at: str | None = Field(default=None, description="Checkpoint timestamp") - tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details") - - -class ThreadPatchRequest(BaseModel): - """Request body for patching thread metadata.""" - - metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata to merge") - - _strip_reserved = field_validator("metadata")(classmethod(lambda cls, v: _strip_reserved_metadata(v))) - - -class ThreadStateUpdateRequest(BaseModel): - """Request body for updating thread state (human-in-the-loop resume).""" - - values: dict[str, Any] | None = Field(default=None, description="Channel values to merge") - checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from") - checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object") - as_node: str | None = Field(default=None, description="Node identity for the update") - - -class HistoryEntry(BaseModel): - """Single checkpoint history entry.""" - - checkpoint_id: str - parent_checkpoint_id: str | None = None - metadata: dict[str, Any] = Field(default_factory=dict) - values: dict[str, Any] = Field(default_factory=dict) - created_at: str | None = None - next: list[str] = Field(default_factory=list) - - -class ThreadHistoryRequest(BaseModel): - """Request body for checkpoint history.""" - - limit: int = Field(default=10, ge=1, le=100, description="Maximum entries") - before: str | None = Field(default=None, description="Cursor for pagination") - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _delete_thread_data(thread_id: str, paths: Paths | None = None, *, user_id: str | None = None) -> ThreadDeleteResponse: - """Delete local persisted filesystem data for a thread.""" - path_manager = paths or get_paths() - try: - path_manager.delete_thread_dir(thread_id, user_id=user_id) - except ValueError as exc: - raise HTTPException(status_code=422, detail=str(exc)) from exc - except FileNotFoundError: - # Not critical — thread data may not exist on disk - logger.debug("No local thread data to delete for %s", sanitize_log_param(thread_id)) - return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}") - except Exception as exc: - logger.exception("Failed to delete thread data for %s", sanitize_log_param(thread_id)) - raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc - - logger.info("Deleted local thread data for %s", sanitize_log_param(thread_id)) - return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}") - - -def _derive_thread_status(checkpoint_tuple) -> str: - """Derive thread status from checkpoint metadata.""" - if checkpoint_tuple is None: - return "idle" - pending_writes = getattr(checkpoint_tuple, "pending_writes", None) or [] - - # Check for error in pending writes - for pw in pending_writes: - if len(pw) >= 2 and pw[1] == "__error__": - return "error" - - # Check for pending next tasks (indicates interrupt) - tasks = getattr(checkpoint_tuple, "tasks", None) - if tasks: - return "interrupted" - - return "idle" - - -# --------------------------------------------------------------------------- -# Endpoints -# --------------------------------------------------------------------------- - - -@router.delete("/{thread_id}", response_model=ThreadDeleteResponse) -@require_permission("threads", "delete", owner_check=True, require_existing=True) -async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse: - """Delete local persisted filesystem data for a thread. - - Cleans DeerFlow-managed thread directories, removes checkpoint data, - and removes the thread_meta row from the configured ThreadMetaStore - (sqlite or memory). - """ - from app.gateway.deps import get_thread_store - - # Clean local filesystem - response = _delete_thread_data(thread_id, user_id=get_effective_user_id()) - - # Remove checkpoints (best-effort) - checkpointer = getattr(request.app.state, "checkpointer", None) - if checkpointer is not None: - try: - if hasattr(checkpointer, "adelete_thread"): - await checkpointer.adelete_thread(thread_id) - except Exception: - logger.debug("Could not delete checkpoints for thread %s (not critical)", sanitize_log_param(thread_id)) - - # Remove thread_meta row (best-effort) — required for sqlite backend - # so the deleted thread no longer appears in /threads/search. - try: - thread_store = get_thread_store(request) - await thread_store.delete(thread_id) - except Exception: - logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id)) - - return response - - -@router.post("", response_model=ThreadResponse) -async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse: - """Create a new thread. - - Writes a thread_meta record (so the thread appears in /threads/search) - and an empty checkpoint (so state endpoints work immediately). - Idempotent: returns the existing record when ``thread_id`` already exists. - """ - from app.gateway.deps import get_thread_store - - checkpointer = get_checkpointer(request) - thread_store = get_thread_store(request) - thread_id = body.thread_id or str(uuid.uuid4()) - now = time.time() - # ``body.metadata`` is already stripped of server-reserved keys by - # ``ThreadCreateRequest._strip_reserved`` — see the model definition. - - # Idempotency: return existing record when already present - existing_record = await thread_store.get(thread_id) - if existing_record is not None: - return ThreadResponse( - thread_id=thread_id, - status=existing_record.get("status", "idle"), - created_at=str(existing_record.get("created_at", "")), - updated_at=str(existing_record.get("updated_at", "")), - metadata=existing_record.get("metadata", {}), - ) - - # Write thread_meta so the thread appears in /threads/search immediately - try: - await thread_store.create( - thread_id, - assistant_id=getattr(body, "assistant_id", None), - metadata=body.metadata, - ) - except Exception: - logger.exception("Failed to write thread_meta for %s", sanitize_log_param(thread_id)) - raise HTTPException(status_code=500, detail="Failed to create thread") - - # Write an empty checkpoint so state endpoints work immediately - config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} - try: - from langgraph.checkpoint.base import empty_checkpoint - - ckpt_metadata = { - "step": -1, - "source": "input", - "writes": None, - "parents": {}, - **body.metadata, - "created_at": now, - } - await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {}) - except Exception: - logger.exception("Failed to create checkpoint for thread %s", sanitize_log_param(thread_id)) - raise HTTPException(status_code=500, detail="Failed to create thread") - - logger.info("Thread created: %s", sanitize_log_param(thread_id)) - return ThreadResponse( - thread_id=thread_id, - status="idle", - created_at=str(now), - updated_at=str(now), - metadata=body.metadata, - ) - - -@router.post("/search", response_model=list[ThreadResponse]) -async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]: - """Search and list threads. - - Delegates to the configured ThreadMetaStore implementation - (SQL-backed for sqlite/postgres, Store-backed for memory mode). - """ - from app.gateway.deps import get_thread_store - - repo = get_thread_store(request) - rows = await repo.search( - metadata=body.metadata or None, - status=body.status, - limit=body.limit, - offset=body.offset, - ) - return [ - ThreadResponse( - thread_id=r["thread_id"], - status=r.get("status", "idle"), - created_at=r.get("created_at", ""), - updated_at=r.get("updated_at", ""), - metadata=r.get("metadata", {}), - values={"title": r["display_name"]} if r.get("display_name") else {}, - interrupts={}, - ) - for r in rows - ] - - -@router.patch("/{thread_id}", response_model=ThreadResponse) -@require_permission("threads", "write", owner_check=True, require_existing=True) -async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse: - """Merge metadata into a thread record.""" - from app.gateway.deps import get_thread_store - - thread_store = get_thread_store(request) - record = await thread_store.get(thread_id) - if record is None: - raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") - - # ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``. - try: - await thread_store.update_metadata(thread_id, body.metadata) - except Exception: - logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id)) - raise HTTPException(status_code=500, detail="Failed to update thread") - - # Re-read to get the merged metadata + refreshed updated_at - record = await thread_store.get(thread_id) or record - return ThreadResponse( - thread_id=thread_id, - status=record.get("status", "idle"), - created_at=str(record.get("created_at", "")), - updated_at=str(record.get("updated_at", "")), - metadata=record.get("metadata", {}), - ) - - -@router.get("/{thread_id}", response_model=ThreadResponse) -@require_permission("threads", "read", owner_check=True) -async def get_thread(thread_id: str, request: Request) -> ThreadResponse: - """Get thread info. - - Reads metadata from the ThreadMetaStore and derives the accurate - execution status from the checkpointer. Falls back to the checkpointer - alone for threads that pre-date ThreadMetaStore adoption (backward compat). - """ - from app.gateway.deps import get_thread_store - - thread_store = get_thread_store(request) - checkpointer = get_checkpointer(request) - - record: dict | None = await thread_store.get(thread_id) - - # Derive accurate status from the checkpointer - config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} - try: - checkpoint_tuple = await checkpointer.aget_tuple(config) - except Exception: - logger.exception("Failed to get checkpoint for thread %s", sanitize_log_param(thread_id)) - raise HTTPException(status_code=500, detail="Failed to get thread") - - if record is None and checkpoint_tuple is None: - raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") - - # If the thread exists in the checkpointer but not in thread_meta (e.g. - # legacy data created before thread_meta adoption), synthesize a minimal - # record from the checkpoint metadata. - if record is None and checkpoint_tuple is not None: - ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {} - record = { - "thread_id": thread_id, - "status": "idle", - "created_at": ckpt_meta.get("created_at", ""), - "updated_at": ckpt_meta.get("updated_at", ckpt_meta.get("created_at", "")), - "metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}, - } - - if record is None: - raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") - - status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle") - checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} if checkpoint_tuple is not None else {} - channel_values = checkpoint.get("channel_values", {}) - - return ThreadResponse( - thread_id=thread_id, - status=status, - created_at=str(record.get("created_at", "")), - updated_at=str(record.get("updated_at", "")), - metadata=record.get("metadata", {}), - values=serialize_channel_values(channel_values), - ) - - -# --------------------------------------------------------------------------- -@router.get("/{thread_id}/state", response_model=ThreadStateResponse) -@require_permission("threads", "read", owner_check=True) -async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse: - """Get the latest state snapshot for a thread. - - Channel values are serialized to ensure LangChain message objects - are converted to JSON-safe dicts. - """ - checkpointer = get_checkpointer(request) - - config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} - try: - checkpoint_tuple = await checkpointer.aget_tuple(config) - except Exception: - logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id)) - raise HTTPException(status_code=500, detail="Failed to get thread state") - - if checkpoint_tuple is None: - raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") - - checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} - metadata = getattr(checkpoint_tuple, "metadata", {}) or {} - checkpoint_id = None - ckpt_config = getattr(checkpoint_tuple, "config", {}) - if ckpt_config: - checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id") - - channel_values = checkpoint.get("channel_values", {}) - - parent_config = getattr(checkpoint_tuple, "parent_config", None) - parent_checkpoint_id = None - if parent_config: - parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id") - - tasks_raw = getattr(checkpoint_tuple, "tasks", []) or [] - next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")] - tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw] - - values = serialize_channel_values(channel_values) - - return ThreadStateResponse( - values=values, - next=next_tasks, - metadata=metadata, - checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))}, - checkpoint_id=checkpoint_id, - parent_checkpoint_id=parent_checkpoint_id, - created_at=str(metadata.get("created_at", "")), - tasks=tasks, - ) - - -@router.post("/{thread_id}/state", response_model=ThreadStateResponse) -@require_permission("threads", "write", owner_check=True, require_existing=True) -async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse: - """Update thread state (e.g. for human-in-the-loop resume or title rename). - - Writes a new checkpoint that merges *body.values* into the latest - channel values, then syncs any updated ``title`` field through the - ThreadMetaStore abstraction so that ``/threads/search`` reflects the - change immediately in both sqlite and memory backends. - """ - from app.gateway.deps import get_thread_store - - checkpointer = get_checkpointer(request) - thread_store = get_thread_store(request) - - # checkpoint_ns must be present in the config for aput — default to "" - # (the root graph namespace). checkpoint_id is optional; omitting it - # fetches the latest checkpoint for the thread. - read_config: dict[str, Any] = { - "configurable": { - "thread_id": thread_id, - "checkpoint_ns": "", - } - } - if body.checkpoint_id: - read_config["configurable"]["checkpoint_id"] = body.checkpoint_id - - try: - checkpoint_tuple = await checkpointer.aget_tuple(read_config) - except Exception: - logger.exception("Failed to get state for thread %s", sanitize_log_param(thread_id)) - raise HTTPException(status_code=500, detail="Failed to get thread state") - - if checkpoint_tuple is None: - raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") - - # Work on mutable copies so we don't accidentally mutate cached objects. - checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {}) - metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {}) - channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {})) - - if body.values: - channel_values.update(body.values) - - checkpoint["channel_values"] = channel_values - metadata["updated_at"] = time.time() - - if body.as_node: - metadata["source"] = "update" - metadata["step"] = metadata.get("step", 0) + 1 - metadata["writes"] = {body.as_node: body.values} - - # aput requires checkpoint_ns in the config — use the same config used for the - # read (which always includes checkpoint_ns=""). Do NOT include checkpoint_id - # so that aput generates a fresh checkpoint ID for the new snapshot. - write_config: dict[str, Any] = { - "configurable": { - "thread_id": thread_id, - "checkpoint_ns": "", - } - } - try: - new_config = await checkpointer.aput(write_config, checkpoint, metadata, {}) - except Exception: - logger.exception("Failed to update state for thread %s", sanitize_log_param(thread_id)) - raise HTTPException(status_code=500, detail="Failed to update thread state") - - new_checkpoint_id: str | None = None - if isinstance(new_config, dict): - new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id") - - # Sync title changes through the ThreadMetaStore abstraction so /threads/search - # reflects them immediately in both sqlite and memory backends. - if body.values and "title" in body.values: - new_title = body.values["title"] - if new_title: # Skip empty strings and None - try: - await thread_store.update_display_name(thread_id, new_title) - except Exception: - logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id)) - - return ThreadStateResponse( - values=serialize_channel_values(channel_values), - next=[], - metadata=metadata, - checkpoint_id=new_checkpoint_id, - created_at=str(metadata.get("created_at", "")), - ) - - -@router.post("/{thread_id}/history", response_model=list[HistoryEntry]) -@require_permission("threads", "read", owner_check=True) -async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]: - """Get checkpoint history for a thread. - - Messages are read from the checkpointer's channel values (the - authoritative source) and serialized via - :func:`~deerflow.runtime.serialization.serialize_channel_values`. - Only the latest (first) checkpoint carries the ``messages`` key to - avoid duplicating them across every entry. - """ - checkpointer = get_checkpointer(request) - - config: dict[str, Any] = {"configurable": {"thread_id": thread_id}} - if body.before: - config["configurable"]["checkpoint_id"] = body.before - - entries: list[HistoryEntry] = [] - is_latest_checkpoint = True - try: - async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit): - ckpt_config = getattr(checkpoint_tuple, "config", {}) - parent_config = getattr(checkpoint_tuple, "parent_config", None) - metadata = getattr(checkpoint_tuple, "metadata", {}) or {} - checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} - - checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "") - parent_id = None - if parent_config: - parent_id = parent_config.get("configurable", {}).get("checkpoint_id") - - channel_values = checkpoint.get("channel_values", {}) - - # Build values from checkpoint channel_values - values: dict[str, Any] = {} - if title := channel_values.get("title"): - values["title"] = title - if thread_data := channel_values.get("thread_data"): - values["thread_data"] = thread_data - - # Attach messages only to the latest checkpoint entry. - if is_latest_checkpoint: - messages = channel_values.get("messages") - if messages: - values["messages"] = serialize_channel_values({"messages": messages}).get("messages", []) - is_latest_checkpoint = False - - # Derive next tasks - tasks_raw = getattr(checkpoint_tuple, "tasks", []) or [] - next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")] - - # Strip LangGraph internal keys from metadata - user_meta = {k: v for k, v in metadata.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")} - # Keep step for ordering context - if "step" in metadata: - user_meta["step"] = metadata["step"] - - entries.append( - HistoryEntry( - checkpoint_id=checkpoint_id, - parent_checkpoint_id=parent_id, - metadata=user_meta, - values=values, - created_at=str(metadata.get("created_at", "")), - next=next_tasks, - ) - ) - except Exception: - logger.exception("Failed to get history for thread %s", sanitize_log_param(thread_id)) - raise HTTPException(status_code=500, detail="Failed to get thread history") - - return entries diff --git a/backend/app/gateway/routers/uploads.py b/backend/app/gateway/routers/uploads.py index e31ff11d2..ef7ba8a1f 100644 --- a/backend/app/gateway/routers/uploads.py +++ b/backend/app/gateway/routers/uploads.py @@ -7,10 +7,10 @@ import stat from fastapi import APIRouter, File, HTTPException, Request, UploadFile from pydantic import BaseModel -from app.gateway.authz import require_permission -from deerflow.config.paths import get_paths -from deerflow.runtime.user_context import get_effective_user_id +from app.plugins.auth.security.actor_context import bind_request_actor_context from deerflow.sandbox.sandbox_provider import get_sandbox_provider +from deerflow.config.paths import get_paths +from deerflow.runtime.actor_context import get_effective_user_id from deerflow.uploads.manager import ( PathTraversalError, delete_file_safe, @@ -56,7 +56,6 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None: @router.post("", response_model=UploadResponse) -@require_permission("threads", "write", owner_check=True, require_existing=False) async def upload_files( thread_id: str, request: Request, @@ -66,68 +65,69 @@ async def upload_files( if not files: raise HTTPException(status_code=400, detail="No files provided") - try: - uploads_dir = ensure_uploads_dir(thread_id) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) - uploaded_files = [] - - sandbox_provider = get_sandbox_provider() - sandbox_id = sandbox_provider.acquire(thread_id) - sandbox = sandbox_provider.get(sandbox_id) - - for file in files: - if not file.filename: - continue - + with bind_request_actor_context(request): try: - safe_filename = normalize_filename(file.filename) - except ValueError: - logger.warning(f"Skipping file with unsafe filename: {file.filename!r}") - continue + uploads_dir = ensure_uploads_dir(thread_id) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) + uploaded_files = [] - try: - content = await file.read() - file_path = uploads_dir / safe_filename - file_path.write_bytes(content) + sandbox_provider = get_sandbox_provider() + sandbox_id = sandbox_provider.acquire(thread_id) + sandbox = sandbox_provider.get(sandbox_id) - virtual_path = upload_virtual_path(safe_filename) + for file in files: + if not file.filename: + continue - if sandbox_id != "local": - _make_file_sandbox_writable(file_path) - sandbox.update_file(virtual_path, content) + try: + safe_filename = normalize_filename(file.filename) + except ValueError: + logger.warning(f"Skipping file with unsafe filename: {file.filename!r}") + continue - file_info = { - "filename": safe_filename, - "size": str(len(content)), - "path": str(sandbox_uploads / safe_filename), - "virtual_path": virtual_path, - "artifact_url": upload_artifact_url(thread_id, safe_filename), - } + try: + content = await file.read() + file_path = uploads_dir / safe_filename + file_path.write_bytes(content) - logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}") + virtual_path = upload_virtual_path(safe_filename) - file_ext = file_path.suffix.lower() - if file_ext in CONVERTIBLE_EXTENSIONS: - md_path = await convert_file_to_markdown(file_path) - if md_path: - md_virtual_path = upload_virtual_path(md_path.name) + if sandbox_id != "local": + _make_file_sandbox_writable(file_path) + sandbox.update_file(virtual_path, content) - if sandbox_id != "local": - _make_file_sandbox_writable(md_path) - sandbox.update_file(md_virtual_path, md_path.read_bytes()) + file_info = { + "filename": safe_filename, + "size": str(len(content)), + "path": str(sandbox_uploads / safe_filename), + "virtual_path": virtual_path, + "artifact_url": upload_artifact_url(thread_id, safe_filename), + } - file_info["markdown_file"] = md_path.name - file_info["markdown_path"] = str(sandbox_uploads / md_path.name) - file_info["markdown_virtual_path"] = md_virtual_path - file_info["markdown_artifact_url"] = upload_artifact_url(thread_id, md_path.name) + logger.info(f"Saved file: {safe_filename} ({len(content)} bytes) to {file_info['path']}") - uploaded_files.append(file_info) + file_ext = file_path.suffix.lower() + if file_ext in CONVERTIBLE_EXTENSIONS: + md_path = await convert_file_to_markdown(file_path) + if md_path: + md_virtual_path = upload_virtual_path(md_path.name) - except Exception as e: - logger.error(f"Failed to upload {file.filename}: {e}") - raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}") + if sandbox_id != "local": + _make_file_sandbox_writable(md_path) + sandbox.update_file(md_virtual_path, md_path.read_bytes()) + + file_info["markdown_file"] = md_path.name + file_info["markdown_path"] = str(sandbox_uploads / md_path.name) + file_info["markdown_virtual_path"] = md_virtual_path + file_info["markdown_artifact_url"] = upload_artifact_url(thread_id, md_path.name) + + uploaded_files.append(file_info) + + except Exception as e: + logger.error(f"Failed to upload {file.filename}: {e}") + raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}") return UploadResponse( success=True, @@ -137,26 +137,25 @@ async def upload_files( @router.get("/list", response_model=dict) -@require_permission("threads", "read", owner_check=True) async def list_uploaded_files(thread_id: str, request: Request) -> dict: """List all files in a thread's uploads directory.""" - try: - uploads_dir = get_uploads_dir(thread_id) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - result = list_files_in_dir(uploads_dir) - enrich_file_listing(result, thread_id) + with bind_request_actor_context(request): + try: + uploads_dir = get_uploads_dir(thread_id) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + result = list_files_in_dir(uploads_dir) + enrich_file_listing(result, thread_id) - # Gateway additionally includes the sandbox-relative path. - sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) - for f in result["files"]: - f["path"] = str(sandbox_uploads / f["filename"]) + # Gateway additionally includes the sandbox-relative path. + sandbox_uploads = get_paths().sandbox_uploads_dir(thread_id, user_id=get_effective_user_id()) + for f in result["files"]: + f["path"] = str(sandbox_uploads / f["filename"]) - return result + return result @router.delete("/{filename}") -@require_permission("threads", "delete", owner_check=True, require_existing=True) async def delete_uploaded_file(thread_id: str, filename: str, request: Request) -> dict: """Delete a file from a thread's uploads directory.""" try: