mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-15 13:13:45 +00:00
Add app/infra/ package with: - storage/ - repository adapters for runs, run_events, thread_meta - run_events/ - JSONL-based event store with factory - stream_bridge/ - memory and redis adapters for SSE streaming This layer provides the persistence abstractions used by the gateway services, replacing the old deerflow/persistence modules. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
451 lines
17 KiB
Python
451 lines
17 KiB
Python
"""In-memory stream bridge implementation owned by the app layer."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
from collections.abc import AsyncIterator
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Literal
|
|
|
|
from deerflow.runtime.stream_bridge import (
|
|
CANCELLED_SENTINEL,
|
|
END_SENTINEL,
|
|
HEARTBEAT_SENTINEL,
|
|
TERMINAL_STATES,
|
|
ResumeResult,
|
|
StreamBridge,
|
|
StreamEvent,
|
|
StreamStatus,
|
|
)
|
|
from deerflow.runtime.stream_bridge.exceptions import (
|
|
BridgeClosedError,
|
|
StreamCapacityExceededError,
|
|
StreamTerminatedError,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class _RunStream:
|
|
condition: asyncio.Condition = field(default_factory=asyncio.Condition)
|
|
events: list[StreamEvent] = field(default_factory=list)
|
|
id_to_offset: dict[str, int] = field(default_factory=dict)
|
|
start_offset: int = 0
|
|
current_bytes: int = 0
|
|
seq: int = 0
|
|
status: StreamStatus = StreamStatus.ACTIVE
|
|
created_at: float = field(default_factory=time.monotonic)
|
|
last_publish_at: float | None = None
|
|
ended_at: float | None = None
|
|
subscriber_count: int = 0
|
|
last_subscribe_at: float | None = None
|
|
awaiting_input: bool = False
|
|
awaiting_since: float | None = None
|
|
|
|
|
|
class MemoryStreamBridge(StreamBridge):
|
|
"""Per-run in-memory event log implementation."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
max_events_per_stream: int = 256,
|
|
max_bytes_per_stream: int = 10 * 1024 * 1024,
|
|
max_active_streams: int = 1000,
|
|
stream_eviction_policy: Literal["reject", "lru"] = "lru",
|
|
terminal_retention_ttl: float = 300.0,
|
|
active_no_publish_timeout: float = 600.0,
|
|
orphan_timeout: float = 60.0,
|
|
max_stream_age: float = 86400.0,
|
|
hitl_extended_timeout: float = 7200.0,
|
|
cleanup_interval: float = 30.0,
|
|
queue_maxsize: int | None = None,
|
|
) -> None:
|
|
if queue_maxsize is not None:
|
|
max_events_per_stream = queue_maxsize
|
|
|
|
self._max_events = max_events_per_stream
|
|
self._max_bytes = max_bytes_per_stream
|
|
self._max_streams = max_active_streams
|
|
self._eviction_policy = stream_eviction_policy
|
|
self._terminal_ttl = terminal_retention_ttl
|
|
self._active_timeout = active_no_publish_timeout
|
|
self._orphan_timeout = orphan_timeout
|
|
self._max_age = max_stream_age
|
|
self._hitl_timeout = hitl_extended_timeout
|
|
self._cleanup_interval = cleanup_interval
|
|
self._streams: dict[str, _RunStream] = {}
|
|
self._registry_lock = asyncio.Lock()
|
|
self._closed = False
|
|
self._cleanup_task: asyncio.Task[None] | None = None
|
|
|
|
async def start(self) -> None:
|
|
if self._cleanup_task is None:
|
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
|
logger.info(
|
|
"MemoryStreamBridge started (max_events=%d, max_bytes=%d, max_streams=%d)",
|
|
self._max_events,
|
|
self._max_bytes,
|
|
self._max_streams,
|
|
)
|
|
|
|
async def close(self) -> None:
|
|
async with self._registry_lock:
|
|
self._closed = True
|
|
if self._cleanup_task is not None:
|
|
self._cleanup_task.cancel()
|
|
try:
|
|
await self._cleanup_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._cleanup_task = None
|
|
|
|
for stream in self._streams.values():
|
|
async with stream.condition:
|
|
stream.status = StreamStatus.CLOSED
|
|
stream.condition.notify_all()
|
|
|
|
self._streams.clear()
|
|
logger.info("MemoryStreamBridge closed")
|
|
|
|
async def _get_or_create_stream(self, run_id: str) -> _RunStream:
|
|
stream = self._streams.get(run_id)
|
|
if stream is not None:
|
|
return stream
|
|
|
|
async with self._registry_lock:
|
|
if self._closed:
|
|
raise BridgeClosedError("Stream bridge is closed")
|
|
|
|
stream = self._streams.get(run_id)
|
|
if stream is not None:
|
|
return stream
|
|
|
|
if len(self._streams) >= self._max_streams:
|
|
if self._eviction_policy == "reject":
|
|
raise StreamCapacityExceededError(
|
|
f"Max {self._max_streams} active streams reached"
|
|
)
|
|
evicted = self._evict_oldest_terminal()
|
|
if evicted is None:
|
|
raise StreamCapacityExceededError("All streams active, cannot evict")
|
|
logger.info("Evicted stream %s to make room", evicted)
|
|
|
|
stream = _RunStream()
|
|
self._streams[run_id] = stream
|
|
logger.debug("Created stream for run %s", run_id)
|
|
return stream
|
|
|
|
def _evict_oldest_terminal(self) -> str | None:
|
|
oldest_run_id: str | None = None
|
|
oldest_ended_at: float = float("inf")
|
|
for run_id, stream in self._streams.items():
|
|
if stream.status in TERMINAL_STATES and stream.ended_at is not None:
|
|
if stream.ended_at < oldest_ended_at:
|
|
oldest_ended_at = stream.ended_at
|
|
oldest_run_id = run_id
|
|
if oldest_run_id is not None:
|
|
del self._streams[oldest_run_id]
|
|
return oldest_run_id
|
|
return None
|
|
|
|
def _next_id(self, stream: _RunStream) -> str:
|
|
stream.seq += 1
|
|
return f"{int(time.time() * 1000)}-{stream.seq}"
|
|
|
|
def _estimate_size(self, event: StreamEvent) -> int:
|
|
base = len(event.id) + len(event.event) + 100
|
|
if event.data is None:
|
|
return base
|
|
if isinstance(event.data, str):
|
|
return base + len(event.data)
|
|
if isinstance(event.data, (dict, list)):
|
|
try:
|
|
return base + len(json.dumps(event.data, default=str))
|
|
except (TypeError, ValueError):
|
|
return base + 200
|
|
return base + 50
|
|
|
|
def _evict_overflow(self, stream: _RunStream) -> None:
|
|
while len(stream.events) > self._max_events or stream.current_bytes > self._max_bytes:
|
|
if not stream.events:
|
|
break
|
|
evicted = stream.events.pop(0)
|
|
stream.id_to_offset.pop(evicted.id, None)
|
|
stream.current_bytes -= self._estimate_size(evicted)
|
|
stream.start_offset += 1
|
|
|
|
async def publish(self, run_id: str, event: str, data: Any) -> str:
|
|
stream = await self._get_or_create_stream(run_id)
|
|
async with stream.condition:
|
|
if stream.status != StreamStatus.ACTIVE:
|
|
raise StreamTerminatedError(
|
|
f"Cannot publish to {stream.status.value} stream"
|
|
)
|
|
|
|
entry = StreamEvent(id=self._next_id(stream), event=event, data=data)
|
|
absolute_offset = stream.start_offset + len(stream.events)
|
|
stream.events.append(entry)
|
|
stream.id_to_offset[entry.id] = absolute_offset
|
|
stream.current_bytes += self._estimate_size(entry)
|
|
stream.last_publish_at = time.monotonic()
|
|
self._evict_overflow(stream)
|
|
stream.condition.notify_all()
|
|
return entry.id
|
|
|
|
async def publish_end(self, run_id: str) -> str:
|
|
return await self.publish_terminal(run_id, StreamStatus.ENDED)
|
|
|
|
async def publish_terminal(
|
|
self,
|
|
run_id: str,
|
|
kind: StreamStatus,
|
|
data: Any = None,
|
|
) -> str:
|
|
if kind not in TERMINAL_STATES:
|
|
raise ValueError(f"Invalid terminal kind: {kind}")
|
|
|
|
stream = await self._get_or_create_stream(run_id)
|
|
async with stream.condition:
|
|
if stream.status != StreamStatus.ACTIVE:
|
|
for evt in reversed(stream.events):
|
|
if evt.event in ("end", "cancel", "error", "dead_letter"):
|
|
return evt.id
|
|
return ""
|
|
|
|
event_name = {
|
|
StreamStatus.ENDED: "end",
|
|
StreamStatus.CANCELLED: "cancel",
|
|
StreamStatus.ERRORED: "error",
|
|
}[kind]
|
|
entry = StreamEvent(id=self._next_id(stream), event=event_name, data=data)
|
|
absolute_offset = stream.start_offset + len(stream.events)
|
|
stream.events.append(entry)
|
|
stream.id_to_offset[entry.id] = absolute_offset
|
|
stream.current_bytes += self._estimate_size(entry)
|
|
stream.status = kind
|
|
stream.ended_at = time.monotonic()
|
|
stream.awaiting_input = False
|
|
stream.condition.notify_all()
|
|
logger.debug("Stream %s terminal: %s", run_id, kind.value)
|
|
return entry.id
|
|
|
|
async def cancel(self, run_id: str) -> None:
|
|
await self.publish_terminal(run_id, StreamStatus.CANCELLED)
|
|
|
|
async def subscribe(
|
|
self,
|
|
run_id: str,
|
|
*,
|
|
last_event_id: str | None = None,
|
|
heartbeat_interval: float = 15.0,
|
|
) -> AsyncIterator[StreamEvent]:
|
|
stream = await self._get_or_create_stream(run_id)
|
|
resume = self._resolve_resume_point(stream, last_event_id)
|
|
next_offset = resume.next_offset
|
|
|
|
async with stream.condition:
|
|
stream.subscriber_count += 1
|
|
stream.last_subscribe_at = time.monotonic()
|
|
|
|
try:
|
|
while True:
|
|
entry_to_yield: StreamEvent | None = None
|
|
sentinel_to_yield: StreamEvent | None = None
|
|
should_return = False
|
|
should_wait = False
|
|
|
|
async with stream.condition:
|
|
if self._closed or stream.status == StreamStatus.CLOSED:
|
|
sentinel_to_yield = CANCELLED_SENTINEL
|
|
should_return = True
|
|
elif next_offset < stream.start_offset:
|
|
next_offset = stream.start_offset
|
|
else:
|
|
local_index = next_offset - stream.start_offset
|
|
if 0 <= local_index < len(stream.events):
|
|
entry_to_yield = stream.events[local_index]
|
|
next_offset += 1
|
|
if entry_to_yield.event in ("end", "cancel", "error", "dead_letter"):
|
|
should_return = True
|
|
elif stream.status in TERMINAL_STATES:
|
|
sentinel_to_yield = END_SENTINEL
|
|
should_return = True
|
|
else:
|
|
should_wait = True
|
|
try:
|
|
await asyncio.wait_for(
|
|
stream.condition.wait(),
|
|
timeout=heartbeat_interval,
|
|
)
|
|
except TimeoutError:
|
|
pass
|
|
|
|
if sentinel_to_yield is not None:
|
|
yield sentinel_to_yield
|
|
if should_return:
|
|
return
|
|
continue
|
|
|
|
if entry_to_yield is not None:
|
|
yield entry_to_yield
|
|
if should_return:
|
|
return
|
|
continue
|
|
|
|
if should_wait:
|
|
async with stream.condition:
|
|
local_index = next_offset - stream.start_offset
|
|
has_events = 0 <= local_index < len(stream.events)
|
|
is_terminal = stream.status in TERMINAL_STATES
|
|
if not has_events and not is_terminal:
|
|
yield HEARTBEAT_SENTINEL
|
|
|
|
finally:
|
|
async with stream.condition:
|
|
stream.subscriber_count = max(0, stream.subscriber_count - 1)
|
|
|
|
async def mark_awaiting_input(self, run_id: str) -> None:
|
|
stream = self._streams.get(run_id)
|
|
if stream is None:
|
|
return
|
|
async with stream.condition:
|
|
if stream.status == StreamStatus.ACTIVE:
|
|
stream.awaiting_input = True
|
|
stream.awaiting_since = time.monotonic()
|
|
logger.debug("Stream %s marked as awaiting input", run_id)
|
|
|
|
async def cleanup(self, run_id: str, *, delay: float = 0) -> None:
|
|
if delay > 0:
|
|
await asyncio.sleep(delay)
|
|
await self._do_cleanup(run_id, "manual")
|
|
|
|
async def _do_cleanup(self, run_id: str, reason: str) -> None:
|
|
async with self._registry_lock:
|
|
stream = self._streams.pop(run_id, None)
|
|
if stream is not None:
|
|
async with stream.condition:
|
|
stream.status = StreamStatus.CLOSED
|
|
stream.condition.notify_all()
|
|
logger.debug("Cleaned up stream %s (reason: %s)", run_id, reason)
|
|
|
|
async def _mark_dead_letter(self, run_id: str, reason: str) -> None:
|
|
stream = self._streams.get(run_id)
|
|
if stream is None:
|
|
return
|
|
async with stream.condition:
|
|
if stream.status != StreamStatus.ACTIVE:
|
|
return
|
|
entry = StreamEvent(
|
|
id=self._next_id(stream),
|
|
event="dead_letter",
|
|
data={"reason": reason, "timestamp": time.time()},
|
|
)
|
|
absolute_offset = stream.start_offset + len(stream.events)
|
|
stream.events.append(entry)
|
|
stream.id_to_offset[entry.id] = absolute_offset
|
|
stream.current_bytes += self._estimate_size(entry)
|
|
stream.status = StreamStatus.ERRORED
|
|
stream.ended_at = time.monotonic()
|
|
stream.condition.notify_all()
|
|
logger.warning("Stream %s marked as dead letter: %s", run_id, reason)
|
|
|
|
async def _cleanup_loop(self) -> None:
|
|
while not self._closed:
|
|
try:
|
|
await asyncio.sleep(self._cleanup_interval)
|
|
except asyncio.CancelledError:
|
|
break
|
|
|
|
now = time.monotonic()
|
|
to_cleanup: list[tuple[str, str]] = []
|
|
to_mark_dead: list[tuple[str, str]] = []
|
|
|
|
async with self._registry_lock:
|
|
for run_id, stream in list(self._streams.items()):
|
|
if now - stream.created_at > self._max_age:
|
|
to_cleanup.append((run_id, "max_age_exceeded"))
|
|
continue
|
|
|
|
if stream.status == StreamStatus.ACTIVE:
|
|
timeout = self._hitl_timeout if stream.awaiting_input else self._active_timeout
|
|
last_activity = stream.last_publish_at or stream.created_at
|
|
if now - last_activity > timeout:
|
|
to_mark_dead.append((run_id, "no_publish_timeout"))
|
|
continue
|
|
|
|
if stream.status in TERMINAL_STATES and stream.ended_at:
|
|
if stream.subscriber_count > 0:
|
|
continue
|
|
last_sub = stream.last_subscribe_at or stream.ended_at
|
|
if now - last_sub > self._orphan_timeout:
|
|
to_cleanup.append((run_id, "orphan"))
|
|
continue
|
|
if now - stream.ended_at > self._terminal_ttl:
|
|
to_cleanup.append((run_id, "ttl_expired"))
|
|
|
|
for run_id, reason in to_mark_dead:
|
|
await self._mark_dead_letter(run_id, reason)
|
|
for run_id, reason in to_cleanup:
|
|
await self._do_cleanup(run_id, reason)
|
|
|
|
def get_stats(self) -> dict[str, Any]:
|
|
active = sum(1 for s in self._streams.values() if s.status == StreamStatus.ACTIVE)
|
|
terminal = sum(1 for s in self._streams.values() if s.status in TERMINAL_STATES)
|
|
total_events = sum(len(s.events) for s in self._streams.values())
|
|
total_bytes = sum(s.current_bytes for s in self._streams.values())
|
|
total_subs = sum(s.subscriber_count for s in self._streams.values())
|
|
return {
|
|
"total_streams": len(self._streams),
|
|
"active_streams": active,
|
|
"terminal_streams": terminal,
|
|
"total_events": total_events,
|
|
"total_bytes": total_bytes,
|
|
"total_subscribers": total_subs,
|
|
"closed": self._closed,
|
|
}
|
|
|
|
def _resolve_resume_point(
|
|
self,
|
|
stream: _RunStream,
|
|
last_event_id: str | None,
|
|
) -> ResumeResult:
|
|
if last_event_id is None:
|
|
return ResumeResult(next_offset=stream.start_offset, status="fresh")
|
|
if last_event_id in stream.id_to_offset:
|
|
return ResumeResult(
|
|
next_offset=stream.id_to_offset[last_event_id] + 1,
|
|
status="resumed",
|
|
)
|
|
|
|
parts = last_event_id.split("-")
|
|
if len(parts) != 2:
|
|
return ResumeResult(next_offset=stream.start_offset, status="invalid")
|
|
try:
|
|
event_ts = int(parts[0])
|
|
_event_seq = int(parts[1])
|
|
except ValueError:
|
|
return ResumeResult(next_offset=stream.start_offset, status="invalid")
|
|
|
|
if stream.events:
|
|
try:
|
|
oldest_parts = stream.events[0].id.split("-")
|
|
oldest_ts = int(oldest_parts[0])
|
|
if event_ts < oldest_ts:
|
|
return ResumeResult(
|
|
next_offset=stream.start_offset,
|
|
status="evicted",
|
|
gap_count=stream.start_offset,
|
|
)
|
|
except (ValueError, IndexError):
|
|
pass
|
|
|
|
return ResumeResult(next_offset=stream.start_offset, status="unknown")
|
|
|
|
|
|
__all__ = ["MemoryStreamBridge"]
|