mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-09 17:12:01 +00:00
refactor(runtime): add run DDD boundary skeleton
This commit is contained in:
parent
9f3be2a9fa
commit
30bb2d5149
@ -1,16 +1,39 @@
|
||||
"""Run lifecycle management for LangGraph Platform API compatibility."""
|
||||
|
||||
from .domain import (
|
||||
AssistantId,
|
||||
CancelAction,
|
||||
DisconnectMode,
|
||||
EventSeq,
|
||||
InvalidRunTransition,
|
||||
MultitaskStrategy,
|
||||
Run,
|
||||
RunId,
|
||||
RunScope,
|
||||
RunStatus,
|
||||
ThreadId,
|
||||
UserId,
|
||||
)
|
||||
from .manager import ConflictError, RunManager, RunRecord, UnsupportedStrategyError
|
||||
from .schemas import DisconnectMode, RunStatus
|
||||
from .worker import RunContext, run_agent
|
||||
|
||||
__all__ = [
|
||||
"AssistantId",
|
||||
"CancelAction",
|
||||
"ConflictError",
|
||||
"DisconnectMode",
|
||||
"EventSeq",
|
||||
"InvalidRunTransition",
|
||||
"MultitaskStrategy",
|
||||
"Run",
|
||||
"RunContext",
|
||||
"RunId",
|
||||
"RunManager",
|
||||
"RunRecord",
|
||||
"RunScope",
|
||||
"RunStatus",
|
||||
"ThreadId",
|
||||
"UnsupportedStrategyError",
|
||||
"UserId",
|
||||
"run_agent",
|
||||
]
|
||||
|
||||
@ -0,0 +1,20 @@
|
||||
"""Application-layer DTOs and services for run runtime use cases."""
|
||||
|
||||
from .commands import CancelRunCommand, CreateRunCommand, JoinRunStreamCommand
|
||||
from .dto import RunMessageView, RunSnapshot, RunStreamHandle, StoredRunEvent
|
||||
from .queries import GetRunQuery, ListRunMessagesQuery, ListRunsQuery
|
||||
from .services import RunsApplicationService
|
||||
|
||||
__all__ = [
|
||||
"CancelRunCommand",
|
||||
"CreateRunCommand",
|
||||
"GetRunQuery",
|
||||
"JoinRunStreamCommand",
|
||||
"ListRunMessagesQuery",
|
||||
"ListRunsQuery",
|
||||
"RunMessageView",
|
||||
"RunSnapshot",
|
||||
"RunStreamHandle",
|
||||
"RunsApplicationService",
|
||||
"StoredRunEvent",
|
||||
]
|
||||
@ -0,0 +1,46 @@
|
||||
"""Application command DTOs for run use cases."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
from ..domain import AssistantId, CancelAction, DisconnectMode, MultitaskStrategy, RunId, RunScope, ThreadId
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CreateRunCommand:
|
||||
thread_id: ThreadId
|
||||
assistant_id: AssistantId | None = None
|
||||
input: dict[str, Any] | None = None
|
||||
command: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
config: dict[str, Any] = field(default_factory=dict)
|
||||
context: dict[str, Any] = field(default_factory=dict)
|
||||
scope: RunScope = RunScope.stateful
|
||||
on_disconnect: DisconnectMode = DisconnectMode.cancel
|
||||
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject
|
||||
stream_mode: list[str] | str | None = None
|
||||
stream_subgraphs: bool = False
|
||||
interrupt_before: list[str] | Literal["*"] | None = None
|
||||
interrupt_after: list[str] | Literal["*"] | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CancelRunCommand:
|
||||
run_id: RunId
|
||||
action: CancelAction = CancelAction.interrupt
|
||||
wait: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JoinRunStreamCommand:
|
||||
run_id: RunId
|
||||
last_event_id: str | None = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CancelRunCommand",
|
||||
"CreateRunCommand",
|
||||
"JoinRunStreamCommand",
|
||||
]
|
||||
@ -0,0 +1,76 @@
|
||||
"""Application output DTOs for run use cases."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from ..domain import AssistantId, EventSeq, Run, RunId, RunStatus, ThreadId
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunSnapshot:
|
||||
run_id: RunId
|
||||
thread_id: ThreadId
|
||||
assistant_id: AssistantId | None = None
|
||||
status: RunStatus = RunStatus.pending
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
error: str | None = None
|
||||
model_name: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_run(cls, run: Run) -> RunSnapshot:
|
||||
return cls(
|
||||
run_id=run.run_id,
|
||||
thread_id=run.thread_id,
|
||||
assistant_id=run.assistant_id,
|
||||
status=run.status,
|
||||
metadata=dict(run.metadata),
|
||||
kwargs=dict(run.kwargs),
|
||||
created_at=run.created_at,
|
||||
updated_at=run.updated_at,
|
||||
error=run.error,
|
||||
model_name=run.model_name,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunMessageView:
|
||||
thread_id: ThreadId
|
||||
run_id: RunId
|
||||
seq: EventSeq
|
||||
event_type: str
|
||||
content: str | dict[str, Any] = ""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
created_at: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StoredRunEvent:
|
||||
thread_id: ThreadId
|
||||
run_id: RunId
|
||||
seq: EventSeq
|
||||
event_type: str
|
||||
category: str
|
||||
content: str | dict[str, Any] = ""
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
created_at: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunStreamHandle:
|
||||
run_id: RunId
|
||||
thread_id: ThreadId
|
||||
events: AsyncIterator[Any]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RunMessageView",
|
||||
"RunSnapshot",
|
||||
"RunStreamHandle",
|
||||
"StoredRunEvent",
|
||||
]
|
||||
@ -0,0 +1,37 @@
|
||||
"""Application query DTOs for run use cases."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..domain import RunId, ThreadId, UserId
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GetRunQuery:
|
||||
run_id: RunId
|
||||
thread_id: ThreadId | None = None
|
||||
user_id: UserId | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ListRunsQuery:
|
||||
thread_id: ThreadId
|
||||
user_id: UserId | None = None
|
||||
limit: int = 100
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ListRunMessagesQuery:
|
||||
thread_id: ThreadId
|
||||
run_id: RunId
|
||||
limit: int = 50
|
||||
before_seq: int | None = None
|
||||
after_seq: int | None = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GetRunQuery",
|
||||
"ListRunMessagesQuery",
|
||||
"ListRunsQuery",
|
||||
]
|
||||
@ -0,0 +1,74 @@
|
||||
"""Application service skeleton for run use cases."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..execution import RunExecutionScheduler, RunSupervisor
|
||||
from ..repositories import RunEventLog, RunRepository
|
||||
from ..streams import RunStreamBroker
|
||||
from .commands import CancelRunCommand, CreateRunCommand, JoinRunStreamCommand
|
||||
from .dto import RunMessageView, RunSnapshot, RunStreamHandle
|
||||
from .queries import GetRunQuery, ListRunMessagesQuery, ListRunsQuery
|
||||
|
||||
|
||||
@dataclass
|
||||
class RunsApplicationService:
|
||||
"""Use-case orchestration boundary for run runtime operations.
|
||||
|
||||
PR1 only introduces the boundary and dependency shape. Existing Gateway
|
||||
handlers continue to call the legacy service functions until later PRs move
|
||||
behavior into this class.
|
||||
"""
|
||||
|
||||
run_repository: RunRepository
|
||||
run_event_log: RunEventLog
|
||||
stream_broker: RunStreamBroker
|
||||
scheduler: RunExecutionScheduler
|
||||
supervisor: RunSupervisor
|
||||
|
||||
async def create_background(self, command: CreateRunCommand) -> RunSnapshot:
|
||||
# PR1 defines the application boundary; later PRs move Gateway runtime
|
||||
# behavior behind this method.
|
||||
raise NotImplementedError("RunsApplicationService is not wired in PR1")
|
||||
|
||||
async def create_and_stream(self, command: CreateRunCommand) -> RunStreamHandle:
|
||||
raise NotImplementedError("RunsApplicationService is not wired in PR1")
|
||||
|
||||
async def create_and_wait(self, command: CreateRunCommand) -> RunSnapshot:
|
||||
raise NotImplementedError("RunsApplicationService is not wired in PR1")
|
||||
|
||||
async def join_stream(self, command: JoinRunStreamCommand) -> RunStreamHandle:
|
||||
raise NotImplementedError("RunsApplicationService is not wired in PR1")
|
||||
|
||||
async def cancel(self, command: CancelRunCommand) -> bool:
|
||||
return await self.supervisor.cancel(command.run_id, action=command.action)
|
||||
|
||||
async def get_run(self, query: GetRunQuery) -> RunSnapshot | None:
|
||||
run = await self.run_repository.get(query.run_id, user_id=query.user_id)
|
||||
if run is None:
|
||||
return None
|
||||
if query.thread_id is not None and run.thread_id != query.thread_id:
|
||||
return None
|
||||
return RunSnapshot.from_run(run)
|
||||
|
||||
async def list_runs(self, query: ListRunsQuery) -> list[RunSnapshot]:
|
||||
return await self.run_repository.list_by_thread(
|
||||
query.thread_id,
|
||||
user_id=query.user_id,
|
||||
limit=query.limit,
|
||||
)
|
||||
|
||||
async def list_run_messages(self, query: ListRunMessagesQuery) -> list[RunMessageView]:
|
||||
return await self.run_event_log.list_messages_by_run(
|
||||
query.thread_id,
|
||||
query.run_id,
|
||||
limit=query.limit,
|
||||
before_seq=query.before_seq,
|
||||
after_seq=query.after_seq,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RunsApplicationService",
|
||||
]
|
||||
@ -0,0 +1,33 @@
|
||||
"""Run runtime domain model."""
|
||||
|
||||
from .errors import InvalidRunTransition, RunDomainError
|
||||
from .events import RunCancelled, RunCompleted, RunCreated, RunEvent, RunFailed, RunStarted
|
||||
from .identifiers import AssistantId, RunId, ThreadId, UserId
|
||||
from .model import Run
|
||||
from .policies import CancelPolicy, MultitaskDecision, MultitaskPolicy
|
||||
from .value_objects import CancelAction, DisconnectMode, EventSeq, MultitaskStrategy, RunScope, RunStatus
|
||||
|
||||
__all__ = [
|
||||
"AssistantId",
|
||||
"CancelAction",
|
||||
"CancelPolicy",
|
||||
"DisconnectMode",
|
||||
"EventSeq",
|
||||
"InvalidRunTransition",
|
||||
"MultitaskDecision",
|
||||
"MultitaskPolicy",
|
||||
"MultitaskStrategy",
|
||||
"Run",
|
||||
"RunCancelled",
|
||||
"RunCompleted",
|
||||
"RunCreated",
|
||||
"RunDomainError",
|
||||
"RunEvent",
|
||||
"RunFailed",
|
||||
"RunId",
|
||||
"RunScope",
|
||||
"RunStarted",
|
||||
"RunStatus",
|
||||
"ThreadId",
|
||||
"UserId",
|
||||
]
|
||||
@ -0,0 +1,24 @@
|
||||
"""Domain-level errors for run lifecycle operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .value_objects import RunStatus
|
||||
|
||||
|
||||
class RunDomainError(Exception):
|
||||
"""Base class for run runtime domain errors."""
|
||||
|
||||
|
||||
class InvalidRunTransition(RunDomainError):
|
||||
"""Raised when a run status transition violates lifecycle rules."""
|
||||
|
||||
def __init__(self, current: RunStatus, target: RunStatus) -> None:
|
||||
super().__init__(f"Cannot transition run from {current.value!r} to {target.value!r}")
|
||||
self.current = current
|
||||
self.target = target
|
||||
|
||||
|
||||
__all__ = [
|
||||
"InvalidRunTransition",
|
||||
"RunDomainError",
|
||||
]
|
||||
@ -0,0 +1,64 @@
|
||||
"""Domain events emitted by the run aggregate."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from deerflow.utils.time import now_iso
|
||||
|
||||
from .identifiers import AssistantId, RunId, ThreadId
|
||||
from .value_objects import CancelAction, RunStatus
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunCreated:
|
||||
run_id: RunId
|
||||
thread_id: ThreadId
|
||||
occurred_at: str = field(default_factory=now_iso)
|
||||
assistant_id: AssistantId | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunStarted:
|
||||
run_id: RunId
|
||||
thread_id: ThreadId
|
||||
occurred_at: str = field(default_factory=now_iso)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunCompleted:
|
||||
run_id: RunId
|
||||
thread_id: ThreadId
|
||||
occurred_at: str = field(default_factory=now_iso)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunFailed:
|
||||
run_id: RunId
|
||||
thread_id: ThreadId
|
||||
status: RunStatus
|
||||
occurred_at: str = field(default_factory=now_iso)
|
||||
error: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunCancelled:
|
||||
run_id: RunId
|
||||
thread_id: ThreadId
|
||||
occurred_at: str = field(default_factory=now_iso)
|
||||
action: CancelAction = CancelAction.interrupt
|
||||
|
||||
|
||||
RunEvent = RunCreated | RunStarted | RunCompleted | RunFailed | RunCancelled
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RunCancelled",
|
||||
"RunCompleted",
|
||||
"RunCreated",
|
||||
"RunEvent",
|
||||
"RunFailed",
|
||||
"RunStarted",
|
||||
]
|
||||
@ -0,0 +1,27 @@
|
||||
"""Lightweight identifiers for the run runtime domain."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NewType
|
||||
|
||||
RunId = NewType("RunId", str)
|
||||
ThreadId = NewType("ThreadId", str)
|
||||
AssistantId = NewType("AssistantId", str)
|
||||
UserId = NewType("UserId", str)
|
||||
|
||||
|
||||
def require_non_empty(value: str, *, field_name: str) -> str:
|
||||
"""Return a stripped identifier value, rejecting empty identifiers."""
|
||||
normalized = value.strip()
|
||||
if not normalized:
|
||||
raise ValueError(f"{field_name} must not be empty")
|
||||
return normalized
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AssistantId",
|
||||
"RunId",
|
||||
"ThreadId",
|
||||
"UserId",
|
||||
"require_non_empty",
|
||||
]
|
||||
193
backend/packages/harness/deerflow/runtime/runs/domain/model.py
Normal file
193
backend/packages/harness/deerflow/runtime/runs/domain/model.py
Normal file
@ -0,0 +1,193 @@
|
||||
"""Run aggregate root and lifecycle invariants."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from deerflow.utils.time import now_iso
|
||||
|
||||
from .errors import InvalidRunTransition
|
||||
from .events import RunCancelled, RunCompleted, RunCreated, RunEvent, RunFailed, RunStarted
|
||||
from .identifiers import AssistantId, RunId, ThreadId, require_non_empty
|
||||
from .value_objects import CancelAction, MultitaskStrategy, RunScope, RunStatus
|
||||
|
||||
# Keep lifecycle transitions explicit so later application code cannot invent
|
||||
# ad hoc status moves outside the aggregate.
|
||||
_ALLOWED_TRANSITIONS: dict[RunStatus, frozenset[RunStatus]] = {
|
||||
RunStatus.pending: frozenset(
|
||||
{
|
||||
RunStatus.running,
|
||||
RunStatus.error,
|
||||
RunStatus.timeout,
|
||||
RunStatus.interrupted,
|
||||
}
|
||||
),
|
||||
RunStatus.running: frozenset(
|
||||
{
|
||||
RunStatus.success,
|
||||
RunStatus.error,
|
||||
RunStatus.timeout,
|
||||
RunStatus.interrupted,
|
||||
}
|
||||
),
|
||||
RunStatus.success: frozenset(),
|
||||
RunStatus.error: frozenset(),
|
||||
RunStatus.timeout: frozenset(),
|
||||
RunStatus.interrupted: frozenset(),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Run:
|
||||
"""Run aggregate root.
|
||||
|
||||
The aggregate owns lifecycle invariants only. Infrastructure concerns such
|
||||
as SQL sessions, SSE frames, Redis clients, and FastAPI requests stay out of
|
||||
this model.
|
||||
"""
|
||||
|
||||
run_id: RunId
|
||||
thread_id: ThreadId
|
||||
status: RunStatus
|
||||
assistant_id: AssistantId | None = None
|
||||
scope: RunScope = RunScope.stateful
|
||||
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
created_at: str = field(default_factory=now_iso)
|
||||
updated_at: str = field(default_factory=now_iso)
|
||||
error: str | None = None
|
||||
model_name: str | None = None
|
||||
_pending_events: list[RunEvent] = field(default_factory=list, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.run_id = RunId(require_non_empty(str(self.run_id), field_name="run_id"))
|
||||
self.thread_id = ThreadId(require_non_empty(str(self.thread_id), field_name="thread_id"))
|
||||
if self.assistant_id is not None:
|
||||
self.assistant_id = AssistantId(require_non_empty(str(self.assistant_id), field_name="assistant_id"))
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
*,
|
||||
run_id: RunId,
|
||||
thread_id: ThreadId,
|
||||
assistant_id: AssistantId | None = None,
|
||||
scope: RunScope = RunScope.stateful,
|
||||
multitask_strategy: MultitaskStrategy = MultitaskStrategy.reject,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
model_name: str | None = None,
|
||||
created_at: str | None = None,
|
||||
) -> Run:
|
||||
timestamp = created_at or now_iso()
|
||||
run = cls(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
status=RunStatus.pending,
|
||||
scope=scope,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata=metadata or {},
|
||||
kwargs=kwargs or {},
|
||||
created_at=timestamp,
|
||||
updated_at=timestamp,
|
||||
model_name=model_name,
|
||||
)
|
||||
run._record_event(
|
||||
RunCreated(
|
||||
run_id=run.run_id,
|
||||
thread_id=run.thread_id,
|
||||
occurred_at=timestamp,
|
||||
assistant_id=run.assistant_id,
|
||||
metadata=dict(run.metadata),
|
||||
)
|
||||
)
|
||||
return run
|
||||
|
||||
@property
|
||||
def is_terminal(self) -> bool:
|
||||
return not _ALLOWED_TRANSITIONS[self.status]
|
||||
|
||||
def pull_events(self) -> tuple[RunEvent, ...]:
|
||||
# Domain events are drained by the application layer after the aggregate
|
||||
# has accepted a state change.
|
||||
events = tuple(self._pending_events)
|
||||
self._pending_events.clear()
|
||||
return events
|
||||
|
||||
def mark_started(self, *, at: str | None = None) -> None:
|
||||
self._transition_to(RunStatus.running, at=at)
|
||||
|
||||
def mark_completed(self, *, at: str | None = None) -> None:
|
||||
self._transition_to(RunStatus.success, at=at)
|
||||
|
||||
def mark_failed(self, error: str | None = None, *, at: str | None = None) -> None:
|
||||
self._transition_to(RunStatus.error, error=error, at=at)
|
||||
|
||||
def mark_timed_out(self, error: str | None = None, *, at: str | None = None) -> None:
|
||||
self._transition_to(RunStatus.timeout, error=error, at=at)
|
||||
|
||||
def mark_cancelled(self, *, action: CancelAction = CancelAction.interrupt, at: str | None = None) -> None:
|
||||
self._transition_to(RunStatus.interrupted, action=action, at=at)
|
||||
|
||||
def _transition_to(
|
||||
self,
|
||||
target: RunStatus,
|
||||
*,
|
||||
error: str | None = None,
|
||||
action: CancelAction = CancelAction.interrupt,
|
||||
at: str | None = None,
|
||||
) -> None:
|
||||
if target == self.status:
|
||||
return
|
||||
if target not in _ALLOWED_TRANSITIONS[self.status]:
|
||||
raise InvalidRunTransition(self.status, target)
|
||||
|
||||
timestamp = at or now_iso()
|
||||
self.status = target
|
||||
self.updated_at = timestamp
|
||||
if error is not None:
|
||||
self.error = error
|
||||
self._record_event(self._event_for_transition(target, timestamp, error=error, action=action))
|
||||
|
||||
def _event_for_transition(
|
||||
self,
|
||||
target: RunStatus,
|
||||
occurred_at: str,
|
||||
*,
|
||||
error: str | None,
|
||||
action: CancelAction,
|
||||
) -> RunEvent:
|
||||
# Keep event construction next to the transition rules so a new status
|
||||
# cannot be added without an explicit durable event shape.
|
||||
if target == RunStatus.running:
|
||||
return RunStarted(run_id=self.run_id, thread_id=self.thread_id, occurred_at=occurred_at)
|
||||
if target == RunStatus.success:
|
||||
return RunCompleted(run_id=self.run_id, thread_id=self.thread_id, occurred_at=occurred_at)
|
||||
if target in (RunStatus.error, RunStatus.timeout):
|
||||
return RunFailed(
|
||||
run_id=self.run_id,
|
||||
thread_id=self.thread_id,
|
||||
status=target,
|
||||
occurred_at=occurred_at,
|
||||
error=error,
|
||||
)
|
||||
if target == RunStatus.interrupted:
|
||||
return RunCancelled(
|
||||
run_id=self.run_id,
|
||||
thread_id=self.thread_id,
|
||||
occurred_at=occurred_at,
|
||||
action=action,
|
||||
)
|
||||
raise InvalidRunTransition(self.status, target)
|
||||
|
||||
def _record_event(self, event: RunEvent) -> None:
|
||||
self._pending_events.append(event)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Run",
|
||||
"RunStatus",
|
||||
]
|
||||
@ -0,0 +1,50 @@
|
||||
"""Domain policies for run concurrency and cancellation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
|
||||
from .model import Run
|
||||
from .value_objects import CancelAction, MultitaskStrategy, RunStatus
|
||||
|
||||
|
||||
class MultitaskDecision(StrEnum):
|
||||
"""Application-level decision produced by a multitask policy."""
|
||||
|
||||
allow = "allow"
|
||||
reject = "reject"
|
||||
cancel_existing = "cancel_existing"
|
||||
enqueue = "enqueue"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultitaskPolicy:
|
||||
strategy: MultitaskStrategy = MultitaskStrategy.reject
|
||||
|
||||
def decide(self, active_runs: Sequence[Run]) -> MultitaskDecision:
|
||||
inflight = [run for run in active_runs if run.status in (RunStatus.pending, RunStatus.running)]
|
||||
if not inflight:
|
||||
return MultitaskDecision.allow
|
||||
if self.strategy == MultitaskStrategy.reject:
|
||||
return MultitaskDecision.reject
|
||||
if self.strategy in (MultitaskStrategy.interrupt, MultitaskStrategy.rollback):
|
||||
return MultitaskDecision.cancel_existing
|
||||
return MultitaskDecision.enqueue
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CancelPolicy:
|
||||
action: CancelAction = CancelAction.interrupt
|
||||
|
||||
@property
|
||||
def rolls_back_checkpoint(self) -> bool:
|
||||
return self.action == CancelAction.rollback
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CancelPolicy",
|
||||
"MultitaskDecision",
|
||||
"MultitaskPolicy",
|
||||
]
|
||||
@ -0,0 +1,88 @@
|
||||
"""Domain value objects for run lifecycle semantics."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class RunStatus(StrEnum):
|
||||
"""Lifecycle status of a single run."""
|
||||
|
||||
pending = "pending"
|
||||
running = "running"
|
||||
success = "success"
|
||||
error = "error"
|
||||
timeout = "timeout"
|
||||
interrupted = "interrupted"
|
||||
|
||||
|
||||
class DisconnectMode(StrEnum):
|
||||
"""Behaviour when the SSE consumer disconnects."""
|
||||
|
||||
cancel = "cancel"
|
||||
continue_ = "continue"
|
||||
|
||||
|
||||
class RunScope(StrEnum):
|
||||
"""Conversation scope for a run."""
|
||||
|
||||
stateful = "stateful"
|
||||
stateless = "stateless"
|
||||
temporary_thread = "temporary_thread"
|
||||
|
||||
|
||||
class MultitaskStrategy(StrEnum):
|
||||
"""Concurrency strategy for a new run on a thread."""
|
||||
|
||||
reject = "reject"
|
||||
interrupt = "interrupt"
|
||||
rollback = "rollback"
|
||||
enqueue = "enqueue"
|
||||
|
||||
|
||||
class CancelAction(StrEnum):
|
||||
"""Cancellation action requested by an API or supervisor."""
|
||||
|
||||
interrupt = "interrupt"
|
||||
rollback = "rollback"
|
||||
|
||||
|
||||
TERMINAL_RUN_STATUSES: frozenset[RunStatus] = frozenset(
|
||||
{
|
||||
RunStatus.success,
|
||||
RunStatus.error,
|
||||
RunStatus.timeout,
|
||||
RunStatus.interrupted,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def is_terminal_status(status: RunStatus) -> bool:
|
||||
return status in TERMINAL_RUN_STATUSES
|
||||
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class EventSeq:
|
||||
"""Thread-local event sequence number."""
|
||||
|
||||
value: int
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.value < 0:
|
||||
raise ValueError("EventSeq must be non-negative")
|
||||
|
||||
def next(self) -> EventSeq:
|
||||
return EventSeq(self.value + 1)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CancelAction",
|
||||
"DisconnectMode",
|
||||
"EventSeq",
|
||||
"MultitaskStrategy",
|
||||
"RunScope",
|
||||
"RunStatus",
|
||||
"TERMINAL_RUN_STATUSES",
|
||||
"is_terminal_status",
|
||||
]
|
||||
@ -0,0 +1,12 @@
|
||||
"""Execution contracts for run lifecycle orchestration."""
|
||||
|
||||
from .executor import RunExecutor
|
||||
from .scheduler import RunExecutionHandle, RunExecutionScheduler
|
||||
from .supervisor import RunSupervisor
|
||||
|
||||
__all__ = [
|
||||
"RunExecutionHandle",
|
||||
"RunExecutionScheduler",
|
||||
"RunExecutor",
|
||||
"RunSupervisor",
|
||||
]
|
||||
@ -0,0 +1,18 @@
|
||||
"""Run executor contract."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from ..domain import Run
|
||||
|
||||
|
||||
class RunExecutor(Protocol):
|
||||
"""Executes one run against the underlying agent or graph runtime."""
|
||||
|
||||
async def execute(self, run: Run) -> None: ...
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RunExecutor",
|
||||
]
|
||||
@ -0,0 +1,25 @@
|
||||
"""Run execution scheduler contract."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
from ..domain import RunId
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunExecutionHandle:
|
||||
run_id: RunId
|
||||
|
||||
|
||||
class RunExecutionScheduler(Protocol):
|
||||
"""Starts background execution for an accepted run."""
|
||||
|
||||
async def start(self, run_id: RunId) -> RunExecutionHandle: ...
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RunExecutionHandle",
|
||||
"RunExecutionScheduler",
|
||||
]
|
||||
@ -0,0 +1,18 @@
|
||||
"""Run execution supervision contract."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from ..domain import CancelAction, RunId
|
||||
|
||||
|
||||
class RunSupervisor(Protocol):
|
||||
"""Controls lifecycle operations for already scheduled runs."""
|
||||
|
||||
async def cancel(self, run_id: RunId, *, action: CancelAction = CancelAction.interrupt) -> bool: ...
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RunSupervisor",
|
||||
]
|
||||
@ -0,0 +1,9 @@
|
||||
"""Repository contracts for the run runtime application layer."""
|
||||
|
||||
from .run_event_log import RunEventLog
|
||||
from .run_repository import RunRepository
|
||||
|
||||
__all__ = [
|
||||
"RunEventLog",
|
||||
"RunRepository",
|
||||
]
|
||||
@ -0,0 +1,39 @@
|
||||
"""Durable run event log contract."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from ..domain import RunEvent, RunId, ThreadId
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..application.dto import RunMessageView, StoredRunEvent
|
||||
|
||||
|
||||
class RunEventLog(Protocol):
|
||||
"""Persistence boundary for run messages and execution trace events."""
|
||||
|
||||
async def append(self, events: list[RunEvent]) -> list[StoredRunEvent]: ...
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id: ThreadId,
|
||||
run_id: RunId,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[RunMessageView]: ...
|
||||
|
||||
async def list_events_by_run(
|
||||
self,
|
||||
thread_id: ThreadId,
|
||||
run_id: RunId,
|
||||
*,
|
||||
limit: int = 500,
|
||||
) -> list[StoredRunEvent]: ...
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RunEventLog",
|
||||
]
|
||||
@ -0,0 +1,33 @@
|
||||
"""Run state repository contract."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from ..domain import Run, RunId, ThreadId, UserId
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..application.dto import RunSnapshot
|
||||
|
||||
|
||||
class RunRepository(Protocol):
|
||||
"""Persistence boundary for run state snapshots."""
|
||||
|
||||
async def save(self, run: Run) -> None: ...
|
||||
|
||||
async def get(self, run_id: RunId, *, user_id: UserId | None = None) -> Run | None: ...
|
||||
|
||||
async def list_by_thread(
|
||||
self,
|
||||
thread_id: ThreadId,
|
||||
*,
|
||||
user_id: UserId | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[RunSnapshot]: ...
|
||||
|
||||
async def delete(self, run_id: RunId) -> bool: ...
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RunRepository",
|
||||
]
|
||||
@ -1,21 +1,10 @@
|
||||
"""Run status and disconnect mode enums."""
|
||||
"""Compatibility exports for run status and disconnect mode enums."""
|
||||
|
||||
from enum import StrEnum
|
||||
# Existing callers import these enums from ``runs.schemas``. Re-export the
|
||||
# domain definitions until all imports move to ``runs.domain``.
|
||||
from .domain import DisconnectMode, RunStatus
|
||||
|
||||
|
||||
class RunStatus(StrEnum):
|
||||
"""Lifecycle status of a single run."""
|
||||
|
||||
pending = "pending"
|
||||
running = "running"
|
||||
success = "success"
|
||||
error = "error"
|
||||
timeout = "timeout"
|
||||
interrupted = "interrupted"
|
||||
|
||||
|
||||
class DisconnectMode(StrEnum):
|
||||
"""Behaviour when the SSE consumer disconnects."""
|
||||
|
||||
cancel = "cancel"
|
||||
continue_ = "continue"
|
||||
__all__ = [
|
||||
"DisconnectMode",
|
||||
"RunStatus",
|
||||
]
|
||||
|
||||
@ -0,0 +1,8 @@
|
||||
"""Realtime stream contracts for run application use cases."""
|
||||
|
||||
from .run_stream_broker import RunStreamBroker, RunStreamEvent
|
||||
|
||||
__all__ = [
|
||||
"RunStreamBroker",
|
||||
"RunStreamEvent",
|
||||
]
|
||||
@ -0,0 +1,40 @@
|
||||
"""Realtime run stream broker contract."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Protocol
|
||||
|
||||
from ..domain import RunId
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RunStreamEvent:
|
||||
id: str
|
||||
event: str
|
||||
data: Any
|
||||
|
||||
|
||||
class RunStreamBroker(Protocol):
|
||||
"""Realtime publish/subscribe boundary for run streams."""
|
||||
|
||||
async def publish(self, run_id: RunId, event: str, data: Any) -> None: ...
|
||||
|
||||
async def publish_terminal(self, run_id: RunId, *, event: str = "end", data: Any = None) -> None: ...
|
||||
|
||||
def subscribe(
|
||||
self,
|
||||
run_id: RunId,
|
||||
*,
|
||||
last_event_id: str | None = None,
|
||||
heartbeat_interval: float = 15.0,
|
||||
) -> AsyncIterator[RunStreamEvent]: ...
|
||||
|
||||
async def cleanup(self, run_id: RunId, *, delay: float = 0) -> None: ...
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RunStreamBroker",
|
||||
"RunStreamEvent",
|
||||
]
|
||||
109
backend/tests/test_run_domain.py
Normal file
109
backend/tests/test_run_domain.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""Tests for the DDD run domain skeleton."""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.runs import DisconnectMode, RunStatus
|
||||
from deerflow.runtime.runs.domain import (
|
||||
AssistantId,
|
||||
CancelAction,
|
||||
EventSeq,
|
||||
InvalidRunTransition,
|
||||
MultitaskStrategy,
|
||||
Run,
|
||||
RunCancelled,
|
||||
RunCompleted,
|
||||
RunCreated,
|
||||
RunFailed,
|
||||
RunId,
|
||||
RunScope,
|
||||
RunStarted,
|
||||
ThreadId,
|
||||
)
|
||||
from deerflow.runtime.runs.schemas import DisconnectMode as CompatDisconnectMode
|
||||
from deerflow.runtime.runs.schemas import RunStatus as CompatRunStatus
|
||||
|
||||
|
||||
def test_compat_schema_exports_use_domain_enums() -> None:
|
||||
assert CompatRunStatus is RunStatus
|
||||
assert CompatDisconnectMode is DisconnectMode
|
||||
|
||||
|
||||
def test_create_run_records_pending_state_and_created_event() -> None:
|
||||
run = Run.create(
|
||||
run_id=RunId("run-1"),
|
||||
thread_id=ThreadId("thread-1"),
|
||||
assistant_id=AssistantId("lead_agent"),
|
||||
scope=RunScope.stateful,
|
||||
multitask_strategy=MultitaskStrategy.reject,
|
||||
metadata={"source": "test"},
|
||||
kwargs={"input": {"messages": []}},
|
||||
created_at="2026-01-01T00:00:00+00:00",
|
||||
)
|
||||
|
||||
assert run.status == RunStatus.pending
|
||||
assert run.run_id == "run-1"
|
||||
assert run.thread_id == "thread-1"
|
||||
assert run.assistant_id == "lead_agent"
|
||||
assert run.created_at == "2026-01-01T00:00:00+00:00"
|
||||
assert run.updated_at == "2026-01-01T00:00:00+00:00"
|
||||
|
||||
events = run.pull_events()
|
||||
assert len(events) == 1
|
||||
assert isinstance(events[0], RunCreated)
|
||||
assert events[0].metadata == {"source": "test"}
|
||||
assert run.pull_events() == ()
|
||||
|
||||
|
||||
def test_run_allows_pending_running_success_transition() -> None:
|
||||
run = Run.create(run_id=RunId("run-1"), thread_id=ThreadId("thread-1"))
|
||||
run.pull_events()
|
||||
|
||||
run.mark_started(at="2026-01-01T00:00:01+00:00")
|
||||
run.mark_completed(at="2026-01-01T00:00:02+00:00")
|
||||
|
||||
assert run.status == RunStatus.success
|
||||
assert run.updated_at == "2026-01-01T00:00:02+00:00"
|
||||
events = run.pull_events()
|
||||
assert [type(event) for event in events] == [RunStarted, RunCompleted]
|
||||
|
||||
|
||||
def test_run_records_failed_and_cancelled_domain_events() -> None:
|
||||
failed = Run.create(run_id=RunId("run-failed"), thread_id=ThreadId("thread-1"))
|
||||
failed.pull_events()
|
||||
failed.mark_started()
|
||||
failed.mark_failed("boom", at="2026-01-01T00:00:03+00:00")
|
||||
failed_events = failed.pull_events()
|
||||
|
||||
assert failed.status == RunStatus.error
|
||||
assert isinstance(failed_events[-1], RunFailed)
|
||||
assert failed_events[-1].status == RunStatus.error
|
||||
assert failed_events[-1].error == "boom"
|
||||
|
||||
cancelled = Run.create(run_id=RunId("run-cancelled"), thread_id=ThreadId("thread-1"))
|
||||
cancelled.pull_events()
|
||||
cancelled.mark_cancelled(action=CancelAction.rollback)
|
||||
cancelled_events = cancelled.pull_events()
|
||||
|
||||
assert cancelled.status == RunStatus.interrupted
|
||||
assert isinstance(cancelled_events[-1], RunCancelled)
|
||||
assert cancelled_events[-1].action == CancelAction.rollback
|
||||
|
||||
|
||||
def test_terminal_run_cannot_transition_again() -> None:
|
||||
run = Run.create(run_id=RunId("run-1"), thread_id=ThreadId("thread-1"))
|
||||
run.mark_started()
|
||||
run.mark_completed()
|
||||
|
||||
with pytest.raises(InvalidRunTransition) as exc:
|
||||
run.mark_failed("too late")
|
||||
|
||||
assert exc.value.current == RunStatus.success
|
||||
assert exc.value.target == RunStatus.error
|
||||
|
||||
|
||||
def test_domain_value_objects_validate_minimal_invariants() -> None:
|
||||
assert EventSeq(1).next() == EventSeq(2)
|
||||
with pytest.raises(ValueError, match="EventSeq"):
|
||||
EventSeq(-1)
|
||||
with pytest.raises(ValueError, match="run_id"):
|
||||
Run.create(run_id=RunId(" "), thread_id=ThreadId("thread-1"))
|
||||
Loading…
x
Reference in New Issue
Block a user