diff --git a/backend/app/infra/__init__.py b/backend/app/infra/__init__.py new file mode 100644 index 000000000..1e9be367c --- /dev/null +++ b/backend/app/infra/__init__.py @@ -0,0 +1 @@ +"""Application-owned infrastructure adapters and wiring.""" diff --git a/backend/app/infra/run_events/__init__.py b/backend/app/infra/run_events/__init__.py new file mode 100644 index 000000000..24be6f48d --- /dev/null +++ b/backend/app/infra/run_events/__init__.py @@ -0,0 +1,6 @@ +"""Run event store backends owned by app infrastructure.""" + +from .factory import build_run_event_store +from .jsonl_store import JsonlRunEventStore + +__all__ = ["JsonlRunEventStore", "build_run_event_store"] diff --git a/backend/app/infra/run_events/factory.py b/backend/app/infra/run_events/factory.py new file mode 100644 index 000000000..f968bd6ef --- /dev/null +++ b/backend/app/infra/run_events/factory.py @@ -0,0 +1,25 @@ +"""Factory for app-owned run event store backends.""" + +from __future__ import annotations + +from pathlib import Path + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from app.infra.storage import AppRunEventStore +from deerflow.config import get_app_config + +from .jsonl_store import JsonlRunEventStore + + +def build_run_event_store(session_factory: async_sessionmaker[AsyncSession]) -> AppRunEventStore | JsonlRunEventStore: + """Build the run event store selected by app configuration.""" + + config = get_app_config().run_events + if config.backend == "db": + return AppRunEventStore(session_factory) + if config.backend == "jsonl": + return JsonlRunEventStore( + base_dir=Path(config.jsonl_base_dir), + ) + raise ValueError(f"Unsupported run event backend: {config.backend}") diff --git a/backend/app/infra/run_events/jsonl_store.py b/backend/app/infra/run_events/jsonl_store.py new file mode 100644 index 000000000..a38e32375 --- /dev/null +++ b/backend/app/infra/run_events/jsonl_store.py @@ -0,0 +1,210 @@ +"""JSONL run event store backend owned by app infrastructure.""" + +from __future__ import annotations + +import asyncio +import json +import shutil +from collections.abc import Iterable +from datetime import UTC, datetime +from pathlib import Path +from typing import Any + + +class JsonlRunEventStore: + """Append-only JSONL implementation of the runs RunEventStore protocol.""" + + def __init__( + self, + base_dir: Path | str = ".deer-flow/run-events", + ) -> None: + self._base_dir = Path(base_dir) + self._locks: dict[str, asyncio.Lock] = {} + self._locks_guard = asyncio.Lock() + + async def put_batch(self, events: list[dict[str, Any]]) -> list[dict[str, Any]]: + if not events: + return [] + + grouped: dict[str, list[dict[str, Any]]] = {} + for event in events: + grouped.setdefault(str(event["thread_id"]), []).append(event) + + records_by_thread: dict[str, list[dict[str, Any]]] = {} + for thread_id, thread_events in grouped.items(): + async with await self._thread_lock(thread_id): + records_by_thread[thread_id] = self._append_thread_events(thread_id, thread_events) + + indexes = {thread_id: 0 for thread_id in records_by_thread} + ordered: list[dict[str, Any]] = [] + for event in events: + thread_id = str(event["thread_id"]) + index = indexes[thread_id] + ordered.append(records_by_thread[thread_id][index]) + indexes[thread_id] = index + 1 + return ordered + + async def list_messages( + self, + thread_id: str, + *, + limit: int = 50, + before_seq: int | None = None, + after_seq: int | None = None, + ) -> list[dict[str, Any]]: + events = [event for event in await self._read_thread_events(thread_id) if event.get("category") == "message"] + if before_seq is not None: + events = [event for event in events if int(event["seq"]) < before_seq] + return events[-limit:] + if after_seq is not None: + events = [event for event in events if int(event["seq"]) > after_seq] + return events[:limit] + return events[-limit:] + + async def list_events( + self, + thread_id: str, + run_id: str, + *, + event_types: list[str] | None = None, + limit: int = 500, + ) -> list[dict[str, Any]]: + event_type_set = set(event_types or []) + events = [ + event + for event in await self._read_thread_events(thread_id) + if event.get("run_id") == run_id and (not event_type_set or event.get("event_type") in event_type_set) + ] + return events[:limit] + + async def list_messages_by_run( + self, + thread_id: str, + run_id: str, + *, + limit: int = 50, + before_seq: int | None = None, + after_seq: int | None = None, + ) -> list[dict[str, Any]]: + events = [ + event + for event in await self._read_thread_events(thread_id) + if event.get("run_id") == run_id and event.get("category") == "message" + ] + if before_seq is not None: + events = [event for event in events if int(event["seq"]) < before_seq] + return events[-limit:] + if after_seq is not None: + events = [event for event in events if int(event["seq"]) > after_seq] + return events[:limit] + return events[-limit:] + + async def count_messages(self, thread_id: str) -> int: + return len(await self.list_messages(thread_id, limit=10**9)) + + async def delete_by_thread(self, thread_id: str) -> int: + async with await self._thread_lock(thread_id): + count = len(self._read_thread_events_sync(thread_id)) + shutil.rmtree(self._thread_dir(thread_id), ignore_errors=True) + return count + + async def delete_by_run(self, thread_id: str, run_id: str) -> int: + async with await self._thread_lock(thread_id): + events = self._read_thread_events_sync(thread_id) + kept = [event for event in events if event.get("run_id") != run_id] + deleted = len(events) - len(kept) + if deleted: + self._write_thread_events(thread_id, kept) + return deleted + + async def _thread_lock(self, thread_id: str) -> asyncio.Lock: + async with self._locks_guard: + lock = self._locks.get(thread_id) + if lock is None: + lock = asyncio.Lock() + self._locks[thread_id] = lock + return lock + + def _append_thread_events(self, thread_id: str, events: list[dict[str, Any]]) -> list[dict[str, Any]]: + thread_dir = self._thread_dir(thread_id) + thread_dir.mkdir(parents=True, exist_ok=True) + + seq = self._read_seq(thread_id) + records: list[dict[str, Any]] = [] + with self._events_path(thread_id).open("a", encoding="utf-8") as file: + for event in events: + seq += 1 + record = self._normalize_event(event, seq=seq) + file.write(json.dumps(record, ensure_ascii=False, default=str)) + file.write("\n") + records.append(record) + self._write_seq(thread_id, seq) + return records + + def _normalize_event(self, event: dict[str, Any], *, seq: int) -> dict[str, Any]: + created_at = event.get("created_at") + if isinstance(created_at, datetime): + created_at_value = created_at.isoformat() + elif created_at: + created_at_value = str(created_at) + else: + created_at_value = datetime.now(UTC).isoformat() + + return { + "thread_id": str(event["thread_id"]), + "run_id": str(event["run_id"]), + "seq": seq, + "event_type": str(event["event_type"]), + "category": str(event["category"]), + "content": event.get("content", ""), + "metadata": dict(event.get("metadata") or {}), + "created_at": created_at_value, + } + + async def _read_thread_events(self, thread_id: str) -> list[dict[str, Any]]: + async with await self._thread_lock(thread_id): + return self._read_thread_events_sync(thread_id) + + def _read_thread_events_sync(self, thread_id: str) -> list[dict[str, Any]]: + path = self._events_path(thread_id) + if not path.exists(): + return [] + + events: list[dict[str, Any]] = [] + with path.open(encoding="utf-8") as file: + for line in file: + stripped = line.strip() + if stripped: + events.append(json.loads(stripped)) + return events + + def _write_thread_events(self, thread_id: str, events: Iterable[dict[str, Any]]) -> None: + thread_dir = self._thread_dir(thread_id) + thread_dir.mkdir(parents=True, exist_ok=True) + temp_path = self._events_path(thread_id).with_suffix(".jsonl.tmp") + with temp_path.open("w", encoding="utf-8") as file: + for event in events: + file.write(json.dumps(event, ensure_ascii=False, default=str)) + file.write("\n") + temp_path.replace(self._events_path(thread_id)) + + def _read_seq(self, thread_id: str) -> int: + path = self._seq_path(thread_id) + if not path.exists(): + return 0 + try: + return int(path.read_text(encoding="utf-8").strip() or "0") + except ValueError: + return 0 + + def _write_seq(self, thread_id: str, seq: int) -> None: + self._seq_path(thread_id).write_text(str(seq), encoding="utf-8") + + def _thread_dir(self, thread_id: str) -> Path: + return self._base_dir / "threads" / thread_id + + def _events_path(self, thread_id: str) -> Path: + return self._thread_dir(thread_id) / "events.jsonl" + + def _seq_path(self, thread_id: str) -> Path: + return self._thread_dir(thread_id) / "seq" diff --git a/backend/app/infra/storage/__init__.py b/backend/app/infra/storage/__init__.py new file mode 100644 index 000000000..29ff550ea --- /dev/null +++ b/backend/app/infra/storage/__init__.py @@ -0,0 +1,14 @@ +"""Storage-facing adapters owned by the app layer.""" + +from .run_events import AppRunEventStore +from .runs import FeedbackStoreAdapter, RunStoreAdapter, StorageRunObserver +from .thread_meta import ThreadMetaStorage, ThreadMetaStoreAdapter + +__all__ = [ + "AppRunEventStore", + "FeedbackStoreAdapter", + "RunStoreAdapter", + "StorageRunObserver", + "ThreadMetaStorage", + "ThreadMetaStoreAdapter", +] diff --git a/backend/app/infra/storage/run_events.py b/backend/app/infra/storage/run_events.py new file mode 100644 index 000000000..a81eda7be --- /dev/null +++ b/backend/app/infra/storage/run_events.py @@ -0,0 +1,166 @@ +"""App-owned adapter from runs callbacks to storage run event repository.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from store.repositories import RunEvent, RunEventCreate, build_run_event_repository, build_thread_meta_repository + +from deerflow.runtime.actor_context import get_actor_context + + +class AppRunEventStore: + """Implements the harness RunEventStore protocol using storage repositories.""" + + def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: + self._session_factory = session_factory + + async def put_batch(self, events: list[dict[str, Any]]) -> list[dict[str, Any]]: + if not events: + return [] + + denied = {str(event["thread_id"]) for event in events if not await self._thread_visible(str(event["thread_id"]))} + if denied: + raise PermissionError(f"actor is not allowed to append events for thread(s): {', '.join(sorted(denied))}") + + async with self._session_factory() as session: + try: + repo = build_run_event_repository(session) + rows = await repo.append_batch([_event_create_from_dict(event) for event in events]) + await session.commit() + except Exception: + await session.rollback() + raise + + return [_event_to_dict(row) for row in rows] + + async def list_messages( + self, + thread_id: str, + *, + limit: int = 50, + before_seq: int | None = None, + after_seq: int | None = None, + ) -> list[dict[str, Any]]: + if not await self._thread_visible(thread_id): + return [] + async with self._session_factory() as session: + repo = build_run_event_repository(session) + rows = await repo.list_messages( + thread_id, + limit=limit, + before_seq=before_seq, + after_seq=after_seq, + ) + return [_event_to_dict(row) for row in rows] + + async def list_events( + self, + thread_id: str, + run_id: str, + *, + event_types: list[str] | None = None, + limit: int = 500, + ) -> list[dict[str, Any]]: + if not await self._thread_visible(thread_id): + return [] + async with self._session_factory() as session: + repo = build_run_event_repository(session) + rows = await repo.list_events(thread_id, run_id, event_types=event_types, limit=limit) + return [_event_to_dict(row) for row in rows] + + async def list_messages_by_run( + self, + thread_id: str, + run_id: str, + *, + limit: int = 50, + before_seq: int | None = None, + after_seq: int | None = None, + ) -> list[dict[str, Any]]: + if not await self._thread_visible(thread_id): + return [] + async with self._session_factory() as session: + repo = build_run_event_repository(session) + rows = await repo.list_messages_by_run( + thread_id, + run_id, + limit=limit, + before_seq=before_seq, + after_seq=after_seq, + ) + return [_event_to_dict(row) for row in rows] + + async def count_messages(self, thread_id: str) -> int: + if not await self._thread_visible(thread_id): + return 0 + async with self._session_factory() as session: + repo = build_run_event_repository(session) + return await repo.count_messages(thread_id) + + async def delete_by_thread(self, thread_id: str) -> int: + if not await self._thread_visible(thread_id): + return 0 + async with self._session_factory() as session: + try: + repo = build_run_event_repository(session) + count = await repo.delete_by_thread(thread_id) + await session.commit() + except Exception: + await session.rollback() + raise + return count + + async def delete_by_run(self, thread_id: str, run_id: str) -> int: + if not await self._thread_visible(thread_id): + return 0 + async with self._session_factory() as session: + try: + repo = build_run_event_repository(session) + count = await repo.delete_by_run(thread_id, run_id) + await session.commit() + except Exception: + await session.rollback() + raise + return count + + async def _thread_visible(self, thread_id: str) -> bool: + actor = get_actor_context() + if actor is None or actor.user_id is None: + return True + + async with self._session_factory() as session: + thread_repo = build_thread_meta_repository(session) + thread = await thread_repo.get_thread_meta(thread_id) + + if thread is None: + return True + return thread.user_id is None or thread.user_id == actor.user_id + + +def _event_create_from_dict(event: dict[str, Any]) -> RunEventCreate: + created_at = event.get("created_at") + return RunEventCreate( + thread_id=str(event["thread_id"]), + run_id=str(event["run_id"]), + event_type=str(event["event_type"]), + category=str(event["category"]), + content=event.get("content", ""), + metadata=dict(event.get("metadata") or {}), + created_at=datetime.fromisoformat(created_at) if isinstance(created_at, str) else created_at, + ) + + +def _event_to_dict(event: RunEvent) -> dict[str, Any]: + return { + "thread_id": event.thread_id, + "run_id": event.run_id, + "event_type": event.event_type, + "category": event.category, + "content": event.content, + "metadata": event.metadata, + "seq": event.seq, + "created_at": event.created_at.isoformat(), + } diff --git a/backend/app/infra/storage/runs.py b/backend/app/infra/storage/runs.py new file mode 100644 index 000000000..efc6fd019 --- /dev/null +++ b/backend/app/infra/storage/runs.py @@ -0,0 +1,515 @@ +"""Run lifecycle persistence adapters owned by the app layer.""" + +from __future__ import annotations + +import logging +import uuid +from collections.abc import Callable +from typing import Protocol, TypedDict, Unpack, cast + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from store.repositories import FeedbackCreate, Run, RunCreate, build_feedback_repository, build_run_repository + +from deerflow.runtime.actor_context import AUTO, resolve_user_id +from deerflow.runtime.serialization import serialize_lc_object +from deerflow.runtime.runs.observer import LifecycleEventType, RunLifecycleEvent, RunObserver +from deerflow.runtime.stream_bridge import JSONValue + +from .thread_meta import ThreadMetaStorage + +logger = logging.getLogger(__name__) + + +class RunCreateFields(TypedDict, total=False): + status: str + created_at: str + started_at: str + ended_at: str + assistant_id: str | None + user_id: str | None + follow_up_to_run_id: str | None + metadata: dict[str, JSONValue] + kwargs: dict[str, JSONValue] + + +class RunStatusUpdateFields(TypedDict, total=False): + started_at: str + ended_at: str + metadata: dict[str, JSONValue] + + +class RunCompletionFields(TypedDict, total=False): + total_input_tokens: int + total_output_tokens: int + total_tokens: int + llm_call_count: int + lead_agent_tokens: int + subagent_tokens: int + middleware_tokens: int + message_count: int + last_ai_message: str | None + first_human_message: str | None + error: str | None + + +class RunRow(TypedDict, total=False): + run_id: str + thread_id: str + assistant_id: str | None + status: str + multitask_strategy: str + follow_up_to_run_id: str | None + metadata: dict[str, JSONValue] + created_at: str + updated_at: str + started_at: str | None + ended_at: str | None + error: str | None + + +class RunReadRepository(Protocol): + """Protocol for durable run queries.""" + + async def get(self, run_id: str, *, user_id: str | None | object = AUTO) -> RunRow | None: ... + + async def list_by_thread( + self, + thread_id: str, + *, + limit: int = 100, + user_id: str | None | object = AUTO, + ) -> list[RunRow]: ... + + +class RunWriteRepository(Protocol): + """Protocol for durable run writes.""" + + async def create(self, run_id: str, thread_id: str, **kwargs: Unpack[RunCreateFields]) -> None: ... + async def update_status( + self, + run_id: str, + status: str, + **kwargs: Unpack[RunStatusUpdateFields], + ) -> None: ... + async def set_error(self, run_id: str, error: str) -> None: ... + async def update_run_completion( + self, + run_id: str, + *, + status: str, + **kwargs: Unpack[RunCompletionFields], + ) -> None: ... + + +class RunDeleteRepository(Protocol): + """Protocol for durable run deletion.""" + + async def delete(self, run_id: str) -> bool: ... + + +class _RepositoryContext: + def __init__( + self, + session_factory: async_sessionmaker[AsyncSession], + build_repo: Callable[[AsyncSession], object], + *, + commit: bool, + ) -> None: + self._session_factory = session_factory + self._build_repo = build_repo + self._commit = commit + self._session: AsyncSession | None = None + + async def __aenter__(self): + self._session = self._session_factory() + return self._build_repo(self._session) + + async def __aexit__(self, exc_type, exc, tb) -> None: + if self._session is None: + return + try: + if self._commit: + if exc_type is None: + await self._session.commit() + else: + await self._session.rollback() + finally: + await self._session.close() + + +def _run_to_row(row: Run) -> RunRow: + return { + "run_id": row.run_id, + "thread_id": row.thread_id, + "assistant_id": row.assistant_id, + "user_id": row.user_id, + "status": row.status, + "model_name": row.model_name, + "multitask_strategy": row.multitask_strategy, + "follow_up_to_run_id": row.follow_up_to_run_id, + "metadata": cast(dict[str, JSONValue], row.metadata), + "kwargs": cast(dict[str, JSONValue], row.kwargs), + "created_at": row.created_time.isoformat(), + "updated_at": row.updated_time.isoformat() if row.updated_time else "", + "total_input_tokens": row.total_input_tokens, + "total_output_tokens": row.total_output_tokens, + "total_tokens": row.total_tokens, + "llm_call_count": row.llm_call_count, + "lead_agent_tokens": row.lead_agent_tokens, + "subagent_tokens": row.subagent_tokens, + "middleware_tokens": row.middleware_tokens, + "message_count": row.message_count, + "first_human_message": row.first_human_message, + "last_ai_message": row.last_ai_message, + "error": row.error, + } + + +class FeedbackStoreAdapter: + """Expose feedback route semantics on top of storage package repositories.""" + + def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: + self._session_factory = session_factory + + async def create( + self, + *, + run_id: str, + thread_id: str, + rating: int, + owner_id: str | None = None, + user_id: str | None = None, + message_id: str | None = None, + comment: str | None = None, + ) -> dict[str, object]: + if rating not in (1, -1): + raise ValueError(f"rating must be +1 or -1, got {rating}") + effective_user_id = user_id if user_id is not None else owner_id + async with self._transaction() as repo: + row = await repo.create_feedback( + FeedbackCreate( + feedback_id=str(uuid.uuid4()), + run_id=run_id, + thread_id=thread_id, + rating=rating, + user_id=effective_user_id, + message_id=message_id, + comment=comment, + ) + ) + return _feedback_to_dict(row) + + async def get(self, feedback_id: str) -> dict[str, object] | None: + async with self._read() as repo: + row = await repo.get_feedback(feedback_id) + return _feedback_to_dict(row) if row is not None else None + + async def list_by_run( + self, + thread_id: str, + run_id: str, + *, + limit: int = 100, + user_id: str | None = None, + ) -> list[dict[str, object]]: + async with self._read() as repo: + rows = await repo.list_feedback_by_run(run_id) + filtered = [row for row in rows if row.thread_id == thread_id] + if user_id is not None: + filtered = [row for row in filtered if row.user_id == user_id] + return [_feedback_to_dict(row) for row in filtered][:limit] + + async def list_by_thread(self, thread_id: str, *, limit: int = 100) -> list[dict[str, object]]: + async with self._read() as repo: + rows = await repo.list_feedback_by_thread(thread_id) + return [_feedback_to_dict(row) for row in rows][:limit] + + async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict[str, object]: + rows = await self.list_by_run(thread_id, run_id) + positive = sum(1 for row in rows if row["rating"] == 1) + negative = sum(1 for row in rows if row["rating"] == -1) + return {"run_id": run_id, "total": len(rows), "positive": positive, "negative": negative} + + async def delete(self, feedback_id: str) -> bool: + async with self._transaction() as repo: + return await repo.delete_feedback(feedback_id) + + async def upsert( + self, + *, + run_id: str, + thread_id: str, + rating: int, + user_id: str, + comment: str | None = None, + ) -> dict[str, object]: + if rating not in (1, -1): + raise ValueError(f"rating must be +1 or -1, got {rating}") + async with self._transaction() as repo: + rows = await repo.list_feedback_by_run(run_id) + existing = next((row for row in rows if row.thread_id == thread_id and row.user_id == user_id), None) + feedback_id = existing.feedback_id if existing is not None else str(uuid.uuid4()) + if existing is not None: + await repo.delete_feedback(existing.feedback_id) + row = await repo.create_feedback( + FeedbackCreate( + feedback_id=feedback_id, + run_id=run_id, + thread_id=thread_id, + rating=rating, + user_id=user_id, + comment=comment, + ) + ) + return _feedback_to_dict(row) + + async def delete_by_run(self, *, thread_id: str, run_id: str, user_id: str) -> bool: + async with self._transaction() as repo: + rows = await repo.list_feedback_by_run(run_id) + existing = next((row for row in rows if row.thread_id == thread_id and row.user_id == user_id), None) + if existing is None: + return False + return await repo.delete_feedback(existing.feedback_id) + + async def list_by_thread_grouped(self, thread_id: str, *, user_id: str) -> dict[str, dict[str, object]]: + rows = await self.list_by_thread(thread_id) + return { + row["run_id"]: row + for row in rows + if row["user_id"] == user_id + } + + def _read(self) -> _RepositoryContext: + return _RepositoryContext(self._session_factory, build_feedback_repository, commit=False) + + def _transaction(self) -> _RepositoryContext: + return _RepositoryContext(self._session_factory, build_feedback_repository, commit=True) + + +def _feedback_to_dict(row) -> dict[str, object]: + return { + "feedback_id": row.feedback_id, + "run_id": row.run_id, + "thread_id": row.thread_id, + "user_id": row.user_id, + "owner_id": row.user_id, + "message_id": row.message_id, + "rating": row.rating, + "comment": row.comment, + "created_at": row.created_time.isoformat(), + } + + +class RunStoreAdapter: + """Expose runs facade storage semantics on top of storage package repositories.""" + + def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: + self._session_factory = session_factory + + async def get(self, run_id: str, *, user_id: str | None | object = AUTO) -> RunRow | None: + effective_user_id = resolve_user_id(user_id, method_name="RunStoreAdapter.get") + async with self._read() as repo: + row = await repo.get_run(run_id) + if row is None: + return None + if effective_user_id is not None and row.user_id != effective_user_id: + return None + return _run_to_row(row) + + async def list_by_thread( + self, + thread_id: str, + *, + limit: int = 100, + user_id: str | None | object = AUTO, + ) -> list[RunRow]: + effective_user_id = resolve_user_id(user_id, method_name="RunStoreAdapter.list_by_thread") + async with self._read() as repo: + rows = await repo.list_runs_by_thread(thread_id, limit=limit, offset=0) + if effective_user_id is not None: + rows = [row for row in rows if row.user_id == effective_user_id] + return [_run_to_row(row) for row in rows] + + async def create(self, run_id: str, thread_id: str, **kwargs: Unpack[RunCreateFields]) -> None: + metadata = cast(dict[str, JSONValue], serialize_lc_object(kwargs.get("metadata") or {})) + run_kwargs = cast(dict[str, JSONValue], serialize_lc_object(kwargs.get("kwargs") or {})) + effective_user_id = resolve_user_id(kwargs.get("user_id", AUTO), method_name="RunStoreAdapter.create") + async with self._transaction() as repo: + await repo.create_run( + RunCreate( + run_id=run_id, + thread_id=thread_id, + assistant_id=kwargs.get("assistant_id"), + user_id=effective_user_id, + status=kwargs.get("status", "pending"), + metadata=dict(metadata), + kwargs=dict(run_kwargs), + follow_up_to_run_id=kwargs.get("follow_up_to_run_id"), + ) + ) + + async def delete(self, run_id: str, *, user_id: str | None | object = AUTO) -> bool: + async with self._transaction() as repo: + existing = await repo.get_run(run_id) + if existing is None: + return False + effective_user_id = resolve_user_id(user_id, method_name="RunStoreAdapter.delete") + if effective_user_id is not None and existing.user_id != effective_user_id: + return False + await repo.delete_run(run_id) + return True + + async def update_status( + self, + run_id: str, + status: str, + **kwargs: Unpack[RunStatusUpdateFields], + ) -> None: + async with self._transaction() as repo: + await repo.update_run_status(run_id, status) + + async def set_error(self, run_id: str, error: str) -> None: + async with self._transaction() as repo: + await repo.update_run_status(run_id, "error", error=error) + + async def update_run_completion( + self, + run_id: str, + *, + status: str, + **kwargs: Unpack[RunCompletionFields], + ) -> None: + async with self._transaction() as repo: + await repo.update_run_completion( + run_id, + status=status, + total_input_tokens=kwargs.get("total_input_tokens", 0), + total_output_tokens=kwargs.get("total_output_tokens", 0), + total_tokens=kwargs.get("total_tokens", 0), + llm_call_count=kwargs.get("llm_call_count", 0), + lead_agent_tokens=kwargs.get("lead_agent_tokens", 0), + subagent_tokens=kwargs.get("subagent_tokens", 0), + middleware_tokens=kwargs.get("middleware_tokens", 0), + message_count=kwargs.get("message_count", 0), + last_ai_message=kwargs.get("last_ai_message"), + first_human_message=kwargs.get("first_human_message"), + error=kwargs.get("error"), + ) + + def _read(self) -> _RepositoryContext: + return _RepositoryContext(self._session_factory, build_run_repository, commit=False) + + def _transaction(self) -> _RepositoryContext: + return _RepositoryContext(self._session_factory, build_run_repository, commit=True) + + +class StorageRunObserver(RunObserver): + """Persist run lifecycle state into app-owned repositories.""" + + def __init__( + self, + run_write_repo: RunWriteRepository | None = None, + thread_meta_storage: ThreadMetaStorage | None = None, + ) -> None: + self._run_write_repo = run_write_repo + self._thread_meta_storage = thread_meta_storage + + async def on_event(self, event: RunLifecycleEvent) -> None: + try: + await self._dispatch(event) + except Exception: + logger.exception( + "StorageRunObserver failed to persist event %s for run %s", + event.event_type, + event.run_id, + ) + + async def _dispatch(self, event: RunLifecycleEvent) -> None: + handlers = { + LifecycleEventType.RUN_STARTED: self._handle_run_started, + LifecycleEventType.RUN_COMPLETED: self._handle_run_completed, + LifecycleEventType.RUN_FAILED: self._handle_run_failed, + LifecycleEventType.RUN_CANCELLED: self._handle_run_cancelled, + LifecycleEventType.THREAD_STATUS_UPDATED: self._handle_thread_status, + } + + handler = handlers.get(event.event_type) + if handler: + await handler(event) + + async def _handle_run_started(self, event: RunLifecycleEvent) -> None: + if self._run_write_repo: + await self._run_write_repo.update_status( + run_id=event.run_id, + status="running", + started_at=event.occurred_at.isoformat(), + ) + + async def _handle_run_completed(self, event: RunLifecycleEvent) -> None: + payload = dict(event.payload) if event.payload else {} + if self._run_write_repo: + completion_data = payload.get("completion_data") + if isinstance(completion_data, dict): + await self._run_write_repo.update_run_completion( + run_id=event.run_id, + status="success", + **cast(RunCompletionFields, completion_data), + ) + else: + await self._run_write_repo.update_status( + run_id=event.run_id, + status="success", + ended_at=event.occurred_at.isoformat(), + ) + + if self._thread_meta_storage and "title" in payload: + await self._thread_meta_storage.sync_thread_title( + thread_id=event.thread_id, + title=payload["title"], + ) + + async def _handle_run_failed(self, event: RunLifecycleEvent) -> None: + if self._run_write_repo: + payload = dict(event.payload) if event.payload else {} + error = payload.get("error", "Unknown error") + completion_data = payload.get("completion_data") + if isinstance(completion_data, dict): + await self._run_write_repo.update_run_completion( + run_id=event.run_id, + status="error", + error=str(error), + **cast(RunCompletionFields, completion_data), + ) + else: + await self._run_write_repo.update_status( + run_id=event.run_id, + status="error", + ended_at=event.occurred_at.isoformat(), + ) + await self._run_write_repo.set_error(run_id=event.run_id, error=str(error)) + + async def _handle_run_cancelled(self, event: RunLifecycleEvent) -> None: + if self._run_write_repo: + payload = dict(event.payload) if event.payload else {} + completion_data = payload.get("completion_data") + if isinstance(completion_data, dict): + await self._run_write_repo.update_run_completion( + run_id=event.run_id, + status="interrupted", + **cast(RunCompletionFields, completion_data), + ) + else: + await self._run_write_repo.update_status( + run_id=event.run_id, + status="interrupted", + ended_at=event.occurred_at.isoformat(), + ) + + async def _handle_thread_status(self, event: RunLifecycleEvent) -> None: + if self._thread_meta_storage: + payload = dict(event.payload) if event.payload else {} + status = payload.get("status", "idle") + await self._thread_meta_storage.sync_thread_status( + thread_id=event.thread_id, + status=status, + ) diff --git a/backend/app/infra/storage/thread_meta.py b/backend/app/infra/storage/thread_meta.py new file mode 100644 index 000000000..3205d9199 --- /dev/null +++ b/backend/app/infra/storage/thread_meta.py @@ -0,0 +1,208 @@ +"""Thread metadata storage adapter owned by the app layer.""" + +from __future__ import annotations + +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker +from store.repositories import build_thread_meta_repository +from store.repositories.contracts import ( + ThreadMeta, + ThreadMetaCreate, + ThreadMetaRepositoryProtocol, +) +from deerflow.runtime.actor_context import AUTO, resolve_user_id + + +class ThreadMetaStoreAdapter: + """Use storage package thread repositories with per-call sessions.""" + + def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: + self._session_factory = session_factory + + async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta: + async with self._transaction() as repo: + return await repo.create_thread_meta(data) + + async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None: + async with self._read() as repo: + return await repo.get_thread_meta(thread_id) + + async def update_thread_meta( + self, + thread_id: str, + *, + assistant_id: str | None = None, + display_name: str | None = None, + status: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + async with self._transaction() as repo: + await repo.update_thread_meta( + thread_id, + assistant_id=assistant_id, + display_name=display_name, + status=status, + metadata=metadata, + ) + + async def delete_thread(self, thread_id: str) -> None: + async with self._transaction() as repo: + await repo.delete_thread(thread_id) + + async def search_threads( + self, + *, + metadata: dict[str, Any] | None = None, + status: str | None = None, + user_id: str | None = None, + assistant_id: str | None = None, + limit: int = 100, + offset: int = 0, + ) -> list[ThreadMeta]: + async with self._read() as repo: + return await repo.search_threads( + metadata=metadata, + status=status, + user_id=user_id, + assistant_id=assistant_id, + limit=limit, + offset=offset, + ) + + def _read(self): + return _ThreadMetaRepositoryContext(self._session_factory, commit=False) + + def _transaction(self): + return _ThreadMetaRepositoryContext(self._session_factory, commit=True) + + +class _ThreadMetaRepositoryContext: + def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, commit: bool) -> None: + self._session_factory = session_factory + self._commit = commit + self._session: AsyncSession | None = None + + async def __aenter__(self): + self._session = self._session_factory() + return build_thread_meta_repository(self._session) + + async def __aexit__(self, exc_type, exc, tb) -> None: + if self._session is None: + return + try: + if self._commit: + if exc_type is None: + await self._session.commit() + else: + await self._session.rollback() + finally: + await self._session.close() + + +class ThreadMetaStorage: + """App-facing adapter around the storage thread metadata contract.""" + + def __init__(self, repo: ThreadMetaRepositoryProtocol) -> None: + self._repo = repo + + async def get_thread(self, thread_id: str, *, user_id: str | None | object = AUTO) -> ThreadMeta | None: + thread = await self._repo.get_thread_meta(thread_id) + if thread is None: + return None + effective_user_id = resolve_user_id(user_id, method_name="ThreadMetaStorage.get_thread") + if effective_user_id is not None and thread.user_id != effective_user_id: + return None + return thread + + async def ensure_thread( + self, + *, + thread_id: str, + assistant_id: str | None = None, + metadata: dict[str, Any] | None = None, + user_id: str | None | object = AUTO, + ) -> ThreadMeta: + effective_user_id = resolve_user_id(user_id, method_name="ThreadMetaStorage.ensure_thread") + existing = await self.get_thread(thread_id, user_id=effective_user_id) + if existing is not None: + return existing + + return await self._repo.create_thread_meta( + ThreadMetaCreate( + thread_id=thread_id, + assistant_id=assistant_id, + user_id=effective_user_id, + metadata=metadata or {}, + ) + ) + + async def ensure_thread_running( + self, + *, + thread_id: str, + assistant_id: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> ThreadMeta | None: + existing = await self._repo.get_thread_meta(thread_id) + if existing is None: + return await self._repo.create_thread_meta( + ThreadMetaCreate( + thread_id=thread_id, + assistant_id=assistant_id, + status="running", + metadata=metadata or {}, + ) + ) + + await self._repo.update_thread_meta(thread_id, status="running") + return await self._repo.get_thread_meta(thread_id) + + async def sync_thread_title(self, *, thread_id: str, title: str) -> None: + await self._repo.update_thread_meta(thread_id, display_name=title) + + async def sync_thread_assistant_id(self, *, thread_id: str, assistant_id: str) -> None: + await self._repo.update_thread_meta(thread_id, assistant_id=assistant_id) + + async def sync_thread_status(self, *, thread_id: str, status: str) -> None: + await self._repo.update_thread_meta(thread_id, status=status) + + async def sync_thread_metadata( + self, + *, + thread_id: str, + metadata: dict[str, Any], + ) -> None: + await self._repo.update_thread_meta(thread_id, metadata=metadata) + + async def delete_thread(self, thread_id: str) -> None: + await self._repo.delete_thread(thread_id) + + async def search_threads( + self, + *, + metadata: dict[str, Any] | None = None, + status: str | None = None, + user_id: str | None | object = AUTO, + assistant_id: str | None = None, + limit: int = 100, + offset: int = 0, + ) -> list[ThreadMeta]: + normalized_status = status.strip() if status is not None else None + resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaStorage.search_threads") + normalized_user_id = resolved_user_id.strip() if resolved_user_id is not None else None + normalized_assistant_id = ( + assistant_id.strip() if assistant_id is not None else None + ) + + return await self._repo.search_threads( + metadata=metadata, + status=normalized_status or None, + user_id=normalized_user_id or None, + assistant_id=normalized_assistant_id or None, + limit=limit, + offset=offset, + ) + + +__all__ = ["ThreadMetaStorage", "ThreadMetaStoreAdapter"] diff --git a/backend/app/infra/stream_bridge/__init__.py b/backend/app/infra/stream_bridge/__init__.py new file mode 100644 index 000000000..bd58241f5 --- /dev/null +++ b/backend/app/infra/stream_bridge/__init__.py @@ -0,0 +1,6 @@ +"""App-owned stream bridge adapters and factory.""" + +from .factory import build_stream_bridge +from .adapters import MemoryStreamBridge, RedisStreamBridge + +__all__ = ["MemoryStreamBridge", "RedisStreamBridge", "build_stream_bridge"] diff --git a/backend/app/infra/stream_bridge/adapters/__init__.py b/backend/app/infra/stream_bridge/adapters/__init__.py new file mode 100644 index 000000000..76727cfd8 --- /dev/null +++ b/backend/app/infra/stream_bridge/adapters/__init__.py @@ -0,0 +1,6 @@ +"""Concrete stream bridge adapters owned by the app layer.""" + +from .memory import MemoryStreamBridge +from .redis import RedisStreamBridge + +__all__ = ["MemoryStreamBridge", "RedisStreamBridge"] diff --git a/backend/app/infra/stream_bridge/adapters/memory.py b/backend/app/infra/stream_bridge/adapters/memory.py new file mode 100644 index 000000000..bde6cf614 --- /dev/null +++ b/backend/app/infra/stream_bridge/adapters/memory.py @@ -0,0 +1,450 @@ +"""In-memory stream bridge implementation owned by the app layer.""" + +from __future__ import annotations + +import asyncio +import json +import logging +import time +from collections.abc import AsyncIterator +from dataclasses import dataclass, field +from typing import Any, Literal + +from deerflow.runtime.stream_bridge import ( + CANCELLED_SENTINEL, + END_SENTINEL, + HEARTBEAT_SENTINEL, + TERMINAL_STATES, + ResumeResult, + StreamBridge, + StreamEvent, + StreamStatus, +) +from deerflow.runtime.stream_bridge.exceptions import ( + BridgeClosedError, + StreamCapacityExceededError, + StreamTerminatedError, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class _RunStream: + condition: asyncio.Condition = field(default_factory=asyncio.Condition) + events: list[StreamEvent] = field(default_factory=list) + id_to_offset: dict[str, int] = field(default_factory=dict) + start_offset: int = 0 + current_bytes: int = 0 + seq: int = 0 + status: StreamStatus = StreamStatus.ACTIVE + created_at: float = field(default_factory=time.monotonic) + last_publish_at: float | None = None + ended_at: float | None = None + subscriber_count: int = 0 + last_subscribe_at: float | None = None + awaiting_input: bool = False + awaiting_since: float | None = None + + +class MemoryStreamBridge(StreamBridge): + """Per-run in-memory event log implementation.""" + + def __init__( + self, + *, + max_events_per_stream: int = 256, + max_bytes_per_stream: int = 10 * 1024 * 1024, + max_active_streams: int = 1000, + stream_eviction_policy: Literal["reject", "lru"] = "lru", + terminal_retention_ttl: float = 300.0, + active_no_publish_timeout: float = 600.0, + orphan_timeout: float = 60.0, + max_stream_age: float = 86400.0, + hitl_extended_timeout: float = 7200.0, + cleanup_interval: float = 30.0, + queue_maxsize: int | None = None, + ) -> None: + if queue_maxsize is not None: + max_events_per_stream = queue_maxsize + + self._max_events = max_events_per_stream + self._max_bytes = max_bytes_per_stream + self._max_streams = max_active_streams + self._eviction_policy = stream_eviction_policy + self._terminal_ttl = terminal_retention_ttl + self._active_timeout = active_no_publish_timeout + self._orphan_timeout = orphan_timeout + self._max_age = max_stream_age + self._hitl_timeout = hitl_extended_timeout + self._cleanup_interval = cleanup_interval + self._streams: dict[str, _RunStream] = {} + self._registry_lock = asyncio.Lock() + self._closed = False + self._cleanup_task: asyncio.Task[None] | None = None + + async def start(self) -> None: + if self._cleanup_task is None: + self._cleanup_task = asyncio.create_task(self._cleanup_loop()) + logger.info( + "MemoryStreamBridge started (max_events=%d, max_bytes=%d, max_streams=%d)", + self._max_events, + self._max_bytes, + self._max_streams, + ) + + async def close(self) -> None: + async with self._registry_lock: + self._closed = True + if self._cleanup_task is not None: + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + self._cleanup_task = None + + for stream in self._streams.values(): + async with stream.condition: + stream.status = StreamStatus.CLOSED + stream.condition.notify_all() + + self._streams.clear() + logger.info("MemoryStreamBridge closed") + + async def _get_or_create_stream(self, run_id: str) -> _RunStream: + stream = self._streams.get(run_id) + if stream is not None: + return stream + + async with self._registry_lock: + if self._closed: + raise BridgeClosedError("Stream bridge is closed") + + stream = self._streams.get(run_id) + if stream is not None: + return stream + + if len(self._streams) >= self._max_streams: + if self._eviction_policy == "reject": + raise StreamCapacityExceededError( + f"Max {self._max_streams} active streams reached" + ) + evicted = self._evict_oldest_terminal() + if evicted is None: + raise StreamCapacityExceededError("All streams active, cannot evict") + logger.info("Evicted stream %s to make room", evicted) + + stream = _RunStream() + self._streams[run_id] = stream + logger.debug("Created stream for run %s", run_id) + return stream + + def _evict_oldest_terminal(self) -> str | None: + oldest_run_id: str | None = None + oldest_ended_at: float = float("inf") + for run_id, stream in self._streams.items(): + if stream.status in TERMINAL_STATES and stream.ended_at is not None: + if stream.ended_at < oldest_ended_at: + oldest_ended_at = stream.ended_at + oldest_run_id = run_id + if oldest_run_id is not None: + del self._streams[oldest_run_id] + return oldest_run_id + return None + + def _next_id(self, stream: _RunStream) -> str: + stream.seq += 1 + return f"{int(time.time() * 1000)}-{stream.seq}" + + def _estimate_size(self, event: StreamEvent) -> int: + base = len(event.id) + len(event.event) + 100 + if event.data is None: + return base + if isinstance(event.data, str): + return base + len(event.data) + if isinstance(event.data, (dict, list)): + try: + return base + len(json.dumps(event.data, default=str)) + except (TypeError, ValueError): + return base + 200 + return base + 50 + + def _evict_overflow(self, stream: _RunStream) -> None: + while len(stream.events) > self._max_events or stream.current_bytes > self._max_bytes: + if not stream.events: + break + evicted = stream.events.pop(0) + stream.id_to_offset.pop(evicted.id, None) + stream.current_bytes -= self._estimate_size(evicted) + stream.start_offset += 1 + + async def publish(self, run_id: str, event: str, data: Any) -> str: + stream = await self._get_or_create_stream(run_id) + async with stream.condition: + if stream.status != StreamStatus.ACTIVE: + raise StreamTerminatedError( + f"Cannot publish to {stream.status.value} stream" + ) + + entry = StreamEvent(id=self._next_id(stream), event=event, data=data) + absolute_offset = stream.start_offset + len(stream.events) + stream.events.append(entry) + stream.id_to_offset[entry.id] = absolute_offset + stream.current_bytes += self._estimate_size(entry) + stream.last_publish_at = time.monotonic() + self._evict_overflow(stream) + stream.condition.notify_all() + return entry.id + + async def publish_end(self, run_id: str) -> str: + return await self.publish_terminal(run_id, StreamStatus.ENDED) + + async def publish_terminal( + self, + run_id: str, + kind: StreamStatus, + data: Any = None, + ) -> str: + if kind not in TERMINAL_STATES: + raise ValueError(f"Invalid terminal kind: {kind}") + + stream = await self._get_or_create_stream(run_id) + async with stream.condition: + if stream.status != StreamStatus.ACTIVE: + for evt in reversed(stream.events): + if evt.event in ("end", "cancel", "error", "dead_letter"): + return evt.id + return "" + + event_name = { + StreamStatus.ENDED: "end", + StreamStatus.CANCELLED: "cancel", + StreamStatus.ERRORED: "error", + }[kind] + entry = StreamEvent(id=self._next_id(stream), event=event_name, data=data) + absolute_offset = stream.start_offset + len(stream.events) + stream.events.append(entry) + stream.id_to_offset[entry.id] = absolute_offset + stream.current_bytes += self._estimate_size(entry) + stream.status = kind + stream.ended_at = time.monotonic() + stream.awaiting_input = False + stream.condition.notify_all() + logger.debug("Stream %s terminal: %s", run_id, kind.value) + return entry.id + + async def cancel(self, run_id: str) -> None: + await self.publish_terminal(run_id, StreamStatus.CANCELLED) + + async def subscribe( + self, + run_id: str, + *, + last_event_id: str | None = None, + heartbeat_interval: float = 15.0, + ) -> AsyncIterator[StreamEvent]: + stream = await self._get_or_create_stream(run_id) + resume = self._resolve_resume_point(stream, last_event_id) + next_offset = resume.next_offset + + async with stream.condition: + stream.subscriber_count += 1 + stream.last_subscribe_at = time.monotonic() + + try: + while True: + entry_to_yield: StreamEvent | None = None + sentinel_to_yield: StreamEvent | None = None + should_return = False + should_wait = False + + async with stream.condition: + if self._closed or stream.status == StreamStatus.CLOSED: + sentinel_to_yield = CANCELLED_SENTINEL + should_return = True + elif next_offset < stream.start_offset: + next_offset = stream.start_offset + else: + local_index = next_offset - stream.start_offset + if 0 <= local_index < len(stream.events): + entry_to_yield = stream.events[local_index] + next_offset += 1 + if entry_to_yield.event in ("end", "cancel", "error", "dead_letter"): + should_return = True + elif stream.status in TERMINAL_STATES: + sentinel_to_yield = END_SENTINEL + should_return = True + else: + should_wait = True + try: + await asyncio.wait_for( + stream.condition.wait(), + timeout=heartbeat_interval, + ) + except TimeoutError: + pass + + if sentinel_to_yield is not None: + yield sentinel_to_yield + if should_return: + return + continue + + if entry_to_yield is not None: + yield entry_to_yield + if should_return: + return + continue + + if should_wait: + async with stream.condition: + local_index = next_offset - stream.start_offset + has_events = 0 <= local_index < len(stream.events) + is_terminal = stream.status in TERMINAL_STATES + if not has_events and not is_terminal: + yield HEARTBEAT_SENTINEL + + finally: + async with stream.condition: + stream.subscriber_count = max(0, stream.subscriber_count - 1) + + async def mark_awaiting_input(self, run_id: str) -> None: + stream = self._streams.get(run_id) + if stream is None: + return + async with stream.condition: + if stream.status == StreamStatus.ACTIVE: + stream.awaiting_input = True + stream.awaiting_since = time.monotonic() + logger.debug("Stream %s marked as awaiting input", run_id) + + async def cleanup(self, run_id: str, *, delay: float = 0) -> None: + if delay > 0: + await asyncio.sleep(delay) + await self._do_cleanup(run_id, "manual") + + async def _do_cleanup(self, run_id: str, reason: str) -> None: + async with self._registry_lock: + stream = self._streams.pop(run_id, None) + if stream is not None: + async with stream.condition: + stream.status = StreamStatus.CLOSED + stream.condition.notify_all() + logger.debug("Cleaned up stream %s (reason: %s)", run_id, reason) + + async def _mark_dead_letter(self, run_id: str, reason: str) -> None: + stream = self._streams.get(run_id) + if stream is None: + return + async with stream.condition: + if stream.status != StreamStatus.ACTIVE: + return + entry = StreamEvent( + id=self._next_id(stream), + event="dead_letter", + data={"reason": reason, "timestamp": time.time()}, + ) + absolute_offset = stream.start_offset + len(stream.events) + stream.events.append(entry) + stream.id_to_offset[entry.id] = absolute_offset + stream.current_bytes += self._estimate_size(entry) + stream.status = StreamStatus.ERRORED + stream.ended_at = time.monotonic() + stream.condition.notify_all() + logger.warning("Stream %s marked as dead letter: %s", run_id, reason) + + async def _cleanup_loop(self) -> None: + while not self._closed: + try: + await asyncio.sleep(self._cleanup_interval) + except asyncio.CancelledError: + break + + now = time.monotonic() + to_cleanup: list[tuple[str, str]] = [] + to_mark_dead: list[tuple[str, str]] = [] + + async with self._registry_lock: + for run_id, stream in list(self._streams.items()): + if now - stream.created_at > self._max_age: + to_cleanup.append((run_id, "max_age_exceeded")) + continue + + if stream.status == StreamStatus.ACTIVE: + timeout = self._hitl_timeout if stream.awaiting_input else self._active_timeout + last_activity = stream.last_publish_at or stream.created_at + if now - last_activity > timeout: + to_mark_dead.append((run_id, "no_publish_timeout")) + continue + + if stream.status in TERMINAL_STATES and stream.ended_at: + if stream.subscriber_count > 0: + continue + last_sub = stream.last_subscribe_at or stream.ended_at + if now - last_sub > self._orphan_timeout: + to_cleanup.append((run_id, "orphan")) + continue + if now - stream.ended_at > self._terminal_ttl: + to_cleanup.append((run_id, "ttl_expired")) + + for run_id, reason in to_mark_dead: + await self._mark_dead_letter(run_id, reason) + for run_id, reason in to_cleanup: + await self._do_cleanup(run_id, reason) + + def get_stats(self) -> dict[str, Any]: + active = sum(1 for s in self._streams.values() if s.status == StreamStatus.ACTIVE) + terminal = sum(1 for s in self._streams.values() if s.status in TERMINAL_STATES) + total_events = sum(len(s.events) for s in self._streams.values()) + total_bytes = sum(s.current_bytes for s in self._streams.values()) + total_subs = sum(s.subscriber_count for s in self._streams.values()) + return { + "total_streams": len(self._streams), + "active_streams": active, + "terminal_streams": terminal, + "total_events": total_events, + "total_bytes": total_bytes, + "total_subscribers": total_subs, + "closed": self._closed, + } + + def _resolve_resume_point( + self, + stream: _RunStream, + last_event_id: str | None, + ) -> ResumeResult: + if last_event_id is None: + return ResumeResult(next_offset=stream.start_offset, status="fresh") + if last_event_id in stream.id_to_offset: + return ResumeResult( + next_offset=stream.id_to_offset[last_event_id] + 1, + status="resumed", + ) + + parts = last_event_id.split("-") + if len(parts) != 2: + return ResumeResult(next_offset=stream.start_offset, status="invalid") + try: + event_ts = int(parts[0]) + _event_seq = int(parts[1]) + except ValueError: + return ResumeResult(next_offset=stream.start_offset, status="invalid") + + if stream.events: + try: + oldest_parts = stream.events[0].id.split("-") + oldest_ts = int(oldest_parts[0]) + if event_ts < oldest_ts: + return ResumeResult( + next_offset=stream.start_offset, + status="evicted", + gap_count=stream.start_offset, + ) + except (ValueError, IndexError): + pass + + return ResumeResult(next_offset=stream.start_offset, status="unknown") + + +__all__ = ["MemoryStreamBridge"] diff --git a/backend/app/infra/stream_bridge/adapters/redis.py b/backend/app/infra/stream_bridge/adapters/redis.py new file mode 100644 index 000000000..e124c8f9f --- /dev/null +++ b/backend/app/infra/stream_bridge/adapters/redis.py @@ -0,0 +1,37 @@ +"""Redis-backed stream bridge placeholder owned by the app layer.""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Any + +from deerflow.runtime.stream_bridge import StreamBridge, StreamEvent + + +class RedisStreamBridge(StreamBridge): + """Reserved app-owned Redis implementation. + + Phase 1 intentionally keeps Redis out of the harness package. The concrete + implementation will live here once cross-process streaming is introduced. + """ + + def __init__(self, *, redis_url: str) -> None: + self._redis_url = redis_url + + async def publish(self, run_id: str, event: str, data: Any) -> str: + raise NotImplementedError("Redis stream bridge will be implemented in app infra") + + async def publish_end(self, run_id: str) -> str: + raise NotImplementedError("Redis stream bridge will be implemented in app infra") + + def subscribe( + self, + run_id: str, + *, + last_event_id: str | None = None, + heartbeat_interval: float = 15.0, + ) -> AsyncIterator[StreamEvent]: + raise NotImplementedError("Redis stream bridge will be implemented in app infra") + + async def cleanup(self, run_id: str, *, delay: float = 0) -> None: + raise NotImplementedError("Redis stream bridge will be implemented in app infra") diff --git a/backend/app/infra/stream_bridge/factory.py b/backend/app/infra/stream_bridge/factory.py new file mode 100644 index 000000000..2c777551f --- /dev/null +++ b/backend/app/infra/stream_bridge/factory.py @@ -0,0 +1,50 @@ +"""App-owned stream bridge factory.""" + +from __future__ import annotations + +import logging +from collections.abc import AsyncIterator +from contextlib import AbstractAsyncContextManager, asynccontextmanager + +from deerflow.config.stream_bridge_config import get_stream_bridge_config +from deerflow.runtime.stream_bridge import StreamBridge + +from .adapters import MemoryStreamBridge, RedisStreamBridge + +logger = logging.getLogger(__name__) + + +def build_stream_bridge(config=None) -> AbstractAsyncContextManager[StreamBridge]: + """Build the configured app-owned stream bridge.""" + return _build_stream_bridge_impl(config) + + +@asynccontextmanager +async def _build_stream_bridge_impl(config=None) -> AsyncIterator[StreamBridge]: + if config is None: + config = get_stream_bridge_config() + + if config is None or config.type == "memory": + maxsize = config.queue_maxsize if config is not None else 256 + bridge = MemoryStreamBridge(queue_maxsize=maxsize) + await bridge.start() + logger.info("Stream bridge initialised: memory (queue_maxsize=%d)", maxsize) + try: + yield bridge + finally: + await bridge.close() + return + + if config.type == "redis": + if not config.redis_url: + raise ValueError("Redis stream bridge requires redis_url") + bridge = RedisStreamBridge(redis_url=config.redis_url) + await bridge.start() + logger.info("Stream bridge initialised: redis (%s)", config.redis_url) + try: + yield bridge + finally: + await bridge.close() + return + + raise ValueError(f"Unknown stream bridge type: {config.type!r}")