mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
feat(infra): add new infrastructure layer for storage and streaming
Add app/infra/ package with: - storage/ - repository adapters for runs, run_events, thread_meta - run_events/ - JSONL-based event store with factory - stream_bridge/ - memory and redis adapters for SSE streaming This layer provides the persistence abstractions used by the gateway services, replacing the old deerflow/persistence modules. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
14892e1463
commit
274255b1a5
1
backend/app/infra/__init__.py
Normal file
1
backend/app/infra/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Application-owned infrastructure adapters and wiring."""
|
||||
6
backend/app/infra/run_events/__init__.py
Normal file
6
backend/app/infra/run_events/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""Run event store backends owned by app infrastructure."""
|
||||
|
||||
from .factory import build_run_event_store
|
||||
from .jsonl_store import JsonlRunEventStore
|
||||
|
||||
__all__ = ["JsonlRunEventStore", "build_run_event_store"]
|
||||
25
backend/app/infra/run_events/factory.py
Normal file
25
backend/app/infra/run_events/factory.py
Normal file
@ -0,0 +1,25 @@
|
||||
"""Factory for app-owned run event store backends."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.infra.storage import AppRunEventStore
|
||||
from deerflow.config import get_app_config
|
||||
|
||||
from .jsonl_store import JsonlRunEventStore
|
||||
|
||||
|
||||
def build_run_event_store(session_factory: async_sessionmaker[AsyncSession]) -> AppRunEventStore | JsonlRunEventStore:
|
||||
"""Build the run event store selected by app configuration."""
|
||||
|
||||
config = get_app_config().run_events
|
||||
if config.backend == "db":
|
||||
return AppRunEventStore(session_factory)
|
||||
if config.backend == "jsonl":
|
||||
return JsonlRunEventStore(
|
||||
base_dir=Path(config.jsonl_base_dir),
|
||||
)
|
||||
raise ValueError(f"Unsupported run event backend: {config.backend}")
|
||||
210
backend/app/infra/run_events/jsonl_store.py
Normal file
210
backend/app/infra/run_events/jsonl_store.py
Normal file
@ -0,0 +1,210 @@
|
||||
"""JSONL run event store backend owned by app infrastructure."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import shutil
|
||||
from collections.abc import Iterable
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
class JsonlRunEventStore:
|
||||
"""Append-only JSONL implementation of the runs RunEventStore protocol."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_dir: Path | str = ".deer-flow/run-events",
|
||||
) -> None:
|
||||
self._base_dir = Path(base_dir)
|
||||
self._locks: dict[str, asyncio.Lock] = {}
|
||||
self._locks_guard = asyncio.Lock()
|
||||
|
||||
async def put_batch(self, events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
if not events:
|
||||
return []
|
||||
|
||||
grouped: dict[str, list[dict[str, Any]]] = {}
|
||||
for event in events:
|
||||
grouped.setdefault(str(event["thread_id"]), []).append(event)
|
||||
|
||||
records_by_thread: dict[str, list[dict[str, Any]]] = {}
|
||||
for thread_id, thread_events in grouped.items():
|
||||
async with await self._thread_lock(thread_id):
|
||||
records_by_thread[thread_id] = self._append_thread_events(thread_id, thread_events)
|
||||
|
||||
indexes = {thread_id: 0 for thread_id in records_by_thread}
|
||||
ordered: list[dict[str, Any]] = []
|
||||
for event in events:
|
||||
thread_id = str(event["thread_id"])
|
||||
index = indexes[thread_id]
|
||||
ordered.append(records_by_thread[thread_id][index])
|
||||
indexes[thread_id] = index + 1
|
||||
return ordered
|
||||
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
events = [event for event in await self._read_thread_events(thread_id) if event.get("category") == "message"]
|
||||
if before_seq is not None:
|
||||
events = [event for event in events if int(event["seq"]) < before_seq]
|
||||
return events[-limit:]
|
||||
if after_seq is not None:
|
||||
events = [event for event in events if int(event["seq"]) > after_seq]
|
||||
return events[:limit]
|
||||
return events[-limit:]
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
event_types: list[str] | None = None,
|
||||
limit: int = 500,
|
||||
) -> list[dict[str, Any]]:
|
||||
event_type_set = set(event_types or [])
|
||||
events = [
|
||||
event
|
||||
for event in await self._read_thread_events(thread_id)
|
||||
if event.get("run_id") == run_id and (not event_type_set or event.get("event_type") in event_type_set)
|
||||
]
|
||||
return events[:limit]
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
events = [
|
||||
event
|
||||
for event in await self._read_thread_events(thread_id)
|
||||
if event.get("run_id") == run_id and event.get("category") == "message"
|
||||
]
|
||||
if before_seq is not None:
|
||||
events = [event for event in events if int(event["seq"]) < before_seq]
|
||||
return events[-limit:]
|
||||
if after_seq is not None:
|
||||
events = [event for event in events if int(event["seq"]) > after_seq]
|
||||
return events[:limit]
|
||||
return events[-limit:]
|
||||
|
||||
async def count_messages(self, thread_id: str) -> int:
|
||||
return len(await self.list_messages(thread_id, limit=10**9))
|
||||
|
||||
async def delete_by_thread(self, thread_id: str) -> int:
|
||||
async with await self._thread_lock(thread_id):
|
||||
count = len(self._read_thread_events_sync(thread_id))
|
||||
shutil.rmtree(self._thread_dir(thread_id), ignore_errors=True)
|
||||
return count
|
||||
|
||||
async def delete_by_run(self, thread_id: str, run_id: str) -> int:
|
||||
async with await self._thread_lock(thread_id):
|
||||
events = self._read_thread_events_sync(thread_id)
|
||||
kept = [event for event in events if event.get("run_id") != run_id]
|
||||
deleted = len(events) - len(kept)
|
||||
if deleted:
|
||||
self._write_thread_events(thread_id, kept)
|
||||
return deleted
|
||||
|
||||
async def _thread_lock(self, thread_id: str) -> asyncio.Lock:
|
||||
async with self._locks_guard:
|
||||
lock = self._locks.get(thread_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
self._locks[thread_id] = lock
|
||||
return lock
|
||||
|
||||
def _append_thread_events(self, thread_id: str, events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
thread_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
seq = self._read_seq(thread_id)
|
||||
records: list[dict[str, Any]] = []
|
||||
with self._events_path(thread_id).open("a", encoding="utf-8") as file:
|
||||
for event in events:
|
||||
seq += 1
|
||||
record = self._normalize_event(event, seq=seq)
|
||||
file.write(json.dumps(record, ensure_ascii=False, default=str))
|
||||
file.write("\n")
|
||||
records.append(record)
|
||||
self._write_seq(thread_id, seq)
|
||||
return records
|
||||
|
||||
def _normalize_event(self, event: dict[str, Any], *, seq: int) -> dict[str, Any]:
|
||||
created_at = event.get("created_at")
|
||||
if isinstance(created_at, datetime):
|
||||
created_at_value = created_at.isoformat()
|
||||
elif created_at:
|
||||
created_at_value = str(created_at)
|
||||
else:
|
||||
created_at_value = datetime.now(UTC).isoformat()
|
||||
|
||||
return {
|
||||
"thread_id": str(event["thread_id"]),
|
||||
"run_id": str(event["run_id"]),
|
||||
"seq": seq,
|
||||
"event_type": str(event["event_type"]),
|
||||
"category": str(event["category"]),
|
||||
"content": event.get("content", ""),
|
||||
"metadata": dict(event.get("metadata") or {}),
|
||||
"created_at": created_at_value,
|
||||
}
|
||||
|
||||
async def _read_thread_events(self, thread_id: str) -> list[dict[str, Any]]:
|
||||
async with await self._thread_lock(thread_id):
|
||||
return self._read_thread_events_sync(thread_id)
|
||||
|
||||
def _read_thread_events_sync(self, thread_id: str) -> list[dict[str, Any]]:
|
||||
path = self._events_path(thread_id)
|
||||
if not path.exists():
|
||||
return []
|
||||
|
||||
events: list[dict[str, Any]] = []
|
||||
with path.open(encoding="utf-8") as file:
|
||||
for line in file:
|
||||
stripped = line.strip()
|
||||
if stripped:
|
||||
events.append(json.loads(stripped))
|
||||
return events
|
||||
|
||||
def _write_thread_events(self, thread_id: str, events: Iterable[dict[str, Any]]) -> None:
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
thread_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_path = self._events_path(thread_id).with_suffix(".jsonl.tmp")
|
||||
with temp_path.open("w", encoding="utf-8") as file:
|
||||
for event in events:
|
||||
file.write(json.dumps(event, ensure_ascii=False, default=str))
|
||||
file.write("\n")
|
||||
temp_path.replace(self._events_path(thread_id))
|
||||
|
||||
def _read_seq(self, thread_id: str) -> int:
|
||||
path = self._seq_path(thread_id)
|
||||
if not path.exists():
|
||||
return 0
|
||||
try:
|
||||
return int(path.read_text(encoding="utf-8").strip() or "0")
|
||||
except ValueError:
|
||||
return 0
|
||||
|
||||
def _write_seq(self, thread_id: str, seq: int) -> None:
|
||||
self._seq_path(thread_id).write_text(str(seq), encoding="utf-8")
|
||||
|
||||
def _thread_dir(self, thread_id: str) -> Path:
|
||||
return self._base_dir / "threads" / thread_id
|
||||
|
||||
def _events_path(self, thread_id: str) -> Path:
|
||||
return self._thread_dir(thread_id) / "events.jsonl"
|
||||
|
||||
def _seq_path(self, thread_id: str) -> Path:
|
||||
return self._thread_dir(thread_id) / "seq"
|
||||
14
backend/app/infra/storage/__init__.py
Normal file
14
backend/app/infra/storage/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""Storage-facing adapters owned by the app layer."""
|
||||
|
||||
from .run_events import AppRunEventStore
|
||||
from .runs import FeedbackStoreAdapter, RunStoreAdapter, StorageRunObserver
|
||||
from .thread_meta import ThreadMetaStorage, ThreadMetaStoreAdapter
|
||||
|
||||
__all__ = [
|
||||
"AppRunEventStore",
|
||||
"FeedbackStoreAdapter",
|
||||
"RunStoreAdapter",
|
||||
"StorageRunObserver",
|
||||
"ThreadMetaStorage",
|
||||
"ThreadMetaStoreAdapter",
|
||||
]
|
||||
166
backend/app/infra/storage/run_events.py
Normal file
166
backend/app/infra/storage/run_events.py
Normal file
@ -0,0 +1,166 @@
|
||||
"""App-owned adapter from runs callbacks to storage run event repository."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from store.repositories import RunEvent, RunEventCreate, build_run_event_repository, build_thread_meta_repository
|
||||
|
||||
from deerflow.runtime.actor_context import get_actor_context
|
||||
|
||||
|
||||
class AppRunEventStore:
|
||||
"""Implements the harness RunEventStore protocol using storage repositories."""
|
||||
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._session_factory = session_factory
|
||||
|
||||
async def put_batch(self, events: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
if not events:
|
||||
return []
|
||||
|
||||
denied = {str(event["thread_id"]) for event in events if not await self._thread_visible(str(event["thread_id"]))}
|
||||
if denied:
|
||||
raise PermissionError(f"actor is not allowed to append events for thread(s): {', '.join(sorted(denied))}")
|
||||
|
||||
async with self._session_factory() as session:
|
||||
try:
|
||||
repo = build_run_event_repository(session)
|
||||
rows = await repo.append_batch([_event_create_from_dict(event) for event in events])
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
return [_event_to_dict(row) for row in rows]
|
||||
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
if not await self._thread_visible(thread_id):
|
||||
return []
|
||||
async with self._session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
rows = await repo.list_messages(
|
||||
thread_id,
|
||||
limit=limit,
|
||||
before_seq=before_seq,
|
||||
after_seq=after_seq,
|
||||
)
|
||||
return [_event_to_dict(row) for row in rows]
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
event_types: list[str] | None = None,
|
||||
limit: int = 500,
|
||||
) -> list[dict[str, Any]]:
|
||||
if not await self._thread_visible(thread_id):
|
||||
return []
|
||||
async with self._session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
rows = await repo.list_events(thread_id, run_id, event_types=event_types, limit=limit)
|
||||
return [_event_to_dict(row) for row in rows]
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
if not await self._thread_visible(thread_id):
|
||||
return []
|
||||
async with self._session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
rows = await repo.list_messages_by_run(
|
||||
thread_id,
|
||||
run_id,
|
||||
limit=limit,
|
||||
before_seq=before_seq,
|
||||
after_seq=after_seq,
|
||||
)
|
||||
return [_event_to_dict(row) for row in rows]
|
||||
|
||||
async def count_messages(self, thread_id: str) -> int:
|
||||
if not await self._thread_visible(thread_id):
|
||||
return 0
|
||||
async with self._session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
return await repo.count_messages(thread_id)
|
||||
|
||||
async def delete_by_thread(self, thread_id: str) -> int:
|
||||
if not await self._thread_visible(thread_id):
|
||||
return 0
|
||||
async with self._session_factory() as session:
|
||||
try:
|
||||
repo = build_run_event_repository(session)
|
||||
count = await repo.delete_by_thread(thread_id)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
return count
|
||||
|
||||
async def delete_by_run(self, thread_id: str, run_id: str) -> int:
|
||||
if not await self._thread_visible(thread_id):
|
||||
return 0
|
||||
async with self._session_factory() as session:
|
||||
try:
|
||||
repo = build_run_event_repository(session)
|
||||
count = await repo.delete_by_run(thread_id, run_id)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
return count
|
||||
|
||||
async def _thread_visible(self, thread_id: str) -> bool:
|
||||
actor = get_actor_context()
|
||||
if actor is None or actor.user_id is None:
|
||||
return True
|
||||
|
||||
async with self._session_factory() as session:
|
||||
thread_repo = build_thread_meta_repository(session)
|
||||
thread = await thread_repo.get_thread_meta(thread_id)
|
||||
|
||||
if thread is None:
|
||||
return True
|
||||
return thread.user_id is None or thread.user_id == actor.user_id
|
||||
|
||||
|
||||
def _event_create_from_dict(event: dict[str, Any]) -> RunEventCreate:
|
||||
created_at = event.get("created_at")
|
||||
return RunEventCreate(
|
||||
thread_id=str(event["thread_id"]),
|
||||
run_id=str(event["run_id"]),
|
||||
event_type=str(event["event_type"]),
|
||||
category=str(event["category"]),
|
||||
content=event.get("content", ""),
|
||||
metadata=dict(event.get("metadata") or {}),
|
||||
created_at=datetime.fromisoformat(created_at) if isinstance(created_at, str) else created_at,
|
||||
)
|
||||
|
||||
|
||||
def _event_to_dict(event: RunEvent) -> dict[str, Any]:
|
||||
return {
|
||||
"thread_id": event.thread_id,
|
||||
"run_id": event.run_id,
|
||||
"event_type": event.event_type,
|
||||
"category": event.category,
|
||||
"content": event.content,
|
||||
"metadata": event.metadata,
|
||||
"seq": event.seq,
|
||||
"created_at": event.created_at.isoformat(),
|
||||
}
|
||||
515
backend/app/infra/storage/runs.py
Normal file
515
backend/app/infra/storage/runs.py
Normal file
@ -0,0 +1,515 @@
|
||||
"""Run lifecycle persistence adapters owned by the app layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from typing import Protocol, TypedDict, Unpack, cast
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from store.repositories import FeedbackCreate, Run, RunCreate, build_feedback_repository, build_run_repository
|
||||
|
||||
from deerflow.runtime.actor_context import AUTO, resolve_user_id
|
||||
from deerflow.runtime.serialization import serialize_lc_object
|
||||
from deerflow.runtime.runs.observer import LifecycleEventType, RunLifecycleEvent, RunObserver
|
||||
from deerflow.runtime.stream_bridge import JSONValue
|
||||
|
||||
from .thread_meta import ThreadMetaStorage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunCreateFields(TypedDict, total=False):
|
||||
status: str
|
||||
created_at: str
|
||||
started_at: str
|
||||
ended_at: str
|
||||
assistant_id: str | None
|
||||
user_id: str | None
|
||||
follow_up_to_run_id: str | None
|
||||
metadata: dict[str, JSONValue]
|
||||
kwargs: dict[str, JSONValue]
|
||||
|
||||
|
||||
class RunStatusUpdateFields(TypedDict, total=False):
|
||||
started_at: str
|
||||
ended_at: str
|
||||
metadata: dict[str, JSONValue]
|
||||
|
||||
|
||||
class RunCompletionFields(TypedDict, total=False):
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
total_tokens: int
|
||||
llm_call_count: int
|
||||
lead_agent_tokens: int
|
||||
subagent_tokens: int
|
||||
middleware_tokens: int
|
||||
message_count: int
|
||||
last_ai_message: str | None
|
||||
first_human_message: str | None
|
||||
error: str | None
|
||||
|
||||
|
||||
class RunRow(TypedDict, total=False):
|
||||
run_id: str
|
||||
thread_id: str
|
||||
assistant_id: str | None
|
||||
status: str
|
||||
multitask_strategy: str
|
||||
follow_up_to_run_id: str | None
|
||||
metadata: dict[str, JSONValue]
|
||||
created_at: str
|
||||
updated_at: str
|
||||
started_at: str | None
|
||||
ended_at: str | None
|
||||
error: str | None
|
||||
|
||||
|
||||
class RunReadRepository(Protocol):
|
||||
"""Protocol for durable run queries."""
|
||||
|
||||
async def get(self, run_id: str, *, user_id: str | None | object = AUTO) -> RunRow | None: ...
|
||||
|
||||
async def list_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
user_id: str | None | object = AUTO,
|
||||
) -> list[RunRow]: ...
|
||||
|
||||
|
||||
class RunWriteRepository(Protocol):
|
||||
"""Protocol for durable run writes."""
|
||||
|
||||
async def create(self, run_id: str, thread_id: str, **kwargs: Unpack[RunCreateFields]) -> None: ...
|
||||
async def update_status(
|
||||
self,
|
||||
run_id: str,
|
||||
status: str,
|
||||
**kwargs: Unpack[RunStatusUpdateFields],
|
||||
) -> None: ...
|
||||
async def set_error(self, run_id: str, error: str) -> None: ...
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
status: str,
|
||||
**kwargs: Unpack[RunCompletionFields],
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class RunDeleteRepository(Protocol):
|
||||
"""Protocol for durable run deletion."""
|
||||
|
||||
async def delete(self, run_id: str) -> bool: ...
|
||||
|
||||
|
||||
class _RepositoryContext:
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
build_repo: Callable[[AsyncSession], object],
|
||||
*,
|
||||
commit: bool,
|
||||
) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._build_repo = build_repo
|
||||
self._commit = commit
|
||||
self._session: AsyncSession | None = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self._session = self._session_factory()
|
||||
return self._build_repo(self._session)
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
if self._session is None:
|
||||
return
|
||||
try:
|
||||
if self._commit:
|
||||
if exc_type is None:
|
||||
await self._session.commit()
|
||||
else:
|
||||
await self._session.rollback()
|
||||
finally:
|
||||
await self._session.close()
|
||||
|
||||
|
||||
def _run_to_row(row: Run) -> RunRow:
|
||||
return {
|
||||
"run_id": row.run_id,
|
||||
"thread_id": row.thread_id,
|
||||
"assistant_id": row.assistant_id,
|
||||
"user_id": row.user_id,
|
||||
"status": row.status,
|
||||
"model_name": row.model_name,
|
||||
"multitask_strategy": row.multitask_strategy,
|
||||
"follow_up_to_run_id": row.follow_up_to_run_id,
|
||||
"metadata": cast(dict[str, JSONValue], row.metadata),
|
||||
"kwargs": cast(dict[str, JSONValue], row.kwargs),
|
||||
"created_at": row.created_time.isoformat(),
|
||||
"updated_at": row.updated_time.isoformat() if row.updated_time else "",
|
||||
"total_input_tokens": row.total_input_tokens,
|
||||
"total_output_tokens": row.total_output_tokens,
|
||||
"total_tokens": row.total_tokens,
|
||||
"llm_call_count": row.llm_call_count,
|
||||
"lead_agent_tokens": row.lead_agent_tokens,
|
||||
"subagent_tokens": row.subagent_tokens,
|
||||
"middleware_tokens": row.middleware_tokens,
|
||||
"message_count": row.message_count,
|
||||
"first_human_message": row.first_human_message,
|
||||
"last_ai_message": row.last_ai_message,
|
||||
"error": row.error,
|
||||
}
|
||||
|
||||
|
||||
class FeedbackStoreAdapter:
|
||||
"""Expose feedback route semantics on top of storage package repositories."""
|
||||
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._session_factory = session_factory
|
||||
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
rating: int,
|
||||
owner_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
comment: str | None = None,
|
||||
) -> dict[str, object]:
|
||||
if rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||
effective_user_id = user_id if user_id is not None else owner_id
|
||||
async with self._transaction() as repo:
|
||||
row = await repo.create_feedback(
|
||||
FeedbackCreate(
|
||||
feedback_id=str(uuid.uuid4()),
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
rating=rating,
|
||||
user_id=effective_user_id,
|
||||
message_id=message_id,
|
||||
comment=comment,
|
||||
)
|
||||
)
|
||||
return _feedback_to_dict(row)
|
||||
|
||||
async def get(self, feedback_id: str) -> dict[str, object] | None:
|
||||
async with self._read() as repo:
|
||||
row = await repo.get_feedback(feedback_id)
|
||||
return _feedback_to_dict(row) if row is not None else None
|
||||
|
||||
async def list_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
user_id: str | None = None,
|
||||
) -> list[dict[str, object]]:
|
||||
async with self._read() as repo:
|
||||
rows = await repo.list_feedback_by_run(run_id)
|
||||
filtered = [row for row in rows if row.thread_id == thread_id]
|
||||
if user_id is not None:
|
||||
filtered = [row for row in filtered if row.user_id == user_id]
|
||||
return [_feedback_to_dict(row) for row in filtered][:limit]
|
||||
|
||||
async def list_by_thread(self, thread_id: str, *, limit: int = 100) -> list[dict[str, object]]:
|
||||
async with self._read() as repo:
|
||||
rows = await repo.list_feedback_by_thread(thread_id)
|
||||
return [_feedback_to_dict(row) for row in rows][:limit]
|
||||
|
||||
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict[str, object]:
|
||||
rows = await self.list_by_run(thread_id, run_id)
|
||||
positive = sum(1 for row in rows if row["rating"] == 1)
|
||||
negative = sum(1 for row in rows if row["rating"] == -1)
|
||||
return {"run_id": run_id, "total": len(rows), "positive": positive, "negative": negative}
|
||||
|
||||
async def delete(self, feedback_id: str) -> bool:
|
||||
async with self._transaction() as repo:
|
||||
return await repo.delete_feedback(feedback_id)
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
rating: int,
|
||||
user_id: str,
|
||||
comment: str | None = None,
|
||||
) -> dict[str, object]:
|
||||
if rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||
async with self._transaction() as repo:
|
||||
rows = await repo.list_feedback_by_run(run_id)
|
||||
existing = next((row for row in rows if row.thread_id == thread_id and row.user_id == user_id), None)
|
||||
feedback_id = existing.feedback_id if existing is not None else str(uuid.uuid4())
|
||||
if existing is not None:
|
||||
await repo.delete_feedback(existing.feedback_id)
|
||||
row = await repo.create_feedback(
|
||||
FeedbackCreate(
|
||||
feedback_id=feedback_id,
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
rating=rating,
|
||||
user_id=user_id,
|
||||
comment=comment,
|
||||
)
|
||||
)
|
||||
return _feedback_to_dict(row)
|
||||
|
||||
async def delete_by_run(self, *, thread_id: str, run_id: str, user_id: str) -> bool:
|
||||
async with self._transaction() as repo:
|
||||
rows = await repo.list_feedback_by_run(run_id)
|
||||
existing = next((row for row in rows if row.thread_id == thread_id and row.user_id == user_id), None)
|
||||
if existing is None:
|
||||
return False
|
||||
return await repo.delete_feedback(existing.feedback_id)
|
||||
|
||||
async def list_by_thread_grouped(self, thread_id: str, *, user_id: str) -> dict[str, dict[str, object]]:
|
||||
rows = await self.list_by_thread(thread_id)
|
||||
return {
|
||||
row["run_id"]: row
|
||||
for row in rows
|
||||
if row["user_id"] == user_id
|
||||
}
|
||||
|
||||
def _read(self) -> _RepositoryContext:
|
||||
return _RepositoryContext(self._session_factory, build_feedback_repository, commit=False)
|
||||
|
||||
def _transaction(self) -> _RepositoryContext:
|
||||
return _RepositoryContext(self._session_factory, build_feedback_repository, commit=True)
|
||||
|
||||
|
||||
def _feedback_to_dict(row) -> dict[str, object]:
|
||||
return {
|
||||
"feedback_id": row.feedback_id,
|
||||
"run_id": row.run_id,
|
||||
"thread_id": row.thread_id,
|
||||
"user_id": row.user_id,
|
||||
"owner_id": row.user_id,
|
||||
"message_id": row.message_id,
|
||||
"rating": row.rating,
|
||||
"comment": row.comment,
|
||||
"created_at": row.created_time.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class RunStoreAdapter:
|
||||
"""Expose runs facade storage semantics on top of storage package repositories."""
|
||||
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._session_factory = session_factory
|
||||
|
||||
async def get(self, run_id: str, *, user_id: str | None | object = AUTO) -> RunRow | None:
|
||||
effective_user_id = resolve_user_id(user_id, method_name="RunStoreAdapter.get")
|
||||
async with self._read() as repo:
|
||||
row = await repo.get_run(run_id)
|
||||
if row is None:
|
||||
return None
|
||||
if effective_user_id is not None and row.user_id != effective_user_id:
|
||||
return None
|
||||
return _run_to_row(row)
|
||||
|
||||
async def list_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
user_id: str | None | object = AUTO,
|
||||
) -> list[RunRow]:
|
||||
effective_user_id = resolve_user_id(user_id, method_name="RunStoreAdapter.list_by_thread")
|
||||
async with self._read() as repo:
|
||||
rows = await repo.list_runs_by_thread(thread_id, limit=limit, offset=0)
|
||||
if effective_user_id is not None:
|
||||
rows = [row for row in rows if row.user_id == effective_user_id]
|
||||
return [_run_to_row(row) for row in rows]
|
||||
|
||||
async def create(self, run_id: str, thread_id: str, **kwargs: Unpack[RunCreateFields]) -> None:
|
||||
metadata = cast(dict[str, JSONValue], serialize_lc_object(kwargs.get("metadata") or {}))
|
||||
run_kwargs = cast(dict[str, JSONValue], serialize_lc_object(kwargs.get("kwargs") or {}))
|
||||
effective_user_id = resolve_user_id(kwargs.get("user_id", AUTO), method_name="RunStoreAdapter.create")
|
||||
async with self._transaction() as repo:
|
||||
await repo.create_run(
|
||||
RunCreate(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=kwargs.get("assistant_id"),
|
||||
user_id=effective_user_id,
|
||||
status=kwargs.get("status", "pending"),
|
||||
metadata=dict(metadata),
|
||||
kwargs=dict(run_kwargs),
|
||||
follow_up_to_run_id=kwargs.get("follow_up_to_run_id"),
|
||||
)
|
||||
)
|
||||
|
||||
async def delete(self, run_id: str, *, user_id: str | None | object = AUTO) -> bool:
|
||||
async with self._transaction() as repo:
|
||||
existing = await repo.get_run(run_id)
|
||||
if existing is None:
|
||||
return False
|
||||
effective_user_id = resolve_user_id(user_id, method_name="RunStoreAdapter.delete")
|
||||
if effective_user_id is not None and existing.user_id != effective_user_id:
|
||||
return False
|
||||
await repo.delete_run(run_id)
|
||||
return True
|
||||
|
||||
async def update_status(
|
||||
self,
|
||||
run_id: str,
|
||||
status: str,
|
||||
**kwargs: Unpack[RunStatusUpdateFields],
|
||||
) -> None:
|
||||
async with self._transaction() as repo:
|
||||
await repo.update_run_status(run_id, status)
|
||||
|
||||
async def set_error(self, run_id: str, error: str) -> None:
|
||||
async with self._transaction() as repo:
|
||||
await repo.update_run_status(run_id, "error", error=error)
|
||||
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
status: str,
|
||||
**kwargs: Unpack[RunCompletionFields],
|
||||
) -> None:
|
||||
async with self._transaction() as repo:
|
||||
await repo.update_run_completion(
|
||||
run_id,
|
||||
status=status,
|
||||
total_input_tokens=kwargs.get("total_input_tokens", 0),
|
||||
total_output_tokens=kwargs.get("total_output_tokens", 0),
|
||||
total_tokens=kwargs.get("total_tokens", 0),
|
||||
llm_call_count=kwargs.get("llm_call_count", 0),
|
||||
lead_agent_tokens=kwargs.get("lead_agent_tokens", 0),
|
||||
subagent_tokens=kwargs.get("subagent_tokens", 0),
|
||||
middleware_tokens=kwargs.get("middleware_tokens", 0),
|
||||
message_count=kwargs.get("message_count", 0),
|
||||
last_ai_message=kwargs.get("last_ai_message"),
|
||||
first_human_message=kwargs.get("first_human_message"),
|
||||
error=kwargs.get("error"),
|
||||
)
|
||||
|
||||
def _read(self) -> _RepositoryContext:
|
||||
return _RepositoryContext(self._session_factory, build_run_repository, commit=False)
|
||||
|
||||
def _transaction(self) -> _RepositoryContext:
|
||||
return _RepositoryContext(self._session_factory, build_run_repository, commit=True)
|
||||
|
||||
|
||||
class StorageRunObserver(RunObserver):
|
||||
"""Persist run lifecycle state into app-owned repositories."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
run_write_repo: RunWriteRepository | None = None,
|
||||
thread_meta_storage: ThreadMetaStorage | None = None,
|
||||
) -> None:
|
||||
self._run_write_repo = run_write_repo
|
||||
self._thread_meta_storage = thread_meta_storage
|
||||
|
||||
async def on_event(self, event: RunLifecycleEvent) -> None:
|
||||
try:
|
||||
await self._dispatch(event)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"StorageRunObserver failed to persist event %s for run %s",
|
||||
event.event_type,
|
||||
event.run_id,
|
||||
)
|
||||
|
||||
async def _dispatch(self, event: RunLifecycleEvent) -> None:
|
||||
handlers = {
|
||||
LifecycleEventType.RUN_STARTED: self._handle_run_started,
|
||||
LifecycleEventType.RUN_COMPLETED: self._handle_run_completed,
|
||||
LifecycleEventType.RUN_FAILED: self._handle_run_failed,
|
||||
LifecycleEventType.RUN_CANCELLED: self._handle_run_cancelled,
|
||||
LifecycleEventType.THREAD_STATUS_UPDATED: self._handle_thread_status,
|
||||
}
|
||||
|
||||
handler = handlers.get(event.event_type)
|
||||
if handler:
|
||||
await handler(event)
|
||||
|
||||
async def _handle_run_started(self, event: RunLifecycleEvent) -> None:
|
||||
if self._run_write_repo:
|
||||
await self._run_write_repo.update_status(
|
||||
run_id=event.run_id,
|
||||
status="running",
|
||||
started_at=event.occurred_at.isoformat(),
|
||||
)
|
||||
|
||||
async def _handle_run_completed(self, event: RunLifecycleEvent) -> None:
|
||||
payload = dict(event.payload) if event.payload else {}
|
||||
if self._run_write_repo:
|
||||
completion_data = payload.get("completion_data")
|
||||
if isinstance(completion_data, dict):
|
||||
await self._run_write_repo.update_run_completion(
|
||||
run_id=event.run_id,
|
||||
status="success",
|
||||
**cast(RunCompletionFields, completion_data),
|
||||
)
|
||||
else:
|
||||
await self._run_write_repo.update_status(
|
||||
run_id=event.run_id,
|
||||
status="success",
|
||||
ended_at=event.occurred_at.isoformat(),
|
||||
)
|
||||
|
||||
if self._thread_meta_storage and "title" in payload:
|
||||
await self._thread_meta_storage.sync_thread_title(
|
||||
thread_id=event.thread_id,
|
||||
title=payload["title"],
|
||||
)
|
||||
|
||||
async def _handle_run_failed(self, event: RunLifecycleEvent) -> None:
|
||||
if self._run_write_repo:
|
||||
payload = dict(event.payload) if event.payload else {}
|
||||
error = payload.get("error", "Unknown error")
|
||||
completion_data = payload.get("completion_data")
|
||||
if isinstance(completion_data, dict):
|
||||
await self._run_write_repo.update_run_completion(
|
||||
run_id=event.run_id,
|
||||
status="error",
|
||||
error=str(error),
|
||||
**cast(RunCompletionFields, completion_data),
|
||||
)
|
||||
else:
|
||||
await self._run_write_repo.update_status(
|
||||
run_id=event.run_id,
|
||||
status="error",
|
||||
ended_at=event.occurred_at.isoformat(),
|
||||
)
|
||||
await self._run_write_repo.set_error(run_id=event.run_id, error=str(error))
|
||||
|
||||
async def _handle_run_cancelled(self, event: RunLifecycleEvent) -> None:
|
||||
if self._run_write_repo:
|
||||
payload = dict(event.payload) if event.payload else {}
|
||||
completion_data = payload.get("completion_data")
|
||||
if isinstance(completion_data, dict):
|
||||
await self._run_write_repo.update_run_completion(
|
||||
run_id=event.run_id,
|
||||
status="interrupted",
|
||||
**cast(RunCompletionFields, completion_data),
|
||||
)
|
||||
else:
|
||||
await self._run_write_repo.update_status(
|
||||
run_id=event.run_id,
|
||||
status="interrupted",
|
||||
ended_at=event.occurred_at.isoformat(),
|
||||
)
|
||||
|
||||
async def _handle_thread_status(self, event: RunLifecycleEvent) -> None:
|
||||
if self._thread_meta_storage:
|
||||
payload = dict(event.payload) if event.payload else {}
|
||||
status = payload.get("status", "idle")
|
||||
await self._thread_meta_storage.sync_thread_status(
|
||||
thread_id=event.thread_id,
|
||||
status=status,
|
||||
)
|
||||
208
backend/app/infra/storage/thread_meta.py
Normal file
208
backend/app/infra/storage/thread_meta.py
Normal file
@ -0,0 +1,208 @@
|
||||
"""Thread metadata storage adapter owned by the app layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from store.repositories import build_thread_meta_repository
|
||||
from store.repositories.contracts import (
|
||||
ThreadMeta,
|
||||
ThreadMetaCreate,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
)
|
||||
from deerflow.runtime.actor_context import AUTO, resolve_user_id
|
||||
|
||||
|
||||
class ThreadMetaStoreAdapter:
|
||||
"""Use storage package thread repositories with per-call sessions."""
|
||||
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._session_factory = session_factory
|
||||
|
||||
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
|
||||
async with self._transaction() as repo:
|
||||
return await repo.create_thread_meta(data)
|
||||
|
||||
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
|
||||
async with self._read() as repo:
|
||||
return await repo.get_thread_meta(thread_id)
|
||||
|
||||
async def update_thread_meta(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
display_name: str | None = None,
|
||||
status: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
async with self._transaction() as repo:
|
||||
await repo.update_thread_meta(
|
||||
thread_id,
|
||||
assistant_id=assistant_id,
|
||||
display_name=display_name,
|
||||
status=status,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def delete_thread(self, thread_id: str) -> None:
|
||||
async with self._transaction() as repo:
|
||||
await repo.delete_thread(thread_id)
|
||||
|
||||
async def search_threads(
|
||||
self,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
user_id: str | None = None,
|
||||
assistant_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ThreadMeta]:
|
||||
async with self._read() as repo:
|
||||
return await repo.search_threads(
|
||||
metadata=metadata,
|
||||
status=status,
|
||||
user_id=user_id,
|
||||
assistant_id=assistant_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
def _read(self):
|
||||
return _ThreadMetaRepositoryContext(self._session_factory, commit=False)
|
||||
|
||||
def _transaction(self):
|
||||
return _ThreadMetaRepositoryContext(self._session_factory, commit=True)
|
||||
|
||||
|
||||
class _ThreadMetaRepositoryContext:
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, commit: bool) -> None:
|
||||
self._session_factory = session_factory
|
||||
self._commit = commit
|
||||
self._session: AsyncSession | None = None
|
||||
|
||||
async def __aenter__(self):
|
||||
self._session = self._session_factory()
|
||||
return build_thread_meta_repository(self._session)
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb) -> None:
|
||||
if self._session is None:
|
||||
return
|
||||
try:
|
||||
if self._commit:
|
||||
if exc_type is None:
|
||||
await self._session.commit()
|
||||
else:
|
||||
await self._session.rollback()
|
||||
finally:
|
||||
await self._session.close()
|
||||
|
||||
|
||||
class ThreadMetaStorage:
|
||||
"""App-facing adapter around the storage thread metadata contract."""
|
||||
|
||||
def __init__(self, repo: ThreadMetaRepositoryProtocol) -> None:
|
||||
self._repo = repo
|
||||
|
||||
async def get_thread(self, thread_id: str, *, user_id: str | None | object = AUTO) -> ThreadMeta | None:
|
||||
thread = await self._repo.get_thread_meta(thread_id)
|
||||
if thread is None:
|
||||
return None
|
||||
effective_user_id = resolve_user_id(user_id, method_name="ThreadMetaStorage.get_thread")
|
||||
if effective_user_id is not None and thread.user_id != effective_user_id:
|
||||
return None
|
||||
return thread
|
||||
|
||||
async def ensure_thread(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
assistant_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
user_id: str | None | object = AUTO,
|
||||
) -> ThreadMeta:
|
||||
effective_user_id = resolve_user_id(user_id, method_name="ThreadMetaStorage.ensure_thread")
|
||||
existing = await self.get_thread(thread_id, user_id=effective_user_id)
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
return await self._repo.create_thread_meta(
|
||||
ThreadMetaCreate(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
user_id=effective_user_id,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
)
|
||||
|
||||
async def ensure_thread_running(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
assistant_id: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> ThreadMeta | None:
|
||||
existing = await self._repo.get_thread_meta(thread_id)
|
||||
if existing is None:
|
||||
return await self._repo.create_thread_meta(
|
||||
ThreadMetaCreate(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
status="running",
|
||||
metadata=metadata or {},
|
||||
)
|
||||
)
|
||||
|
||||
await self._repo.update_thread_meta(thread_id, status="running")
|
||||
return await self._repo.get_thread_meta(thread_id)
|
||||
|
||||
async def sync_thread_title(self, *, thread_id: str, title: str) -> None:
|
||||
await self._repo.update_thread_meta(thread_id, display_name=title)
|
||||
|
||||
async def sync_thread_assistant_id(self, *, thread_id: str, assistant_id: str) -> None:
|
||||
await self._repo.update_thread_meta(thread_id, assistant_id=assistant_id)
|
||||
|
||||
async def sync_thread_status(self, *, thread_id: str, status: str) -> None:
|
||||
await self._repo.update_thread_meta(thread_id, status=status)
|
||||
|
||||
async def sync_thread_metadata(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
metadata: dict[str, Any],
|
||||
) -> None:
|
||||
await self._repo.update_thread_meta(thread_id, metadata=metadata)
|
||||
|
||||
async def delete_thread(self, thread_id: str) -> None:
|
||||
await self._repo.delete_thread(thread_id)
|
||||
|
||||
async def search_threads(
|
||||
self,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
user_id: str | None | object = AUTO,
|
||||
assistant_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ThreadMeta]:
|
||||
normalized_status = status.strip() if status is not None else None
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaStorage.search_threads")
|
||||
normalized_user_id = resolved_user_id.strip() if resolved_user_id is not None else None
|
||||
normalized_assistant_id = (
|
||||
assistant_id.strip() if assistant_id is not None else None
|
||||
)
|
||||
|
||||
return await self._repo.search_threads(
|
||||
metadata=metadata,
|
||||
status=normalized_status or None,
|
||||
user_id=normalized_user_id or None,
|
||||
assistant_id=normalized_assistant_id or None,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["ThreadMetaStorage", "ThreadMetaStoreAdapter"]
|
||||
6
backend/app/infra/stream_bridge/__init__.py
Normal file
6
backend/app/infra/stream_bridge/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""App-owned stream bridge adapters and factory."""
|
||||
|
||||
from .factory import build_stream_bridge
|
||||
from .adapters import MemoryStreamBridge, RedisStreamBridge
|
||||
|
||||
__all__ = ["MemoryStreamBridge", "RedisStreamBridge", "build_stream_bridge"]
|
||||
6
backend/app/infra/stream_bridge/adapters/__init__.py
Normal file
6
backend/app/infra/stream_bridge/adapters/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
"""Concrete stream bridge adapters owned by the app layer."""
|
||||
|
||||
from .memory import MemoryStreamBridge
|
||||
from .redis import RedisStreamBridge
|
||||
|
||||
__all__ = ["MemoryStreamBridge", "RedisStreamBridge"]
|
||||
450
backend/app/infra/stream_bridge/adapters/memory.py
Normal file
450
backend/app/infra/stream_bridge/adapters/memory.py
Normal file
@ -0,0 +1,450 @@
|
||||
"""In-memory stream bridge implementation owned by the app layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
from deerflow.runtime.stream_bridge import (
|
||||
CANCELLED_SENTINEL,
|
||||
END_SENTINEL,
|
||||
HEARTBEAT_SENTINEL,
|
||||
TERMINAL_STATES,
|
||||
ResumeResult,
|
||||
StreamBridge,
|
||||
StreamEvent,
|
||||
StreamStatus,
|
||||
)
|
||||
from deerflow.runtime.stream_bridge.exceptions import (
|
||||
BridgeClosedError,
|
||||
StreamCapacityExceededError,
|
||||
StreamTerminatedError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RunStream:
|
||||
condition: asyncio.Condition = field(default_factory=asyncio.Condition)
|
||||
events: list[StreamEvent] = field(default_factory=list)
|
||||
id_to_offset: dict[str, int] = field(default_factory=dict)
|
||||
start_offset: int = 0
|
||||
current_bytes: int = 0
|
||||
seq: int = 0
|
||||
status: StreamStatus = StreamStatus.ACTIVE
|
||||
created_at: float = field(default_factory=time.monotonic)
|
||||
last_publish_at: float | None = None
|
||||
ended_at: float | None = None
|
||||
subscriber_count: int = 0
|
||||
last_subscribe_at: float | None = None
|
||||
awaiting_input: bool = False
|
||||
awaiting_since: float | None = None
|
||||
|
||||
|
||||
class MemoryStreamBridge(StreamBridge):
|
||||
"""Per-run in-memory event log implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_events_per_stream: int = 256,
|
||||
max_bytes_per_stream: int = 10 * 1024 * 1024,
|
||||
max_active_streams: int = 1000,
|
||||
stream_eviction_policy: Literal["reject", "lru"] = "lru",
|
||||
terminal_retention_ttl: float = 300.0,
|
||||
active_no_publish_timeout: float = 600.0,
|
||||
orphan_timeout: float = 60.0,
|
||||
max_stream_age: float = 86400.0,
|
||||
hitl_extended_timeout: float = 7200.0,
|
||||
cleanup_interval: float = 30.0,
|
||||
queue_maxsize: int | None = None,
|
||||
) -> None:
|
||||
if queue_maxsize is not None:
|
||||
max_events_per_stream = queue_maxsize
|
||||
|
||||
self._max_events = max_events_per_stream
|
||||
self._max_bytes = max_bytes_per_stream
|
||||
self._max_streams = max_active_streams
|
||||
self._eviction_policy = stream_eviction_policy
|
||||
self._terminal_ttl = terminal_retention_ttl
|
||||
self._active_timeout = active_no_publish_timeout
|
||||
self._orphan_timeout = orphan_timeout
|
||||
self._max_age = max_stream_age
|
||||
self._hitl_timeout = hitl_extended_timeout
|
||||
self._cleanup_interval = cleanup_interval
|
||||
self._streams: dict[str, _RunStream] = {}
|
||||
self._registry_lock = asyncio.Lock()
|
||||
self._closed = False
|
||||
self._cleanup_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._cleanup_task is None:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info(
|
||||
"MemoryStreamBridge started (max_events=%d, max_bytes=%d, max_streams=%d)",
|
||||
self._max_events,
|
||||
self._max_bytes,
|
||||
self._max_streams,
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
async with self._registry_lock:
|
||||
self._closed = True
|
||||
if self._cleanup_task is not None:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
|
||||
for stream in self._streams.values():
|
||||
async with stream.condition:
|
||||
stream.status = StreamStatus.CLOSED
|
||||
stream.condition.notify_all()
|
||||
|
||||
self._streams.clear()
|
||||
logger.info("MemoryStreamBridge closed")
|
||||
|
||||
async def _get_or_create_stream(self, run_id: str) -> _RunStream:
|
||||
stream = self._streams.get(run_id)
|
||||
if stream is not None:
|
||||
return stream
|
||||
|
||||
async with self._registry_lock:
|
||||
if self._closed:
|
||||
raise BridgeClosedError("Stream bridge is closed")
|
||||
|
||||
stream = self._streams.get(run_id)
|
||||
if stream is not None:
|
||||
return stream
|
||||
|
||||
if len(self._streams) >= self._max_streams:
|
||||
if self._eviction_policy == "reject":
|
||||
raise StreamCapacityExceededError(
|
||||
f"Max {self._max_streams} active streams reached"
|
||||
)
|
||||
evicted = self._evict_oldest_terminal()
|
||||
if evicted is None:
|
||||
raise StreamCapacityExceededError("All streams active, cannot evict")
|
||||
logger.info("Evicted stream %s to make room", evicted)
|
||||
|
||||
stream = _RunStream()
|
||||
self._streams[run_id] = stream
|
||||
logger.debug("Created stream for run %s", run_id)
|
||||
return stream
|
||||
|
||||
def _evict_oldest_terminal(self) -> str | None:
|
||||
oldest_run_id: str | None = None
|
||||
oldest_ended_at: float = float("inf")
|
||||
for run_id, stream in self._streams.items():
|
||||
if stream.status in TERMINAL_STATES and stream.ended_at is not None:
|
||||
if stream.ended_at < oldest_ended_at:
|
||||
oldest_ended_at = stream.ended_at
|
||||
oldest_run_id = run_id
|
||||
if oldest_run_id is not None:
|
||||
del self._streams[oldest_run_id]
|
||||
return oldest_run_id
|
||||
return None
|
||||
|
||||
def _next_id(self, stream: _RunStream) -> str:
|
||||
stream.seq += 1
|
||||
return f"{int(time.time() * 1000)}-{stream.seq}"
|
||||
|
||||
def _estimate_size(self, event: StreamEvent) -> int:
|
||||
base = len(event.id) + len(event.event) + 100
|
||||
if event.data is None:
|
||||
return base
|
||||
if isinstance(event.data, str):
|
||||
return base + len(event.data)
|
||||
if isinstance(event.data, (dict, list)):
|
||||
try:
|
||||
return base + len(json.dumps(event.data, default=str))
|
||||
except (TypeError, ValueError):
|
||||
return base + 200
|
||||
return base + 50
|
||||
|
||||
def _evict_overflow(self, stream: _RunStream) -> None:
|
||||
while len(stream.events) > self._max_events or stream.current_bytes > self._max_bytes:
|
||||
if not stream.events:
|
||||
break
|
||||
evicted = stream.events.pop(0)
|
||||
stream.id_to_offset.pop(evicted.id, None)
|
||||
stream.current_bytes -= self._estimate_size(evicted)
|
||||
stream.start_offset += 1
|
||||
|
||||
async def publish(self, run_id: str, event: str, data: Any) -> str:
|
||||
stream = await self._get_or_create_stream(run_id)
|
||||
async with stream.condition:
|
||||
if stream.status != StreamStatus.ACTIVE:
|
||||
raise StreamTerminatedError(
|
||||
f"Cannot publish to {stream.status.value} stream"
|
||||
)
|
||||
|
||||
entry = StreamEvent(id=self._next_id(stream), event=event, data=data)
|
||||
absolute_offset = stream.start_offset + len(stream.events)
|
||||
stream.events.append(entry)
|
||||
stream.id_to_offset[entry.id] = absolute_offset
|
||||
stream.current_bytes += self._estimate_size(entry)
|
||||
stream.last_publish_at = time.monotonic()
|
||||
self._evict_overflow(stream)
|
||||
stream.condition.notify_all()
|
||||
return entry.id
|
||||
|
||||
async def publish_end(self, run_id: str) -> str:
|
||||
return await self.publish_terminal(run_id, StreamStatus.ENDED)
|
||||
|
||||
async def publish_terminal(
|
||||
self,
|
||||
run_id: str,
|
||||
kind: StreamStatus,
|
||||
data: Any = None,
|
||||
) -> str:
|
||||
if kind not in TERMINAL_STATES:
|
||||
raise ValueError(f"Invalid terminal kind: {kind}")
|
||||
|
||||
stream = await self._get_or_create_stream(run_id)
|
||||
async with stream.condition:
|
||||
if stream.status != StreamStatus.ACTIVE:
|
||||
for evt in reversed(stream.events):
|
||||
if evt.event in ("end", "cancel", "error", "dead_letter"):
|
||||
return evt.id
|
||||
return ""
|
||||
|
||||
event_name = {
|
||||
StreamStatus.ENDED: "end",
|
||||
StreamStatus.CANCELLED: "cancel",
|
||||
StreamStatus.ERRORED: "error",
|
||||
}[kind]
|
||||
entry = StreamEvent(id=self._next_id(stream), event=event_name, data=data)
|
||||
absolute_offset = stream.start_offset + len(stream.events)
|
||||
stream.events.append(entry)
|
||||
stream.id_to_offset[entry.id] = absolute_offset
|
||||
stream.current_bytes += self._estimate_size(entry)
|
||||
stream.status = kind
|
||||
stream.ended_at = time.monotonic()
|
||||
stream.awaiting_input = False
|
||||
stream.condition.notify_all()
|
||||
logger.debug("Stream %s terminal: %s", run_id, kind.value)
|
||||
return entry.id
|
||||
|
||||
async def cancel(self, run_id: str) -> None:
|
||||
await self.publish_terminal(run_id, StreamStatus.CANCELLED)
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
last_event_id: str | None = None,
|
||||
heartbeat_interval: float = 15.0,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
stream = await self._get_or_create_stream(run_id)
|
||||
resume = self._resolve_resume_point(stream, last_event_id)
|
||||
next_offset = resume.next_offset
|
||||
|
||||
async with stream.condition:
|
||||
stream.subscriber_count += 1
|
||||
stream.last_subscribe_at = time.monotonic()
|
||||
|
||||
try:
|
||||
while True:
|
||||
entry_to_yield: StreamEvent | None = None
|
||||
sentinel_to_yield: StreamEvent | None = None
|
||||
should_return = False
|
||||
should_wait = False
|
||||
|
||||
async with stream.condition:
|
||||
if self._closed or stream.status == StreamStatus.CLOSED:
|
||||
sentinel_to_yield = CANCELLED_SENTINEL
|
||||
should_return = True
|
||||
elif next_offset < stream.start_offset:
|
||||
next_offset = stream.start_offset
|
||||
else:
|
||||
local_index = next_offset - stream.start_offset
|
||||
if 0 <= local_index < len(stream.events):
|
||||
entry_to_yield = stream.events[local_index]
|
||||
next_offset += 1
|
||||
if entry_to_yield.event in ("end", "cancel", "error", "dead_letter"):
|
||||
should_return = True
|
||||
elif stream.status in TERMINAL_STATES:
|
||||
sentinel_to_yield = END_SENTINEL
|
||||
should_return = True
|
||||
else:
|
||||
should_wait = True
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
stream.condition.wait(),
|
||||
timeout=heartbeat_interval,
|
||||
)
|
||||
except TimeoutError:
|
||||
pass
|
||||
|
||||
if sentinel_to_yield is not None:
|
||||
yield sentinel_to_yield
|
||||
if should_return:
|
||||
return
|
||||
continue
|
||||
|
||||
if entry_to_yield is not None:
|
||||
yield entry_to_yield
|
||||
if should_return:
|
||||
return
|
||||
continue
|
||||
|
||||
if should_wait:
|
||||
async with stream.condition:
|
||||
local_index = next_offset - stream.start_offset
|
||||
has_events = 0 <= local_index < len(stream.events)
|
||||
is_terminal = stream.status in TERMINAL_STATES
|
||||
if not has_events and not is_terminal:
|
||||
yield HEARTBEAT_SENTINEL
|
||||
|
||||
finally:
|
||||
async with stream.condition:
|
||||
stream.subscriber_count = max(0, stream.subscriber_count - 1)
|
||||
|
||||
async def mark_awaiting_input(self, run_id: str) -> None:
|
||||
stream = self._streams.get(run_id)
|
||||
if stream is None:
|
||||
return
|
||||
async with stream.condition:
|
||||
if stream.status == StreamStatus.ACTIVE:
|
||||
stream.awaiting_input = True
|
||||
stream.awaiting_since = time.monotonic()
|
||||
logger.debug("Stream %s marked as awaiting input", run_id)
|
||||
|
||||
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
await self._do_cleanup(run_id, "manual")
|
||||
|
||||
async def _do_cleanup(self, run_id: str, reason: str) -> None:
|
||||
async with self._registry_lock:
|
||||
stream = self._streams.pop(run_id, None)
|
||||
if stream is not None:
|
||||
async with stream.condition:
|
||||
stream.status = StreamStatus.CLOSED
|
||||
stream.condition.notify_all()
|
||||
logger.debug("Cleaned up stream %s (reason: %s)", run_id, reason)
|
||||
|
||||
async def _mark_dead_letter(self, run_id: str, reason: str) -> None:
|
||||
stream = self._streams.get(run_id)
|
||||
if stream is None:
|
||||
return
|
||||
async with stream.condition:
|
||||
if stream.status != StreamStatus.ACTIVE:
|
||||
return
|
||||
entry = StreamEvent(
|
||||
id=self._next_id(stream),
|
||||
event="dead_letter",
|
||||
data={"reason": reason, "timestamp": time.time()},
|
||||
)
|
||||
absolute_offset = stream.start_offset + len(stream.events)
|
||||
stream.events.append(entry)
|
||||
stream.id_to_offset[entry.id] = absolute_offset
|
||||
stream.current_bytes += self._estimate_size(entry)
|
||||
stream.status = StreamStatus.ERRORED
|
||||
stream.ended_at = time.monotonic()
|
||||
stream.condition.notify_all()
|
||||
logger.warning("Stream %s marked as dead letter: %s", run_id, reason)
|
||||
|
||||
async def _cleanup_loop(self) -> None:
|
||||
while not self._closed:
|
||||
try:
|
||||
await asyncio.sleep(self._cleanup_interval)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
now = time.monotonic()
|
||||
to_cleanup: list[tuple[str, str]] = []
|
||||
to_mark_dead: list[tuple[str, str]] = []
|
||||
|
||||
async with self._registry_lock:
|
||||
for run_id, stream in list(self._streams.items()):
|
||||
if now - stream.created_at > self._max_age:
|
||||
to_cleanup.append((run_id, "max_age_exceeded"))
|
||||
continue
|
||||
|
||||
if stream.status == StreamStatus.ACTIVE:
|
||||
timeout = self._hitl_timeout if stream.awaiting_input else self._active_timeout
|
||||
last_activity = stream.last_publish_at or stream.created_at
|
||||
if now - last_activity > timeout:
|
||||
to_mark_dead.append((run_id, "no_publish_timeout"))
|
||||
continue
|
||||
|
||||
if stream.status in TERMINAL_STATES and stream.ended_at:
|
||||
if stream.subscriber_count > 0:
|
||||
continue
|
||||
last_sub = stream.last_subscribe_at or stream.ended_at
|
||||
if now - last_sub > self._orphan_timeout:
|
||||
to_cleanup.append((run_id, "orphan"))
|
||||
continue
|
||||
if now - stream.ended_at > self._terminal_ttl:
|
||||
to_cleanup.append((run_id, "ttl_expired"))
|
||||
|
||||
for run_id, reason in to_mark_dead:
|
||||
await self._mark_dead_letter(run_id, reason)
|
||||
for run_id, reason in to_cleanup:
|
||||
await self._do_cleanup(run_id, reason)
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
active = sum(1 for s in self._streams.values() if s.status == StreamStatus.ACTIVE)
|
||||
terminal = sum(1 for s in self._streams.values() if s.status in TERMINAL_STATES)
|
||||
total_events = sum(len(s.events) for s in self._streams.values())
|
||||
total_bytes = sum(s.current_bytes for s in self._streams.values())
|
||||
total_subs = sum(s.subscriber_count for s in self._streams.values())
|
||||
return {
|
||||
"total_streams": len(self._streams),
|
||||
"active_streams": active,
|
||||
"terminal_streams": terminal,
|
||||
"total_events": total_events,
|
||||
"total_bytes": total_bytes,
|
||||
"total_subscribers": total_subs,
|
||||
"closed": self._closed,
|
||||
}
|
||||
|
||||
def _resolve_resume_point(
|
||||
self,
|
||||
stream: _RunStream,
|
||||
last_event_id: str | None,
|
||||
) -> ResumeResult:
|
||||
if last_event_id is None:
|
||||
return ResumeResult(next_offset=stream.start_offset, status="fresh")
|
||||
if last_event_id in stream.id_to_offset:
|
||||
return ResumeResult(
|
||||
next_offset=stream.id_to_offset[last_event_id] + 1,
|
||||
status="resumed",
|
||||
)
|
||||
|
||||
parts = last_event_id.split("-")
|
||||
if len(parts) != 2:
|
||||
return ResumeResult(next_offset=stream.start_offset, status="invalid")
|
||||
try:
|
||||
event_ts = int(parts[0])
|
||||
_event_seq = int(parts[1])
|
||||
except ValueError:
|
||||
return ResumeResult(next_offset=stream.start_offset, status="invalid")
|
||||
|
||||
if stream.events:
|
||||
try:
|
||||
oldest_parts = stream.events[0].id.split("-")
|
||||
oldest_ts = int(oldest_parts[0])
|
||||
if event_ts < oldest_ts:
|
||||
return ResumeResult(
|
||||
next_offset=stream.start_offset,
|
||||
status="evicted",
|
||||
gap_count=stream.start_offset,
|
||||
)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
return ResumeResult(next_offset=stream.start_offset, status="unknown")
|
||||
|
||||
|
||||
__all__ = ["MemoryStreamBridge"]
|
||||
37
backend/app/infra/stream_bridge/adapters/redis.py
Normal file
37
backend/app/infra/stream_bridge/adapters/redis.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""Redis-backed stream bridge placeholder owned by the app layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from deerflow.runtime.stream_bridge import StreamBridge, StreamEvent
|
||||
|
||||
|
||||
class RedisStreamBridge(StreamBridge):
|
||||
"""Reserved app-owned Redis implementation.
|
||||
|
||||
Phase 1 intentionally keeps Redis out of the harness package. The concrete
|
||||
implementation will live here once cross-process streaming is introduced.
|
||||
"""
|
||||
|
||||
def __init__(self, *, redis_url: str) -> None:
|
||||
self._redis_url = redis_url
|
||||
|
||||
async def publish(self, run_id: str, event: str, data: Any) -> str:
|
||||
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
|
||||
|
||||
async def publish_end(self, run_id: str) -> str:
|
||||
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
last_event_id: str | None = None,
|
||||
heartbeat_interval: float = 15.0,
|
||||
) -> AsyncIterator[StreamEvent]:
|
||||
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
|
||||
|
||||
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
|
||||
raise NotImplementedError("Redis stream bridge will be implemented in app infra")
|
||||
50
backend/app/infra/stream_bridge/factory.py
Normal file
50
backend/app/infra/stream_bridge/factory.py
Normal file
@ -0,0 +1,50 @@
|
||||
"""App-owned stream bridge factory."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import AbstractAsyncContextManager, asynccontextmanager
|
||||
|
||||
from deerflow.config.stream_bridge_config import get_stream_bridge_config
|
||||
from deerflow.runtime.stream_bridge import StreamBridge
|
||||
|
||||
from .adapters import MemoryStreamBridge, RedisStreamBridge
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def build_stream_bridge(config=None) -> AbstractAsyncContextManager[StreamBridge]:
|
||||
"""Build the configured app-owned stream bridge."""
|
||||
return _build_stream_bridge_impl(config)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _build_stream_bridge_impl(config=None) -> AsyncIterator[StreamBridge]:
|
||||
if config is None:
|
||||
config = get_stream_bridge_config()
|
||||
|
||||
if config is None or config.type == "memory":
|
||||
maxsize = config.queue_maxsize if config is not None else 256
|
||||
bridge = MemoryStreamBridge(queue_maxsize=maxsize)
|
||||
await bridge.start()
|
||||
logger.info("Stream bridge initialised: memory (queue_maxsize=%d)", maxsize)
|
||||
try:
|
||||
yield bridge
|
||||
finally:
|
||||
await bridge.close()
|
||||
return
|
||||
|
||||
if config.type == "redis":
|
||||
if not config.redis_url:
|
||||
raise ValueError("Redis stream bridge requires redis_url")
|
||||
bridge = RedisStreamBridge(redis_url=config.redis_url)
|
||||
await bridge.start()
|
||||
logger.info("Stream bridge initialised: redis (%s)", config.redis_url)
|
||||
try:
|
||||
yield bridge
|
||||
finally:
|
||||
await bridge.close()
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown stream bridge type: {config.type!r}")
|
||||
Loading…
x
Reference in New Issue
Block a user