refactor(runtime): introduce RunContext to reduce run_agent parameter bloat

Extract checkpointer, store, event_store, run_events_config, thread_meta_repo,
and follow_up_to_run_id into a frozen RunContext dataclass. Add get_run_context()
in deps.py to build the base context from app.state singletons. start_run() uses
dataclasses.replace() to enrich per-run fields before passing ctx to run_agent.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng 2026-04-06 10:59:47 +08:00
parent 8746a2bcd9
commit eba6810a44
5 changed files with 67 additions and 34 deletions

View File

@ -14,7 +14,7 @@ from contextlib import AsyncExitStack, asynccontextmanager
from fastapi import FastAPI, HTTPException, Request
from deerflow.runtime import RunManager
from deerflow.runtime import RunContext, RunManager
@asynccontextmanager
@ -109,6 +109,25 @@ def get_thread_meta_repo(request: Request):
return getattr(request.app.state, "thread_meta_repo", None)
def get_run_context(request: Request) -> RunContext:
"""Build a :class:`RunContext` from ``app.state`` singletons.
Returns a *base* context with infrastructure dependencies. Callers that
need per-run fields (e.g. ``follow_up_to_run_id``) should use
``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it
to :func:`run_agent`.
"""
from deerflow.config import get_app_config
return RunContext(
checkpointer=get_checkpointer(request),
store=get_store(request),
event_store=get_run_event_store(request),
run_events_config=getattr(get_app_config(), "run_events", None),
thread_meta_repo=get_thread_meta_repo(request),
)
async def get_current_user(request: Request) -> str | None:
"""Extract user identity from request.

View File

@ -8,6 +8,7 @@ frames, and consuming stream bridge events. Router modules
from __future__ import annotations
import asyncio
import dataclasses
import json
import logging
import re
@ -17,7 +18,7 @@ from typing import Any
from fastapi import HTTPException, Request
from langchain_core.messages import HumanMessage
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_store, get_stream_bridge, get_thread_meta_repo
from app.gateway.deps import get_run_context, get_run_manager, get_run_store, get_stream_bridge
from app.gateway.routers.threads import _sanitize_log_param
from deerflow.runtime import (
END_SENTINEL,
@ -256,14 +257,7 @@ async def start_run(
"""
bridge = get_stream_bridge(request)
run_mgr = get_run_manager(request)
checkpointer = get_checkpointer(request)
store = get_store(request)
event_store = get_run_event_store(request)
# Get run_events config for journal
from deerflow.config import get_app_config
run_events_config = getattr(get_app_config(), "run_events", None)
run_ctx = get_run_context(request)
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
@ -278,6 +272,10 @@ async def start_run(
except Exception:
pass # Don't block run creation
# Enrich base context with per-run field
if follow_up_to_run_id:
run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id)
try:
record = await run_mgr.create_or_reject(
thread_id,
@ -295,23 +293,21 @@ async def start_run(
# Ensure the thread is visible in /threads/search, even for threads that
# were never explicitly created via POST /threads (e.g. stateless runs).
store = get_store(request)
if store is not None:
await _upsert_thread_in_store(store, thread_id, body.metadata)
if run_ctx.store is not None:
await _upsert_thread_in_store(run_ctx.store, thread_id, body.metadata)
# Upsert thread metadata in the SQL-backed threads_meta table
thread_meta_repo = get_thread_meta_repo(request)
if thread_meta_repo is not None:
if run_ctx.thread_meta_repo is not None:
try:
existing = await thread_meta_repo.get(thread_id)
existing = await run_ctx.thread_meta_repo.get(thread_id)
if existing is None:
await thread_meta_repo.create(
await run_ctx.thread_meta_repo.create(
thread_id,
assistant_id=body.assistant_id,
metadata=body.metadata,
)
else:
await thread_meta_repo.update_status(thread_id, "running")
await run_ctx.thread_meta_repo.update_status(thread_id, "running")
except Exception:
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", _sanitize_log_param(thread_id))
@ -346,8 +342,7 @@ async def start_run(
bridge,
run_mgr,
record,
checkpointer=checkpointer,
store=store,
ctx=run_ctx,
agent_factory=agent_factory,
graph_input=graph_input,
config=config,
@ -355,10 +350,6 @@ async def start_run(
stream_subgraphs=body.stream_subgraphs,
interrupt_before=body.interrupt_before,
interrupt_after=body.interrupt_after,
event_store=event_store,
run_events_config=run_events_config,
follow_up_to_run_id=follow_up_to_run_id,
thread_meta_repo=thread_meta_repo,
)
)
record.task = task
@ -366,8 +357,8 @@ async def start_run(
# After the run completes, sync the title generated by TitleMiddleware from
# the checkpointer into the Store record so that /threads/search returns the
# correct title instead of an empty values dict.
if store is not None:
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, checkpointer, store))
if run_ctx.store is not None:
asyncio.create_task(_sync_thread_title_after_run(task, thread_id, run_ctx.checkpointer, run_ctx.store))
return record

View File

@ -5,7 +5,7 @@ Re-exports the public API of :mod:`~deerflow.runtime.runs` and
directly from ``deerflow.runtime``.
"""
from .runs import ConflictError, DisconnectMode, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
from .runs import ConflictError, DisconnectMode, RunContext, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent
from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple
from .store import get_store, make_store, reset_store, store_context
from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge
@ -14,6 +14,7 @@ __all__ = [
# runs
"ConflictError",
"DisconnectMode",
"RunContext",
"RunManager",
"RunRecord",
"RunStatus",

View File

@ -2,11 +2,12 @@
from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError
from .schemas import DisconnectMode, RunStatus
from .worker import run_agent
from .worker import RunContext, run_agent
__all__ = [
"ConflictError",
"DisconnectMode",
"RunContext",
"RunManager",
"RunRecord",
"RunStatus",

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
@ -34,13 +35,29 @@ logger = logging.getLogger(__name__)
_VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"}
@dataclass(frozen=True)
class RunContext:
"""Infrastructure dependencies for a single agent run.
Groups checkpointer, store, and persistence-related singletons so that
``run_agent`` (and any future callers) receive one object instead of a
growing list of keyword arguments.
"""
checkpointer: Any
store: Any | None = field(default=None)
event_store: Any | None = field(default=None)
run_events_config: Any | None = field(default=None)
thread_meta_repo: Any | None = field(default=None)
follow_up_to_run_id: str | None = field(default=None)
async def run_agent(
bridge: StreamBridge,
run_manager: RunManager,
record: RunRecord,
*,
checkpointer: Any,
store: Any | None = None,
ctx: RunContext,
agent_factory: Any,
graph_input: dict,
config: dict,
@ -48,13 +65,17 @@ async def run_agent(
stream_subgraphs: bool = False,
interrupt_before: list[str] | Literal["*"] | None = None,
interrupt_after: list[str] | Literal["*"] | None = None,
event_store: Any | None = None,
run_events_config: Any | None = None,
follow_up_to_run_id: str | None = None,
thread_meta_repo: Any | None = None,
) -> None:
"""Execute an agent in the background, publishing events to *bridge*."""
# Unpack infrastructure dependencies from RunContext.
checkpointer = ctx.checkpointer
store = ctx.store
event_store = ctx.event_store
run_events_config = ctx.run_events_config
thread_meta_repo = ctx.thread_meta_repo
follow_up_to_run_id = ctx.follow_up_to_run_id
run_id = record.run_id
thread_id = record.thread_id
requested_modes: set[str] = set(stream_modes or ["values"])