mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-31 04:48:08 +00:00
* fix(gateway): split stream_existing_run into per-method routes for unique OpenAPI operationIds
`@router.api_route("/.../stream", methods=["GET", "POST"])` registers a
single FastAPI route that holds both methods. FastAPI's auto-generated
`operationId` is computed once per route from a single method picked out
of `route.methods`, so when OpenAPI generation iterates over every method
on that route both end up sharing the same `operationId`. That triggers
`UserWarning: Duplicate Operation ID stream_existing_run_..._stream_(get|post) for function stream_existing_run`
during `app.openapi()` and produces an invalid OpenAPI spec for SDK /
codegen consumers.
Register GET and POST as two separate routes on the same handler so each
method gets a distinct auto-generated `operationId` ("..._stream_get" and
"..._stream_post"). Behavior is otherwise unchanged: same handler, same
`require_permission` decoration, same response.
Add `tests/test_openapi_operation_ids.py` to lock in the invariant:
no duplicate-operationId warnings during spec generation, globally unique
operationIds across the spec, and distinct GET / POST operationIds on the
stream endpoint specifically. Reverted the source change locally and
confirmed all three tests fail before the fix.
* test(runtime): widen CancelledError catch in _ScriptedAgent to fix cancel-race flake
`_ScriptedAgent.astream()` previously only caught `asyncio.CancelledError`
inside the inner `if self.block_after_first_chunk:` while-loop. Cancellation
arriving during any earlier `await` in the same body
(`self.model.ainvoke`, `_write_checkpoint`, the `yield`) would propagate
without setting `controller.cancelled`, so callers waiting on
`controller.cancelled.wait(5)` after `POST /cancel` returned 204 could race
and time out.
`test_cancel_interrupt_stops_running_background_run` waits only for the
`started` event (set on the first line of `astream`) before issuing cancel,
so its race window spans all three pre-loop `await`s. On a clean `main`
checkout, stress-running the test 20× reproduces the failure 6/20
(~30%). `test_cancel_rollback_restores_pre_run_checkpoint`, which waits
for the later `checkpoint_written` event, passes 20/20 — confirming the
race lives entirely in the gap between `started.set()` and the
cancellation-aware block.
Widen the try/except to cover the entire `astream` body so any
`CancelledError` sets the controller event; the non-cancel path is
unchanged (no exception means no event set). After this change the
previously flaky test passes 50/50, the rollback test still passes 30/30,
and the full backend suite remains at 3649 passed / 19 skipped.
Test-only change — `backend/tests/test_runtime_lifecycle_e2e.py` is the
only file touched; the production cancel pipeline is unaffected.
439 lines
18 KiB
Python
439 lines
18 KiB
Python
"""Runs endpoints — create, stream, wait, cancel.
|
|
|
|
Implements the LangGraph Platform runs API on top of
|
|
:class:`deerflow.agents.runs.RunManager` and
|
|
:class:`deerflow.agents.stream_bridge.StreamBridge`.
|
|
|
|
SSE format is aligned with the LangGraph Platform protocol so that
|
|
the ``useStream`` React hook from ``@langchain/langgraph-sdk/react``
|
|
works without modification.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Any, Literal
|
|
|
|
from fastapi import APIRouter, HTTPException, Query, Request
|
|
from fastapi.responses import Response, StreamingResponse
|
|
from pydantic import BaseModel, Field
|
|
|
|
from app.gateway.authz import require_permission
|
|
from app.gateway.deps import get_checkpointer, get_current_user, get_feedback_repo, get_run_event_store, get_run_manager, get_run_store, get_stream_bridge
|
|
from app.gateway.services import sse_consumer, start_run, wait_for_run_completion
|
|
from deerflow.runtime import RunRecord, RunStatus, serialize_channel_values
|
|
|
|
logger = logging.getLogger(__name__)
|
|
router = APIRouter(prefix="/api/threads", tags=["runs"])
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Request / response models
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class RunCreateRequest(BaseModel):
|
|
assistant_id: str | None = Field(default=None, description="Agent / assistant to use")
|
|
input: dict[str, Any] | None = Field(default=None, description="Graph input (e.g. {messages: [...]})")
|
|
command: dict[str, Any] | None = Field(default=None, description="LangGraph Command")
|
|
metadata: dict[str, Any] | None = Field(default=None, description="Run metadata")
|
|
config: dict[str, Any] | None = Field(default=None, description="RunnableConfig overrides")
|
|
context: dict[str, Any] | None = Field(default=None, description="DeerFlow context overrides (model_name, thinking_enabled, etc.)")
|
|
webhook: str | None = Field(default=None, description="Completion callback URL")
|
|
checkpoint_id: str | None = Field(default=None, description="Resume from checkpoint")
|
|
checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object")
|
|
interrupt_before: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt before")
|
|
interrupt_after: list[str] | Literal["*"] | None = Field(default=None, description="Nodes to interrupt after")
|
|
stream_mode: list[str] | str | None = Field(default=None, description="Stream mode(s)")
|
|
stream_subgraphs: bool = Field(default=False, description="Include subgraph events")
|
|
stream_resumable: bool | None = Field(default=None, description="SSE resumable mode")
|
|
on_disconnect: Literal["cancel", "continue"] = Field(default="cancel", description="Behaviour on SSE disconnect")
|
|
on_completion: Literal["delete", "keep"] = Field(default="keep", description="Delete temp thread on completion")
|
|
multitask_strategy: Literal["reject", "rollback", "interrupt", "enqueue"] = Field(default="reject", description="Concurrency strategy")
|
|
after_seconds: float | None = Field(default=None, description="Delayed execution")
|
|
if_not_exists: Literal["reject", "create"] = Field(default="create", description="Thread creation policy")
|
|
feedback_keys: list[str] | None = Field(default=None, description="LangSmith feedback keys")
|
|
|
|
|
|
class RunResponse(BaseModel):
|
|
run_id: str
|
|
thread_id: str
|
|
assistant_id: str | None = None
|
|
status: str
|
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
kwargs: dict[str, Any] = Field(default_factory=dict)
|
|
multitask_strategy: str = "reject"
|
|
created_at: str = ""
|
|
updated_at: str = ""
|
|
total_input_tokens: int = 0
|
|
total_output_tokens: int = 0
|
|
total_tokens: int = 0
|
|
llm_call_count: int = 0
|
|
lead_agent_tokens: int = 0
|
|
subagent_tokens: int = 0
|
|
middleware_tokens: int = 0
|
|
message_count: int = 0
|
|
|
|
|
|
class ThreadTokenUsageModelBreakdown(BaseModel):
|
|
tokens: int = 0
|
|
runs: int = 0
|
|
|
|
|
|
class ThreadTokenUsageCallerBreakdown(BaseModel):
|
|
lead_agent: int = 0
|
|
subagent: int = 0
|
|
middleware: int = 0
|
|
|
|
|
|
class ThreadTokenUsageResponse(BaseModel):
|
|
thread_id: str
|
|
total_tokens: int = 0
|
|
total_input_tokens: int = 0
|
|
total_output_tokens: int = 0
|
|
total_runs: int = 0
|
|
by_model: dict[str, ThreadTokenUsageModelBreakdown] = Field(default_factory=dict)
|
|
by_caller: ThreadTokenUsageCallerBreakdown = Field(default_factory=ThreadTokenUsageCallerBreakdown)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _cancel_conflict_detail(run_id: str, record: RunRecord) -> str:
|
|
if record.status in (RunStatus.pending, RunStatus.running):
|
|
return f"Run {run_id} is not active on this worker and cannot be cancelled"
|
|
return f"Run {run_id} is not cancellable (status: {record.status.value})"
|
|
|
|
|
|
def _record_to_response(record: RunRecord) -> RunResponse:
|
|
return RunResponse(
|
|
run_id=record.run_id,
|
|
thread_id=record.thread_id,
|
|
assistant_id=record.assistant_id,
|
|
status=record.status.value,
|
|
metadata=record.metadata,
|
|
kwargs=record.kwargs,
|
|
multitask_strategy=record.multitask_strategy,
|
|
created_at=record.created_at,
|
|
updated_at=record.updated_at,
|
|
total_input_tokens=record.total_input_tokens,
|
|
total_output_tokens=record.total_output_tokens,
|
|
total_tokens=record.total_tokens,
|
|
llm_call_count=record.llm_call_count,
|
|
lead_agent_tokens=record.lead_agent_tokens,
|
|
subagent_tokens=record.subagent_tokens,
|
|
middleware_tokens=record.middleware_tokens,
|
|
message_count=record.message_count,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Endpoints
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@router.post("/{thread_id}/runs", response_model=RunResponse)
|
|
@require_permission("runs", "create", owner_check=True, require_existing=True)
|
|
async def create_run(thread_id: str, body: RunCreateRequest, request: Request) -> RunResponse:
|
|
"""Create a background run (returns immediately)."""
|
|
record = await start_run(body, thread_id, request)
|
|
return _record_to_response(record)
|
|
|
|
|
|
@router.post("/{thread_id}/runs/stream")
|
|
@require_permission("runs", "create", owner_check=True, require_existing=True)
|
|
async def stream_run(thread_id: str, body: RunCreateRequest, request: Request) -> StreamingResponse:
|
|
"""Create a run and stream events via SSE.
|
|
|
|
The response includes a ``Content-Location`` header with the run's
|
|
resource URL, matching the LangGraph Platform protocol. The
|
|
``useStream`` React hook uses this to extract run metadata.
|
|
"""
|
|
bridge = get_stream_bridge(request)
|
|
run_mgr = get_run_manager(request)
|
|
record = await start_run(body, thread_id, request)
|
|
|
|
return StreamingResponse(
|
|
sse_consumer(bridge, record, request, run_mgr),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
# LangGraph Platform includes run metadata in this header.
|
|
# The SDK uses a greedy regex to extract the run id from this path,
|
|
# so it must point at the canonical run resource without extra suffixes.
|
|
"Content-Location": f"/api/threads/{thread_id}/runs/{record.run_id}",
|
|
},
|
|
)
|
|
|
|
|
|
@router.post("/{thread_id}/runs/wait", response_model=dict)
|
|
@require_permission("runs", "create", owner_check=True, require_existing=True)
|
|
async def wait_run(thread_id: str, body: RunCreateRequest, request: Request) -> dict:
|
|
"""Create a run and block until it completes, returning the final state."""
|
|
bridge = get_stream_bridge(request)
|
|
run_mgr = get_run_manager(request)
|
|
record = await start_run(body, thread_id, request)
|
|
|
|
completed = True
|
|
if record.task is not None:
|
|
completed = await wait_for_run_completion(bridge, record, request, run_mgr)
|
|
|
|
if completed:
|
|
checkpointer = get_checkpointer(request)
|
|
config = {"configurable": {"thread_id": thread_id}}
|
|
try:
|
|
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
|
if checkpoint_tuple is not None:
|
|
checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
|
channel_values = checkpoint.get("channel_values", {})
|
|
return serialize_channel_values(channel_values)
|
|
except Exception:
|
|
logger.exception("Failed to fetch final state for run %s", record.run_id)
|
|
|
|
return {"status": record.status.value, "error": record.error}
|
|
|
|
|
|
@router.get("/{thread_id}/runs", response_model=list[RunResponse])
|
|
@require_permission("runs", "read", owner_check=True)
|
|
async def list_runs(thread_id: str, request: Request) -> list[RunResponse]:
|
|
"""List all runs for a thread."""
|
|
run_mgr = get_run_manager(request)
|
|
user_id = await get_current_user(request)
|
|
records = await run_mgr.list_by_thread(thread_id, user_id=user_id)
|
|
return [_record_to_response(r) for r in records]
|
|
|
|
|
|
@router.get("/{thread_id}/runs/{run_id}", response_model=RunResponse)
|
|
@require_permission("runs", "read", owner_check=True)
|
|
async def get_run(thread_id: str, run_id: str, request: Request) -> RunResponse:
|
|
"""Get details of a specific run."""
|
|
run_mgr = get_run_manager(request)
|
|
user_id = await get_current_user(request)
|
|
record = await run_mgr.get(run_id, user_id=user_id)
|
|
if record is None or record.thread_id != thread_id:
|
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
return _record_to_response(record)
|
|
|
|
|
|
@router.post("/{thread_id}/runs/{run_id}/cancel")
|
|
@require_permission("runs", "cancel", owner_check=True, require_existing=True)
|
|
async def cancel_run(
|
|
thread_id: str,
|
|
run_id: str,
|
|
request: Request,
|
|
wait: bool = Query(default=False, description="Block until run completes after cancel"),
|
|
action: Literal["interrupt", "rollback"] = Query(default="interrupt", description="Cancel action"),
|
|
) -> Response:
|
|
"""Cancel a running or pending run.
|
|
|
|
- action=interrupt: Stop execution, keep current checkpoint (can be resumed)
|
|
- action=rollback: Stop execution, revert to pre-run checkpoint state
|
|
- wait=true: Block until the run fully stops, return 204
|
|
- wait=false: Return immediately with 202
|
|
"""
|
|
run_mgr = get_run_manager(request)
|
|
record = await run_mgr.get(run_id)
|
|
if record is None or record.thread_id != thread_id:
|
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
|
|
cancelled = await run_mgr.cancel(run_id, action=action)
|
|
if not cancelled:
|
|
raise HTTPException(status_code=409, detail=_cancel_conflict_detail(run_id, record))
|
|
|
|
if wait and record.task is not None:
|
|
try:
|
|
await record.task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
return Response(status_code=204)
|
|
|
|
return Response(status_code=202)
|
|
|
|
|
|
@router.get("/{thread_id}/runs/{run_id}/join")
|
|
@require_permission("runs", "read", owner_check=True)
|
|
async def join_run(thread_id: str, run_id: str, request: Request) -> StreamingResponse:
|
|
"""Join an existing run's SSE stream."""
|
|
run_mgr = get_run_manager(request)
|
|
record = await run_mgr.get(run_id)
|
|
if record is None or record.thread_id != thread_id:
|
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
if record.store_only:
|
|
raise HTTPException(status_code=409, detail=f"Run {run_id} is not active on this worker and cannot be streamed")
|
|
|
|
bridge = get_stream_bridge(request)
|
|
return StreamingResponse(
|
|
sse_consumer(bridge, record, request, run_mgr),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
|
|
# Register GET and POST as separate routes so each method gets a unique OpenAPI
|
|
# operationId. ``api_route(methods=["GET", "POST"])`` shares one route registration
|
|
# across both methods, which makes FastAPI emit the same ``operationId`` twice and
|
|
# warn about a duplicate operation id during OpenAPI generation.
|
|
@router.get("/{thread_id}/runs/{run_id}/stream", response_model=None)
|
|
@router.post("/{thread_id}/runs/{run_id}/stream", response_model=None)
|
|
@require_permission("runs", "read", owner_check=True)
|
|
async def stream_existing_run(
|
|
thread_id: str,
|
|
run_id: str,
|
|
request: Request,
|
|
action: Literal["interrupt", "rollback"] | None = Query(default=None, description="Cancel action"),
|
|
wait: int = Query(default=0, description="Block until cancelled (1) or return immediately (0)"),
|
|
):
|
|
"""Join an existing run's SSE stream (GET), or cancel-then-stream (POST).
|
|
|
|
The LangGraph SDK's ``joinStream`` and ``useStream`` stop button both use
|
|
``POST`` to this endpoint. When ``action=interrupt`` or ``action=rollback``
|
|
is present the run is cancelled first; the response then streams any
|
|
remaining buffered events so the client observes a clean shutdown.
|
|
"""
|
|
run_mgr = get_run_manager(request)
|
|
record = await run_mgr.get(run_id)
|
|
if record is None or record.thread_id != thread_id:
|
|
raise HTTPException(status_code=404, detail=f"Run {run_id} not found")
|
|
if record.store_only and action is None:
|
|
raise HTTPException(status_code=409, detail=f"Run {run_id} is not active on this worker and cannot be streamed")
|
|
|
|
# Cancel if an action was requested (stop-button / interrupt flow)
|
|
if action is not None:
|
|
cancelled = await run_mgr.cancel(run_id, action=action)
|
|
if not cancelled:
|
|
raise HTTPException(status_code=409, detail=_cancel_conflict_detail(run_id, record))
|
|
if wait and record.task is not None:
|
|
try:
|
|
await record.task
|
|
except (asyncio.CancelledError, Exception):
|
|
pass
|
|
return Response(status_code=204)
|
|
|
|
bridge = get_stream_bridge(request)
|
|
return StreamingResponse(
|
|
sse_consumer(bridge, record, request, run_mgr),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Messages / Events / Token usage endpoints
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@router.get("/{thread_id}/messages")
|
|
@require_permission("runs", "read", owner_check=True)
|
|
async def list_thread_messages(
|
|
thread_id: str,
|
|
request: Request,
|
|
limit: int = Query(default=50, le=200),
|
|
before_seq: int | None = Query(default=None),
|
|
after_seq: int | None = Query(default=None),
|
|
) -> list[dict]:
|
|
"""Return displayable messages for a thread (across all runs), with feedback attached."""
|
|
event_store = get_run_event_store(request)
|
|
messages = await event_store.list_messages(thread_id, limit=limit, before_seq=before_seq, after_seq=after_seq)
|
|
|
|
# Attach feedback to the last AI message of each run
|
|
feedback_repo = get_feedback_repo(request)
|
|
user_id = await get_current_user(request)
|
|
feedback_map = await feedback_repo.list_by_thread_grouped(thread_id, user_id=user_id)
|
|
|
|
# Find the last ai_message per run_id
|
|
last_ai_per_run: dict[str, int] = {} # run_id -> index in messages list
|
|
for i, msg in enumerate(messages):
|
|
if msg.get("event_type") == "ai_message":
|
|
last_ai_per_run[msg["run_id"]] = i
|
|
|
|
# Attach feedback field
|
|
last_ai_indices = set(last_ai_per_run.values())
|
|
for i, msg in enumerate(messages):
|
|
if i in last_ai_indices:
|
|
run_id = msg["run_id"]
|
|
fb = feedback_map.get(run_id)
|
|
msg["feedback"] = (
|
|
{
|
|
"feedback_id": fb["feedback_id"],
|
|
"rating": fb["rating"],
|
|
"comment": fb.get("comment"),
|
|
}
|
|
if fb
|
|
else None
|
|
)
|
|
else:
|
|
msg["feedback"] = None
|
|
|
|
return messages
|
|
|
|
|
|
@router.get("/{thread_id}/runs/{run_id}/messages")
|
|
@require_permission("runs", "read", owner_check=True)
|
|
async def list_run_messages(
|
|
thread_id: str,
|
|
run_id: str,
|
|
request: Request,
|
|
limit: int = Query(default=50, le=200, ge=1),
|
|
before_seq: int | None = Query(default=None),
|
|
after_seq: int | None = Query(default=None),
|
|
) -> dict:
|
|
"""Return paginated messages for a specific run.
|
|
|
|
Response: { data: [...], has_more: bool }
|
|
"""
|
|
event_store = get_run_event_store(request)
|
|
rows = await event_store.list_messages_by_run(
|
|
thread_id,
|
|
run_id,
|
|
limit=limit + 1,
|
|
before_seq=before_seq,
|
|
after_seq=after_seq,
|
|
)
|
|
has_more = len(rows) > limit
|
|
data = rows[:limit] if has_more else rows
|
|
return {"data": data, "has_more": has_more}
|
|
|
|
|
|
@router.get("/{thread_id}/runs/{run_id}/events")
|
|
@require_permission("runs", "read", owner_check=True)
|
|
async def list_run_events(
|
|
thread_id: str,
|
|
run_id: str,
|
|
request: Request,
|
|
event_types: str | None = Query(default=None),
|
|
limit: int = Query(default=500, le=2000),
|
|
) -> list[dict]:
|
|
"""Return the full event stream for a run (debug/audit)."""
|
|
event_store = get_run_event_store(request)
|
|
types = event_types.split(",") if event_types else None
|
|
return await event_store.list_events(thread_id, run_id, event_types=types, limit=limit)
|
|
|
|
|
|
@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse)
|
|
@require_permission("threads", "read", owner_check=True)
|
|
async def thread_token_usage(
|
|
thread_id: str,
|
|
request: Request,
|
|
include_active: bool = Query(default=False, description="Include running run progress snapshots"),
|
|
) -> ThreadTokenUsageResponse:
|
|
"""Thread-level token usage aggregation."""
|
|
run_store = get_run_store(request)
|
|
if include_active:
|
|
agg = await run_store.aggregate_tokens_by_thread(thread_id, include_active=True)
|
|
else:
|
|
agg = await run_store.aggregate_tokens_by_thread(thread_id)
|
|
return ThreadTokenUsageResponse(thread_id=thread_id, **agg)
|