refactor(runtime): add run DDD boundary skeleton

This commit is contained in:
rayhpeng 2026-06-01 09:22:32 +08:00
parent 9f3be2a9fa
commit 30bb2d5149
24 changed files with 1075 additions and 20 deletions

View File

@ -1,16 +1,39 @@
"""Run lifecycle management for LangGraph Platform API compatibility."""
from .domain import (
AssistantId,
CancelAction,
DisconnectMode,
EventSeq,
InvalidRunTransition,
MultitaskStrategy,
Run,
RunId,
RunScope,
RunStatus,
ThreadId,
UserId,
)
from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError
from .schemas import DisconnectMode, RunStatus
from .worker import RunContext, run_agent
__all__ = [
"AssistantId",
"CancelAction",
"ConflictError",
"DisconnectMode",
"EventSeq",
"InvalidRunTransition",
"MultitaskStrategy",
"Run",
"RunContext",
"RunId",
"RunManager",
"RunRecord",
"RunScope",
"RunStatus",
"ThreadId",
"UnsupportedStrategyError",
"UserId",
"run_agent",
]

View File

@ -0,0 +1,20 @@
"""Application-layer DTOs and services for run runtime use cases."""
from .commands import CancelRunCommand, CreateRunCommand, JoinRunStreamCommand
from .dto import RunMessageView, RunSnapshot, RunStreamHandle, StoredRunEvent
from .queries import GetRunQuery, ListRunMessagesQuery, ListRunsQuery
from .services import RunsApplicationService
__all__ = [
"CancelRunCommand",
"CreateRunCommand",
"GetRunQuery",
"JoinRunStreamCommand",
"ListRunMessagesQuery",
"ListRunsQuery",
"RunMessageView",
"RunSnapshot",
"RunStreamHandle",
"RunsApplicationService",
"StoredRunEvent",
]

View File

@ -0,0 +1,46 @@
"""Application command DTOs for run use cases."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Literal
from ..domain import AssistantId, CancelAction, DisconnectMode, MultitaskStrategy, RunId, RunScope, ThreadId
@dataclass(frozen=True)
class CreateRunCommand:
thread_id: ThreadId
assistant_id: AssistantId | None = None
input: dict[str, Any] | None = None
command: dict[str, Any] | None = None
metadata: dict[str, Any] = field(default_factory=dict)
config: dict[str, Any] = field(default_factory=dict)
context: dict[str, Any] = field(default_factory=dict)
scope: RunScope = RunScope.stateful
on_disconnect: DisconnectMode = DisconnectMode.cancel
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject
stream_mode: list[str] | str | None = None
stream_subgraphs: bool = False
interrupt_before: list[str] | Literal["*"] | None = None
interrupt_after: list[str] | Literal["*"] | None = None
@dataclass(frozen=True)
class CancelRunCommand:
run_id: RunId
action: CancelAction = CancelAction.interrupt
wait: bool = False
@dataclass(frozen=True)
class JoinRunStreamCommand:
run_id: RunId
last_event_id: str | None = None
__all__ = [
"CancelRunCommand",
"CreateRunCommand",
"JoinRunStreamCommand",
]

View File

@ -0,0 +1,76 @@
"""Application output DTOs for run use cases."""
from __future__ import annotations
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from typing import Any
from ..domain import AssistantId, EventSeq, Run, RunId, RunStatus, ThreadId
@dataclass(frozen=True)
class RunSnapshot:
run_id: RunId
thread_id: ThreadId
assistant_id: AssistantId | None = None
status: RunStatus = RunStatus.pending
metadata: dict[str, Any] = field(default_factory=dict)
kwargs: dict[str, Any] = field(default_factory=dict)
created_at: str = ""
updated_at: str = ""
error: str | None = None
model_name: str | None = None
@classmethod
def from_run(cls, run: Run) -> RunSnapshot:
return cls(
run_id=run.run_id,
thread_id=run.thread_id,
assistant_id=run.assistant_id,
status=run.status,
metadata=dict(run.metadata),
kwargs=dict(run.kwargs),
created_at=run.created_at,
updated_at=run.updated_at,
error=run.error,
model_name=run.model_name,
)
@dataclass(frozen=True)
class RunMessageView:
thread_id: ThreadId
run_id: RunId
seq: EventSeq
event_type: str
content: str | dict[str, Any] = ""
metadata: dict[str, Any] = field(default_factory=dict)
created_at: str = ""
@dataclass(frozen=True)
class StoredRunEvent:
thread_id: ThreadId
run_id: RunId
seq: EventSeq
event_type: str
category: str
content: str | dict[str, Any] = ""
metadata: dict[str, Any] = field(default_factory=dict)
created_at: str = ""
@dataclass(frozen=True)
class RunStreamHandle:
run_id: RunId
thread_id: ThreadId
events: AsyncIterator[Any]
__all__ = [
"RunMessageView",
"RunSnapshot",
"RunStreamHandle",
"StoredRunEvent",
]

View File

@ -0,0 +1,37 @@
"""Application query DTOs for run use cases."""
from __future__ import annotations
from dataclasses import dataclass
from ..domain import RunId, ThreadId, UserId
@dataclass(frozen=True)
class GetRunQuery:
run_id: RunId
thread_id: ThreadId | None = None
user_id: UserId | None = None
@dataclass(frozen=True)
class ListRunsQuery:
thread_id: ThreadId
user_id: UserId | None = None
limit: int = 100
@dataclass(frozen=True)
class ListRunMessagesQuery:
thread_id: ThreadId
run_id: RunId
limit: int = 50
before_seq: int | None = None
after_seq: int | None = None
__all__ = [
"GetRunQuery",
"ListRunMessagesQuery",
"ListRunsQuery",
]

View File

@ -0,0 +1,74 @@
"""Application service skeleton for run use cases."""
from __future__ import annotations
from dataclasses import dataclass
from ..execution import RunExecutionScheduler, RunSupervisor
from ..repositories import RunEventLog, RunRepository
from ..streams import RunStreamBroker
from .commands import CancelRunCommand, CreateRunCommand, JoinRunStreamCommand
from .dto import RunMessageView, RunSnapshot, RunStreamHandle
from .queries import GetRunQuery, ListRunMessagesQuery, ListRunsQuery
@dataclass
class RunsApplicationService:
"""Use-case orchestration boundary for run runtime operations.
PR1 only introduces the boundary and dependency shape. Existing Gateway
handlers continue to call the legacy service functions until later PRs move
behavior into this class.
"""
run_repository: RunRepository
run_event_log: RunEventLog
stream_broker: RunStreamBroker
scheduler: RunExecutionScheduler
supervisor: RunSupervisor
async def create_background(self, command: CreateRunCommand) -> RunSnapshot:
# PR1 defines the application boundary; later PRs move Gateway runtime
# behavior behind this method.
raise NotImplementedError("RunsApplicationService is not wired in PR1")
async def create_and_stream(self, command: CreateRunCommand) -> RunStreamHandle:
raise NotImplementedError("RunsApplicationService is not wired in PR1")
async def create_and_wait(self, command: CreateRunCommand) -> RunSnapshot:
raise NotImplementedError("RunsApplicationService is not wired in PR1")
async def join_stream(self, command: JoinRunStreamCommand) -> RunStreamHandle:
raise NotImplementedError("RunsApplicationService is not wired in PR1")
async def cancel(self, command: CancelRunCommand) -> bool:
return await self.supervisor.cancel(command.run_id, action=command.action)
async def get_run(self, query: GetRunQuery) -> RunSnapshot | None:
run = await self.run_repository.get(query.run_id, user_id=query.user_id)
if run is None:
return None
if query.thread_id is not None and run.thread_id != query.thread_id:
return None
return RunSnapshot.from_run(run)
async def list_runs(self, query: ListRunsQuery) -> list[RunSnapshot]:
return await self.run_repository.list_by_thread(
query.thread_id,
user_id=query.user_id,
limit=query.limit,
)
async def list_run_messages(self, query: ListRunMessagesQuery) -> list[RunMessageView]:
return await self.run_event_log.list_messages_by_run(
query.thread_id,
query.run_id,
limit=query.limit,
before_seq=query.before_seq,
after_seq=query.after_seq,
)
__all__ = [
"RunsApplicationService",
]

View File

@ -0,0 +1,33 @@
"""Run runtime domain model."""
from .errors import InvalidRunTransition, RunDomainError
from .events import RunCancelled, RunCompleted, RunCreated, RunEvent, RunFailed, RunStarted
from .identifiers import AssistantId, RunId, ThreadId, UserId
from .model import Run
from .policies import CancelPolicy, MultitaskDecision, MultitaskPolicy
from .value_objects import CancelAction, DisconnectMode, EventSeq, MultitaskStrategy, RunScope, RunStatus
__all__ = [
"AssistantId",
"CancelAction",
"CancelPolicy",
"DisconnectMode",
"EventSeq",
"InvalidRunTransition",
"MultitaskDecision",
"MultitaskPolicy",
"MultitaskStrategy",
"Run",
"RunCancelled",
"RunCompleted",
"RunCreated",
"RunDomainError",
"RunEvent",
"RunFailed",
"RunId",
"RunScope",
"RunStarted",
"RunStatus",
"ThreadId",
"UserId",
]

View File

@ -0,0 +1,24 @@
"""Domain-level errors for run lifecycle operations."""
from __future__ import annotations
from .value_objects import RunStatus
class RunDomainError(Exception):
"""Base class for run runtime domain errors."""
class InvalidRunTransition(RunDomainError):
"""Raised when a run status transition violates lifecycle rules."""
def __init__(self, current: RunStatus, target: RunStatus) -> None:
super().__init__(f"Cannot transition run from {current.value!r} to {target.value!r}")
self.current = current
self.target = target
__all__ = [
"InvalidRunTransition",
"RunDomainError",
]

View File

@ -0,0 +1,64 @@
"""Domain events emitted by the run aggregate."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from deerflow.utils.time import now_iso
from .identifiers import AssistantId, RunId, ThreadId
from .value_objects import CancelAction, RunStatus
@dataclass(frozen=True)
class RunCreated:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
assistant_id: AssistantId | None = None
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass(frozen=True)
class RunStarted:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
@dataclass(frozen=True)
class RunCompleted:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
@dataclass(frozen=True)
class RunFailed:
run_id: RunId
thread_id: ThreadId
status: RunStatus
occurred_at: str = field(default_factory=now_iso)
error: str | None = None
@dataclass(frozen=True)
class RunCancelled:
run_id: RunId
thread_id: ThreadId
occurred_at: str = field(default_factory=now_iso)
action: CancelAction = CancelAction.interrupt
RunEvent = RunCreated | RunStarted | RunCompleted | RunFailed | RunCancelled
__all__ = [
"RunCancelled",
"RunCompleted",
"RunCreated",
"RunEvent",
"RunFailed",
"RunStarted",
]

View File

@ -0,0 +1,27 @@
"""Lightweight identifiers for the run runtime domain."""
from __future__ import annotations
from typing import NewType
RunId = NewType("RunId", str)
ThreadId = NewType("ThreadId", str)
AssistantId = NewType("AssistantId", str)
UserId = NewType("UserId", str)
def require_non_empty(value: str, *, field_name: str) -> str:
"""Return a stripped identifier value, rejecting empty identifiers."""
normalized = value.strip()
if not normalized:
raise ValueError(f"{field_name} must not be empty")
return normalized
__all__ = [
"AssistantId",
"RunId",
"ThreadId",
"UserId",
"require_non_empty",
]

View File

@ -0,0 +1,193 @@
"""Run aggregate root and lifecycle invariants."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from deerflow.utils.time import now_iso
from .errors import InvalidRunTransition
from .events import RunCancelled, RunCompleted, RunCreated, RunEvent, RunFailed, RunStarted
from .identifiers import AssistantId, RunId, ThreadId, require_non_empty
from .value_objects import CancelAction, MultitaskStrategy, RunScope, RunStatus
# Keep lifecycle transitions explicit so later application code cannot invent
# ad hoc status moves outside the aggregate.
_ALLOWED_TRANSITIONS: dict[RunStatus, frozenset[RunStatus]] = {
RunStatus.pending: frozenset(
{
RunStatus.running,
RunStatus.error,
RunStatus.timeout,
RunStatus.interrupted,
}
),
RunStatus.running: frozenset(
{
RunStatus.success,
RunStatus.error,
RunStatus.timeout,
RunStatus.interrupted,
}
),
RunStatus.success: frozenset(),
RunStatus.error: frozenset(),
RunStatus.timeout: frozenset(),
RunStatus.interrupted: frozenset(),
}
@dataclass
class Run:
"""Run aggregate root.
The aggregate owns lifecycle invariants only. Infrastructure concerns such
as SQL sessions, SSE frames, Redis clients, and FastAPI requests stay out of
this model.
"""
run_id: RunId
thread_id: ThreadId
status: RunStatus
assistant_id: AssistantId | None = None
scope: RunScope = RunScope.stateful
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject
metadata: dict[str, Any] = field(default_factory=dict)
kwargs: dict[str, Any] = field(default_factory=dict)
created_at: str = field(default_factory=now_iso)
updated_at: str = field(default_factory=now_iso)
error: str | None = None
model_name: str | None = None
_pending_events: list[RunEvent] = field(default_factory=list, init=False, repr=False)
def __post_init__(self) -> None:
self.run_id = RunId(require_non_empty(str(self.run_id), field_name="run_id"))
self.thread_id = ThreadId(require_non_empty(str(self.thread_id), field_name="thread_id"))
if self.assistant_id is not None:
self.assistant_id = AssistantId(require_non_empty(str(self.assistant_id), field_name="assistant_id"))
@classmethod
def create(
cls,
*,
run_id: RunId,
thread_id: ThreadId,
assistant_id: AssistantId | None = None,
scope: RunScope = RunScope.stateful,
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject,
metadata: dict[str, Any] | None = None,
kwargs: dict[str, Any] | None = None,
model_name: str | None = None,
created_at: str | None = None,
) -> Run:
timestamp = created_at or now_iso()
run = cls(
run_id=run_id,
thread_id=thread_id,
assistant_id=assistant_id,
status=RunStatus.pending,
scope=scope,
multitask_strategy=multitask_strategy,
metadata=metadata or {},
kwargs=kwargs or {},
created_at=timestamp,
updated_at=timestamp,
model_name=model_name,
)
run._record_event(
RunCreated(
run_id=run.run_id,
thread_id=run.thread_id,
occurred_at=timestamp,
assistant_id=run.assistant_id,
metadata=dict(run.metadata),
)
)
return run
@property
def is_terminal(self) -> bool:
return not _ALLOWED_TRANSITIONS[self.status]
def pull_events(self) -> tuple[RunEvent, ...]:
# Domain events are drained by the application layer after the aggregate
# has accepted a state change.
events = tuple(self._pending_events)
self._pending_events.clear()
return events
def mark_started(self, *, at: str | None = None) -> None:
self._transition_to(RunStatus.running, at=at)
def mark_completed(self, *, at: str | None = None) -> None:
self._transition_to(RunStatus.success, at=at)
def mark_failed(self, error: str | None = None, *, at: str | None = None) -> None:
self._transition_to(RunStatus.error, error=error, at=at)
def mark_timed_out(self, error: str | None = None, *, at: str | None = None) -> None:
self._transition_to(RunStatus.timeout, error=error, at=at)
def mark_cancelled(self, *, action: CancelAction = CancelAction.interrupt, at: str | None = None) -> None:
self._transition_to(RunStatus.interrupted, action=action, at=at)
def _transition_to(
self,
target: RunStatus,
*,
error: str | None = None,
action: CancelAction = CancelAction.interrupt,
at: str | None = None,
) -> None:
if target == self.status:
return
if target not in _ALLOWED_TRANSITIONS[self.status]:
raise InvalidRunTransition(self.status, target)
timestamp = at or now_iso()
self.status = target
self.updated_at = timestamp
if error is not None:
self.error = error
self._record_event(self._event_for_transition(target, timestamp, error=error, action=action))
def _event_for_transition(
self,
target: RunStatus,
occurred_at: str,
*,
error: str | None,
action: CancelAction,
) -> RunEvent:
# Keep event construction next to the transition rules so a new status
# cannot be added without an explicit durable event shape.
if target == RunStatus.running:
return RunStarted(run_id=self.run_id, thread_id=self.thread_id, occurred_at=occurred_at)
if target == RunStatus.success:
return RunCompleted(run_id=self.run_id, thread_id=self.thread_id, occurred_at=occurred_at)
if target in (RunStatus.error, RunStatus.timeout):
return RunFailed(
run_id=self.run_id,
thread_id=self.thread_id,
status=target,
occurred_at=occurred_at,
error=error,
)
if target == RunStatus.interrupted:
return RunCancelled(
run_id=self.run_id,
thread_id=self.thread_id,
occurred_at=occurred_at,
action=action,
)
raise InvalidRunTransition(self.status, target)
def _record_event(self, event: RunEvent) -> None:
self._pending_events.append(event)
__all__ = [
"Run",
"RunStatus",
]

View File

@ -0,0 +1,50 @@
"""Domain policies for run concurrency and cancellation."""
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from enum import StrEnum
from .model import Run
from .value_objects import CancelAction, MultitaskStrategy, RunStatus
class MultitaskDecision(StrEnum):
"""Application-level decision produced by a multitask policy."""
allow = "allow"
reject = "reject"
cancel_existing = "cancel_existing"
enqueue = "enqueue"
@dataclass(frozen=True)
class MultitaskPolicy:
strategy: MultitaskStrategy = MultitaskStrategy.reject
def decide(self, active_runs: Sequence[Run]) -> MultitaskDecision:
inflight = [run for run in active_runs if run.status in (RunStatus.pending, RunStatus.running)]
if not inflight:
return MultitaskDecision.allow
if self.strategy == MultitaskStrategy.reject:
return MultitaskDecision.reject
if self.strategy in (MultitaskStrategy.interrupt, MultitaskStrategy.rollback):
return MultitaskDecision.cancel_existing
return MultitaskDecision.enqueue
@dataclass(frozen=True)
class CancelPolicy:
action: CancelAction = CancelAction.interrupt
@property
def rolls_back_checkpoint(self) -> bool:
return self.action == CancelAction.rollback
__all__ = [
"CancelPolicy",
"MultitaskDecision",
"MultitaskPolicy",
]

View File

@ -0,0 +1,88 @@
"""Domain value objects for run lifecycle semantics."""
from __future__ import annotations
from dataclasses import dataclass
from enum import StrEnum
class RunStatus(StrEnum):
"""Lifecycle status of a single run."""
pending = "pending"
running = "running"
success = "success"
error = "error"
timeout = "timeout"
interrupted = "interrupted"
class DisconnectMode(StrEnum):
"""Behaviour when the SSE consumer disconnects."""
cancel = "cancel"
continue_ = "continue"
class RunScope(StrEnum):
"""Conversation scope for a run."""
stateful = "stateful"
stateless = "stateless"
temporary_thread = "temporary_thread"
class MultitaskStrategy(StrEnum):
"""Concurrency strategy for a new run on a thread."""
reject = "reject"
interrupt = "interrupt"
rollback = "rollback"
enqueue = "enqueue"
class CancelAction(StrEnum):
"""Cancellation action requested by an API or supervisor."""
interrupt = "interrupt"
rollback = "rollback"
TERMINAL_RUN_STATUSES: frozenset[RunStatus] = frozenset(
{
RunStatus.success,
RunStatus.error,
RunStatus.timeout,
RunStatus.interrupted,
}
)
def is_terminal_status(status: RunStatus) -> bool:
return status in TERMINAL_RUN_STATUSES
@dataclass(frozen=True, order=True)
class EventSeq:
"""Thread-local event sequence number."""
value: int
def __post_init__(self) -> None:
if self.value < 0:
raise ValueError("EventSeq must be non-negative")
def next(self) -> EventSeq:
return EventSeq(self.value + 1)
__all__ = [
"CancelAction",
"DisconnectMode",
"EventSeq",
"MultitaskStrategy",
"RunScope",
"RunStatus",
"TERMINAL_RUN_STATUSES",
"is_terminal_status",
]

View File

@ -0,0 +1,12 @@
"""Execution contracts for run lifecycle orchestration."""
from .executor import RunExecutor
from .scheduler import RunExecutionHandle, RunExecutionScheduler
from .supervisor import RunSupervisor
__all__ = [
"RunExecutionHandle",
"RunExecutionScheduler",
"RunExecutor",
"RunSupervisor",
]

View File

@ -0,0 +1,18 @@
"""Run executor contract."""
from __future__ import annotations
from typing import Protocol
from ..domain import Run
class RunExecutor(Protocol):
"""Executes one run against the underlying agent or graph runtime."""
async def execute(self, run: Run) -> None: ...
__all__ = [
"RunExecutor",
]

View File

@ -0,0 +1,25 @@
"""Run execution scheduler contract."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Protocol
from ..domain import RunId
@dataclass(frozen=True)
class RunExecutionHandle:
run_id: RunId
class RunExecutionScheduler(Protocol):
"""Starts background execution for an accepted run."""
async def start(self, run_id: RunId) -> RunExecutionHandle: ...
__all__ = [
"RunExecutionHandle",
"RunExecutionScheduler",
]

View File

@ -0,0 +1,18 @@
"""Run execution supervision contract."""
from __future__ import annotations
from typing import Protocol
from ..domain import CancelAction, RunId
class RunSupervisor(Protocol):
"""Controls lifecycle operations for already scheduled runs."""
async def cancel(self, run_id: RunId, *, action: CancelAction = CancelAction.interrupt) -> bool: ...
__all__ = [
"RunSupervisor",
]

View File

@ -0,0 +1,9 @@
"""Repository contracts for the run runtime application layer."""
from .run_event_log import RunEventLog
from .run_repository import RunRepository
__all__ = [
"RunEventLog",
"RunRepository",
]

View File

@ -0,0 +1,39 @@
"""Durable run event log contract."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from ..domain import RunEvent, RunId, ThreadId
if TYPE_CHECKING:
from ..application.dto import RunMessageView, StoredRunEvent
class RunEventLog(Protocol):
"""Persistence boundary for run messages and execution trace events."""
async def append(self, events: list[RunEvent]) -> list[StoredRunEvent]: ...
async def list_messages_by_run(
self,
thread_id: ThreadId,
run_id: RunId,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
) -> list[RunMessageView]: ...
async def list_events_by_run(
self,
thread_id: ThreadId,
run_id: RunId,
*,
limit: int = 500,
) -> list[StoredRunEvent]: ...
__all__ = [
"RunEventLog",
]

View File

@ -0,0 +1,33 @@
"""Run state repository contract."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from ..domain import Run, RunId, ThreadId, UserId
if TYPE_CHECKING:
from ..application.dto import RunSnapshot
class RunRepository(Protocol):
"""Persistence boundary for run state snapshots."""
async def save(self, run: Run) -> None: ...
async def get(self, run_id: RunId, *, user_id: UserId | None = None) -> Run | None: ...
async def list_by_thread(
self,
thread_id: ThreadId,
*,
user_id: UserId | None = None,
limit: int = 100,
) -> list[RunSnapshot]: ...
async def delete(self, run_id: RunId) -> bool: ...
__all__ = [
"RunRepository",
]

View File

@ -1,21 +1,10 @@
"""Run status and disconnect mode enums."""
"""Compatibility exports for run status and disconnect mode enums."""
from enum import StrEnum
# Existing callers import these enums from ``runs.schemas``. Re-export the
# domain definitions until all imports move to ``runs.domain``.
from .domain import DisconnectMode, RunStatus
class RunStatus(StrEnum):
"""Lifecycle status of a single run."""
pending = "pending"
running = "running"
success = "success"
error = "error"
timeout = "timeout"
interrupted = "interrupted"
class DisconnectMode(StrEnum):
"""Behaviour when the SSE consumer disconnects."""
cancel = "cancel"
continue_ = "continue"
__all__ = [
"DisconnectMode",
"RunStatus",
]

View File

@ -0,0 +1,8 @@
"""Realtime stream contracts for run application use cases."""
from .run_stream_broker import RunStreamBroker, RunStreamEvent
__all__ = [
"RunStreamBroker",
"RunStreamEvent",
]

View File

@ -0,0 +1,40 @@
"""Realtime run stream broker contract."""
from __future__ import annotations
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any, Protocol
from ..domain import RunId
@dataclass(frozen=True)
class RunStreamEvent:
id: str
event: str
data: Any
class RunStreamBroker(Protocol):
"""Realtime publish/subscribe boundary for run streams."""
async def publish(self, run_id: RunId, event: str, data: Any) -> None: ...
async def publish_terminal(self, run_id: RunId, *, event: str = "end", data: Any = None) -> None: ...
def subscribe(
self,
run_id: RunId,
*,
last_event_id: str | None = None,
heartbeat_interval: float = 15.0,
) -> AsyncIterator[RunStreamEvent]: ...
async def cleanup(self, run_id: RunId, *, delay: float = 0) -> None: ...
__all__ = [
"RunStreamBroker",
"RunStreamEvent",
]

View File

@ -0,0 +1,109 @@
"""Tests for the DDD run domain skeleton."""
import pytest
from deerflow.runtime.runs import DisconnectMode, RunStatus
from deerflow.runtime.runs.domain import (
AssistantId,
CancelAction,
EventSeq,
InvalidRunTransition,
MultitaskStrategy,
Run,
RunCancelled,
RunCompleted,
RunCreated,
RunFailed,
RunId,
RunScope,
RunStarted,
ThreadId,
)
from deerflow.runtime.runs.schemas import DisconnectMode as CompatDisconnectMode
from deerflow.runtime.runs.schemas import RunStatus as CompatRunStatus
def test_compat_schema_exports_use_domain_enums() -> None:
assert CompatRunStatus is RunStatus
assert CompatDisconnectMode is DisconnectMode
def test_create_run_records_pending_state_and_created_event() -> None:
run = Run.create(
run_id=RunId("run-1"),
thread_id=ThreadId("thread-1"),
assistant_id=AssistantId("lead_agent"),
scope=RunScope.stateful,
multitask_strategy=MultitaskStrategy.reject,
metadata={"source": "test"},
kwargs={"input": {"messages": []}},
created_at="2026-01-01T00:00:00+00:00",
)
assert run.status == RunStatus.pending
assert run.run_id == "run-1"
assert run.thread_id == "thread-1"
assert run.assistant_id == "lead_agent"
assert run.created_at == "2026-01-01T00:00:00+00:00"
assert run.updated_at == "2026-01-01T00:00:00+00:00"
events = run.pull_events()
assert len(events) == 1
assert isinstance(events[0], RunCreated)
assert events[0].metadata == {"source": "test"}
assert run.pull_events() == ()
def test_run_allows_pending_running_success_transition() -> None:
run = Run.create(run_id=RunId("run-1"), thread_id=ThreadId("thread-1"))
run.pull_events()
run.mark_started(at="2026-01-01T00:00:01+00:00")
run.mark_completed(at="2026-01-01T00:00:02+00:00")
assert run.status == RunStatus.success
assert run.updated_at == "2026-01-01T00:00:02+00:00"
events = run.pull_events()
assert [type(event) for event in events] == [RunStarted, RunCompleted]
def test_run_records_failed_and_cancelled_domain_events() -> None:
failed = Run.create(run_id=RunId("run-failed"), thread_id=ThreadId("thread-1"))
failed.pull_events()
failed.mark_started()
failed.mark_failed("boom", at="2026-01-01T00:00:03+00:00")
failed_events = failed.pull_events()
assert failed.status == RunStatus.error
assert isinstance(failed_events[-1], RunFailed)
assert failed_events[-1].status == RunStatus.error
assert failed_events[-1].error == "boom"
cancelled = Run.create(run_id=RunId("run-cancelled"), thread_id=ThreadId("thread-1"))
cancelled.pull_events()
cancelled.mark_cancelled(action=CancelAction.rollback)
cancelled_events = cancelled.pull_events()
assert cancelled.status == RunStatus.interrupted
assert isinstance(cancelled_events[-1], RunCancelled)
assert cancelled_events[-1].action == CancelAction.rollback
def test_terminal_run_cannot_transition_again() -> None:
run = Run.create(run_id=RunId("run-1"), thread_id=ThreadId("thread-1"))
run.mark_started()
run.mark_completed()
with pytest.raises(InvalidRunTransition) as exc:
run.mark_failed("too late")
assert exc.value.current == RunStatus.success
assert exc.value.target == RunStatus.error
def test_domain_value_objects_validate_minimal_invariants() -> None:
assert EventSeq(1).next() == EventSeq(2)
with pytest.raises(ValueError, match="EventSeq"):
EventSeq(-1)
with pytest.raises(ValueError, match="run_id"):
Run.create(run_id=RunId(" "), thread_id=ThreadId("thread-1"))