diff --git a/backend/app/gateway/deps.py b/backend/app/gateway/deps.py index 7e418c5ce..d8af19d4e 100644 --- a/backend/app/gateway/deps.py +++ b/backend/app/gateway/deps.py @@ -14,7 +14,7 @@ from contextlib import AsyncExitStack, asynccontextmanager from fastapi import FastAPI, HTTPException, Request -from deerflow.runtime import RunManager +from deerflow.runtime import RunContext, RunManager @asynccontextmanager @@ -109,6 +109,25 @@ def get_thread_meta_repo(request: Request): return getattr(request.app.state, "thread_meta_repo", None) +def get_run_context(request: Request) -> RunContext: + """Build a :class:`RunContext` from ``app.state`` singletons. + + Returns a *base* context with infrastructure dependencies. Callers that + need per-run fields (e.g. ``follow_up_to_run_id``) should use + ``dataclasses.replace(ctx, follow_up_to_run_id=...)`` before passing it + to :func:`run_agent`. + """ + from deerflow.config import get_app_config + + return RunContext( + checkpointer=get_checkpointer(request), + store=get_store(request), + event_store=get_run_event_store(request), + run_events_config=getattr(get_app_config(), "run_events", None), + thread_meta_repo=get_thread_meta_repo(request), + ) + + async def get_current_user(request: Request) -> str | None: """Extract user identity from request. diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index 36c47c4f9..31833b822 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -8,6 +8,7 @@ frames, and consuming stream bridge events. Router modules from __future__ import annotations import asyncio +import dataclasses import json import logging import re @@ -17,7 +18,7 @@ from typing import Any from fastapi import HTTPException, Request from langchain_core.messages import HumanMessage -from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_store, get_stream_bridge, get_thread_meta_repo +from app.gateway.deps import get_run_context, get_run_manager, get_run_store, get_stream_bridge from app.gateway.routers.threads import _sanitize_log_param from deerflow.runtime import ( END_SENTINEL, @@ -256,14 +257,7 @@ async def start_run( """ bridge = get_stream_bridge(request) run_mgr = get_run_manager(request) - checkpointer = get_checkpointer(request) - store = get_store(request) - event_store = get_run_event_store(request) - - # Get run_events config for journal - from deerflow.config import get_app_config - - run_events_config = getattr(get_app_config(), "run_events", None) + run_ctx = get_run_context(request) disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_ @@ -278,6 +272,10 @@ async def start_run( except Exception: pass # Don't block run creation + # Enrich base context with per-run field + if follow_up_to_run_id: + run_ctx = dataclasses.replace(run_ctx, follow_up_to_run_id=follow_up_to_run_id) + try: record = await run_mgr.create_or_reject( thread_id, @@ -295,23 +293,21 @@ async def start_run( # Ensure the thread is visible in /threads/search, even for threads that # were never explicitly created via POST /threads (e.g. stateless runs). - store = get_store(request) - if store is not None: - await _upsert_thread_in_store(store, thread_id, body.metadata) + if run_ctx.store is not None: + await _upsert_thread_in_store(run_ctx.store, thread_id, body.metadata) # Upsert thread metadata in the SQL-backed threads_meta table - thread_meta_repo = get_thread_meta_repo(request) - if thread_meta_repo is not None: + if run_ctx.thread_meta_repo is not None: try: - existing = await thread_meta_repo.get(thread_id) + existing = await run_ctx.thread_meta_repo.get(thread_id) if existing is None: - await thread_meta_repo.create( + await run_ctx.thread_meta_repo.create( thread_id, assistant_id=body.assistant_id, metadata=body.metadata, ) else: - await thread_meta_repo.update_status(thread_id, "running") + await run_ctx.thread_meta_repo.update_status(thread_id, "running") except Exception: logger.warning("Failed to upsert thread_meta for %s (non-fatal)", _sanitize_log_param(thread_id)) @@ -346,8 +342,7 @@ async def start_run( bridge, run_mgr, record, - checkpointer=checkpointer, - store=store, + ctx=run_ctx, agent_factory=agent_factory, graph_input=graph_input, config=config, @@ -355,10 +350,6 @@ async def start_run( stream_subgraphs=body.stream_subgraphs, interrupt_before=body.interrupt_before, interrupt_after=body.interrupt_after, - event_store=event_store, - run_events_config=run_events_config, - follow_up_to_run_id=follow_up_to_run_id, - thread_meta_repo=thread_meta_repo, ) ) record.task = task @@ -366,8 +357,8 @@ async def start_run( # After the run completes, sync the title generated by TitleMiddleware from # the checkpointer into the Store record so that /threads/search returns the # correct title instead of an empty values dict. - if store is not None: - asyncio.create_task(_sync_thread_title_after_run(task, thread_id, checkpointer, store)) + if run_ctx.store is not None: + asyncio.create_task(_sync_thread_title_after_run(task, thread_id, run_ctx.checkpointer, run_ctx.store)) return record diff --git a/backend/packages/harness/deerflow/runtime/__init__.py b/backend/packages/harness/deerflow/runtime/__init__.py index d7eccf101..d5faa9018 100644 --- a/backend/packages/harness/deerflow/runtime/__init__.py +++ b/backend/packages/harness/deerflow/runtime/__init__.py @@ -5,7 +5,7 @@ Re-exports the public API of :mod:`~deerflow.runtime.runs` and directly from ``deerflow.runtime``. """ -from .runs import ConflictError, DisconnectMode, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent +from .runs import ConflictError, DisconnectMode, RunContext, RunManager, RunRecord, RunStatus, UnsupportedStrategyError, run_agent from .serialization import serialize, serialize_channel_values, serialize_lc_object, serialize_messages_tuple from .store import get_store, make_store, reset_store, store_context from .stream_bridge import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, StreamBridge, StreamEvent, make_stream_bridge @@ -14,6 +14,7 @@ __all__ = [ # runs "ConflictError", "DisconnectMode", + "RunContext", "RunManager", "RunRecord", "RunStatus", diff --git a/backend/packages/harness/deerflow/runtime/runs/__init__.py b/backend/packages/harness/deerflow/runtime/runs/__init__.py index afed90f48..9faa30c17 100644 --- a/backend/packages/harness/deerflow/runtime/runs/__init__.py +++ b/backend/packages/harness/deerflow/runtime/runs/__init__.py @@ -2,11 +2,12 @@ from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError from .schemas import DisconnectMode, RunStatus -from .worker import run_agent +from .worker import RunContext, run_agent __all__ = [ "ConflictError", "DisconnectMode", + "RunContext", "RunManager", "RunRecord", "RunStatus", diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index 535aacc0f..00de0a2d1 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -17,6 +17,7 @@ from __future__ import annotations import asyncio import logging +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Literal if TYPE_CHECKING: @@ -34,13 +35,29 @@ logger = logging.getLogger(__name__) _VALID_LG_MODES = {"values", "updates", "checkpoints", "tasks", "debug", "messages", "custom"} +@dataclass(frozen=True) +class RunContext: + """Infrastructure dependencies for a single agent run. + + Groups checkpointer, store, and persistence-related singletons so that + ``run_agent`` (and any future callers) receive one object instead of a + growing list of keyword arguments. + """ + + checkpointer: Any + store: Any | None = field(default=None) + event_store: Any | None = field(default=None) + run_events_config: Any | None = field(default=None) + thread_meta_repo: Any | None = field(default=None) + follow_up_to_run_id: str | None = field(default=None) + + async def run_agent( bridge: StreamBridge, run_manager: RunManager, record: RunRecord, *, - checkpointer: Any, - store: Any | None = None, + ctx: RunContext, agent_factory: Any, graph_input: dict, config: dict, @@ -48,13 +65,17 @@ async def run_agent( stream_subgraphs: bool = False, interrupt_before: list[str] | Literal["*"] | None = None, interrupt_after: list[str] | Literal["*"] | None = None, - event_store: Any | None = None, - run_events_config: Any | None = None, - follow_up_to_run_id: str | None = None, - thread_meta_repo: Any | None = None, ) -> None: """Execute an agent in the background, publishing events to *bridge*.""" + # Unpack infrastructure dependencies from RunContext. + checkpointer = ctx.checkpointer + store = ctx.store + event_store = ctx.event_store + run_events_config = ctx.run_events_config + thread_meta_repo = ctx.thread_meta_repo + follow_up_to_run_id = ctx.follow_up_to_run_id + run_id = record.run_id thread_id = record.thread_id requested_modes: set[str] = set(stream_modes or ["values"])