mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
- Replaced `fetchWithAuth` with a generic `fetch` function across various API modules for consistency. - Updated `useThreadStream` and `useThreadHistory` hooks to manage chat history loading, including loading states and pagination. - Introduced `LoadMoreHistoryIndicator` component for better user experience when loading more chat history. - Enhanced message handling in `MessageList` to accommodate new loading states and history management. - Added support for run messages in the thread context, improving the overall message handling logic. - Updated translations for loading indicators in English and Chinese.
378 lines
15 KiB
Python
378 lines
15 KiB
Python
"""Runs endpoints — create, stream, wait, cancel.
|
|
|
|
Implements the LangGraph Platform runs API on top of
|
|
:class:`deerflow.agents.runs.RunManager` and
|
|
:class:`deerflow.agents.stream_bridge.StreamBridge`.
|
|
|
|
SSE format is aligned with the LangGraph Platform protocol so that
|
|
the ``useStream`` React hook from ``@langchain/langgraph-sdk/react``
|
|
works without modification.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Any, Literal
|
|
|
|
from fastapi import APIRouter, HTTPException, Query, Request
|
|
from fastapi.responses import Response, StreamingResponse
|
|
from pydantic import BaseModel, Field
|
|
|
|
from app.gateway.authz import require_permission
|
|
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
|
from app.gateway.services import sse_consumer, start_run
|
|
from deerflow.runtime import RunRecord, serialize_channel_values
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/api/threads", tags=["runs"])
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Request / response models
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class RunCreateRequest(BaseModel):
|
|
assistant_id: str | None = Field(default=None, description="Agent / assistant to use")
|
|
input: dict[str, Any] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})")
|
|
command: dict[str, Any] | None = Field(default=None, description="LangGraph Command")
|
|
metadata: dict[str, Any] | None = Field(default=None, description="Run metadata")
|
|
config: dict[str, Any] | None = Field(default=None, description="RunnableConfig overrides")
|
|
context: dict[str, Any] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
|
|
webhook: str | None = Field(default=None, description="Completion callback URL")
|
|
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
|
|
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
|
interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before")
|
|
interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after")
|
|
stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)")
|
|
stream_subgraphs: bool = Field(default=False, description="Include subgraph events")
|
|
stream_resumable: bool | None = Field(default=None, description="SSE resumable mode")
|
|
on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect")
|
|
on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion")
|
|
multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy")
|
|
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
|
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
|
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
|
|
|
|
|
class RunResponse(BaseModel):
|
|
run_id: str
|
|
thread_id: str
|
|
assistant_id: str | None = None
|
|
status: str
|
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
kwargs: dict[str, Any] = Field(default_factory=dict)
|
|
multitask_strategy: str = "reject"
|
|
created_at: str = ""
|
|
updated_at: str = ""
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _record_to_response(record: RunRecord) -> RunResponse:
|
|
return RunResponse(
|
|
run_id=record.run_id,
|
|
thread_id=record.thread_id,
|
|
assistant_id=record.assistant_id,
|
|
status=record.status.value,
|
|
metadata=record.metadata,
|
|
kwargs=record.kwargs,
|
|
multitask_strategy=record.multitask_strategy,
|
|
created_at=record.created_at,
|
|
updated_at=record.updated_at,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Endpoints
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@router.post("/{thread_id}/runs", response_model=RunResponse)
|
|
@require_permission("runs", "create", owner_check=True, require_existing=True)
|
|
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
|
|
"""Create a background run (returns immediately)."""
|
|
record = await start_run(body, thread_id, request)
|
|
return _record_to_response(record)
|
|
|
|
|
|
@router.post("/{thread_id}/runs/stream")
|
|
@require_permission("runs", "create", owner_check=True, require_existing=True)
|
|
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
|
|
"""Create a run and stream events via SSE.
|
|
|
|
The response includes a ``Content-Location`` header with the run's
|
|
resource URL, matching the LangGraph Platform protocol. The
|
|
``useStream`` React hook uses this to extract run metadata.
|
|
"""
|
|
bridge = get_stream_bridge(request)
|
|
run_mgr = get_run_manager(request)
|
|
record = await start_run(body, thread_id, request)
|
|
|
|
return StreamingResponse(
|
|
sse_consumer(bridge, record, request, run_mgr),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
# LangGraph Platform includes run metadata in this header.
|
|
# The SDK uses a greedy regex to extract the run id from this path,
|
|
# so it must point at the canonical run resource without extra suffixes.
|
|
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
|
},
|
|
)
|
|
|
|
|
|
@router.post("/{thread_id}/runs/wait", response_model=dict)
|
|
@require_permission("runs", "create", owner_check=True, require_existing=True)
|
|
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
|
|
"""Create a run and block until it completes, returning the final state."""
|
|
record = await start_run(body, thread_id, request)
|
|
|
|
if record.task is not None:
|
|
try:
|
|
await record.task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
checkpointer = get_checkpointer(request)
|
|
config = {"configurable": {"thread_id": thread_id}}
|
|
try:
|
|
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
|
if checkpoint_tuple is not None:
|
|
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
|
channel_values = checkpoint.get("channel_values", {})
|
|
return serialize_channel_values(channel_values)
|
|
except Exception:
|
|
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
|
|
|
return {"status": record.status.value, "error": record.error}
|
|
|
|
|
|
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
|
|
@require_permission("runs", "read", owner_check=True)
|
|
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
|
|
"""List all runs for a thread."""
|
|
run_mgr = get_run_manager(request)
|
|
records = await run_mgr.list_by_thread(thread_id)
|
|
return [_record_to_response(r) for r in records]
|
|
|
|
|
|
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
|
|
@require_permission("runs", "read", owner_check=True)
|
|
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
|
"""Get details of a specific run."""
|
|
run_mgr = get_run_manager(request)
|
|
record = run_mgr.get(run_id)
|
|
if record is None or record.thread_id != thread_id:
|
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
return _record_to_response(record)
|
|
|
|
|
|
@router.post("/{thread_id}/runs/{run_id}/cancel")
|
|
@require_permission("runs", "cancel", owner_check=True, require_existing=True)
|
|
async def cancel_run(
|
|
thread_id: str,
|
|
run_id: str,
|
|
request: Request,
|
|
wait: bool = Query(default=False, description="Block until run completes after cancel"),
|
|
action: Literal["interrupt", "rollback"] = Query(default="interrupt", description="Cancel action"),
|
|
) -> Response:
|
|
"""Cancel a running or pending run.
|
|
|
|
- action=interrupt: Stop execution, keep current checkpoint (can be resumed)
|
|
- action=rollback: Stop execution, revert to pre-run checkpoint state
|
|
- wait=true: Block until the run fully stops, return 204
|
|
- wait=false: Return immediately with 202
|
|
"""
|
|
run_mgr = get_run_manager(request)
|
|
record = run_mgr.get(run_id)
|
|
if record is None or record.thread_id != thread_id:
|
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
|
|
cancelled = await run_mgr.cancel(run_id, action=action)
|
|
if not cancelled:
|
|
raise HTTPException(
|
|
status_code=409,
|
|
detail=f"Run {run_id} is not cancellable (status: {record.status.value})",
|
|
)
|
|
|
|
if wait and record.task is not None:
|
|
try:
|
|
await record.task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
return Response(status_code=204)
|
|
|
|
return Response(status_code=202)
|
|
|
|
|
|
@router.get("/{thread_id}/runs/{run_id}/join")
|
|
@require_permission("runs", "read", owner_check=True)
|
|
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
|
|
"""Join an existing run's SSE stream."""
|
|
bridge = get_stream_bridge(request)
|
|
run_mgr = get_run_manager(request)
|
|
record = run_mgr.get(run_id)
|
|
if record is None or record.thread_id != thread_id:
|
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
|
|
return StreamingResponse(
|
|
sse_consumer(bridge, record, request, run_mgr),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
|
|
@router.api_route("/{thread_id}/runs/{run_id}/stream", methods=["GET", "POST"], response_model=None)
|
|
@require_permission("runs", "read", owner_check=True)
|
|
async def stream_existing_run(
|
|
thread_id: str,
|
|
run_id: str,
|
|
request: Request,
|
|
action: Literal["interrupt", "rollback"] | None = Query(default=None, description="Cancel action"),
|
|
wait: int = Query(default=0, description="Block until cancelled (1) or return immediately (0)"),
|
|
):
|
|
"""Join an existing run's SSE stream (GET), or cancel-then-stream (POST).
|
|
|
|
The LangGraph SDK's ``joinStream`` and ``useStream`` stop button both use
|
|
``POST`` to this endpoint. When ``action=interrupt`` or ``action=rollback``
|
|
is present the run is cancelled first; the response then streams any
|
|
remaining buffered events so the client observes a clean shutdown.
|
|
"""
|
|
run_mgr = get_run_manager(request)
|
|
record = run_mgr.get(run_id)
|
|
if record is None or record.thread_id != thread_id:
|
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
|
|
# Cancel if an action was requested (stop-button / interrupt flow)
|
|
if action is not None:
|
|
cancelled = await run_mgr.cancel(run_id, action=action)
|
|
if cancelled and wait and record.task is not None:
|
|
try:
|
|
await record.task
|
|
except (asyncio.CancelledError, Exception):
|
|
pass
|
|
return Response(status_code=204)
|
|
|
|
bridge = get_stream_bridge(request)
|
|
return StreamingResponse(
|
|
sse_consumer(bridge, record, request, run_mgr),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Messages / Events / Token usage endpoints
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@router.get("/{thread_id}/messages")
|
|
@require_permission("runs", "read", owner_check=True)
|
|
async def list_thread_messages(
|
|
thread_id: str,
|
|
request: Request,
|
|
limit: int = Query(default=50, le=200),
|
|
before_seq: int | None = Query(default=None),
|
|
after_seq: int | None = Query(default=None),
|
|
) -> list[dict]:
|
|
"""Return displayable messages for a thread (across all runs), with feedback attached."""
|
|
event_store = get_run_event_store(request)
|
|
messages = await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
|
|
|
|
# Attach feedback to the last AI message of each run
|
|
feedback_repo = get_feedback_repo(request)
|
|
user_id = await get_current_user(request)
|
|
feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id)
|
|
|
|
# Find the last ai_message per run_id
|
|
last_ai_per_run: dict[str, int] = {} # run_id -> index in messages list
|
|
for i, msg in enumerate(messages):
|
|
if msg.get("event_type") == "ai_message":
|
|
last_ai_per_run[msg["run_id"]] = i
|
|
|
|
# Attach feedback field
|
|
last_ai_indices = set(last_ai_per_run.values())
|
|
for i, msg in enumerate(messages):
|
|
if i in last_ai_indices:
|
|
run_id = msg["run_id"]
|
|
fb = feedback_map.get(run_id)
|
|
msg["feedback"] = (
|
|
{
|
|
"feedback_id": fb["feedback_id"],
|
|
"rating": fb["rating"],
|
|
"comment": fb.get("comment"),
|
|
}
|
|
if fb
|
|
else None
|
|
)
|
|
else:
|
|
msg["feedback"] = None
|
|
|
|
return messages
|
|
|
|
|
|
@router.get("/{thread_id}/runs/{run_id}/messages")
|
|
@require_permission("runs", "read", owner_check=True)
|
|
async def list_run_messages(
|
|
thread_id: str,
|
|
run_id: str,
|
|
request: Request,
|
|
limit: int = Query(default=50, le=200, ge=1),
|
|
before_seq: int | None = Query(default=None),
|
|
after_seq: int | None = Query(default=None),
|
|
) -> dict:
|
|
"""Return paginated messages for a specific run.
|
|
|
|
Response: { data: [...], has_more: bool }
|
|
"""
|
|
event_store = get_run_event_store(request)
|
|
rows = await event_store.list_messages_by_run(
|
|
thread_id,
|
|
run_id,
|
|
limit=limit + 1,
|
|
before_seq=before_seq,
|
|
after_seq=after_seq,
|
|
)
|
|
has_more = len(rows) > limit
|
|
data = rows[:limit] if has_more else rows
|
|
return {"data": data, "has_more": has_more}
|
|
|
|
|
|
@router.get("/{thread_id}/runs/{run_id}/events")
|
|
@require_permission("runs", "read", owner_check=True)
|
|
async def list_run_events(
|
|
thread_id: str,
|
|
run_id: str,
|
|
request: Request,
|
|
event_types: str | None = Query(default=None),
|
|
limit: int = Query(default=500, le=2000),
|
|
) -> list[dict]:
|
|
"""Return the full event stream for a run (debug/audit)."""
|
|
event_store = get_run_event_store(request)
|
|
types = event_types.split(",") if event_types else None
|
|
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
|
|
|
|
|
|
@router.get("/{thread_id}/token-usage")
|
|
@require_permission("threads", "read", owner_check=True)
|
|
async def thread_token_usage(thread_id: str, request: Request) -> dict:
|
|
"""Thread-level token usage aggregation."""
|
|
run_store = get_run_store(request)
|
|
agg = await run_store.aggregate_tokens_by_thread(thread_id)
|
|
return {"thread_id": thread_id, **agg}
|