mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-02 06:48:21 +00:00
feat: asyncio-native Actor framework with supervision, middleware, and pluggable mailbox
Lightweight actor library built on asyncio primitives (~800 lines): - Actor base class with lifecycle hooks (on_started/on_stopped/on_restart) - ActorRef with tell (fire-and-forget) and ask (request-response) - Supervision: OneForOne/AllForOne strategies with restart limits - Middleware pipeline for cross-cutting concerns - Pluggable Mailbox interface (MemoryMailbox default, RedisMailbox optional) - ReplyRegistry + ReplyChannel: ask() works across any mailbox backend - System-level thread pool for blocking I/O (run_in_executor) - Dead letter handling, poison message quarantine, parallel shutdown - 22 tests + benchmark suite
This commit is contained in:
parent
9e3d484858
commit
3e17417122
2
.gitignore
vendored
2
.gitignore
vendored
@ -2,6 +2,8 @@
|
||||
docker/.cache/
|
||||
# oh-my-claudecode state
|
||||
.omc/
|
||||
# Collaborator plugin state
|
||||
.collaborator/
|
||||
# OS generated files
|
||||
.DS_Store
|
||||
*.local
|
||||
|
||||
40
backend/packages/harness/deerflow/actor/__init__.py
Normal file
40
backend/packages/harness/deerflow/actor/__init__.py
Normal file
@ -0,0 +1,40 @@
|
||||
"""Async Actor framework — lightweight, asyncio-native, supervision-ready.
|
||||
|
||||
Usage::
|
||||
|
||||
from deerflow.actor import Actor, ActorSystem
|
||||
|
||||
class Greeter(Actor):
|
||||
async def on_receive(self, message):
|
||||
return f"Hello, {message}!"
|
||||
|
||||
async def main():
|
||||
system = ActorSystem("app")
|
||||
ref = await system.spawn(Greeter, "greeter")
|
||||
reply = await ref.ask("World", timeout=5.0)
|
||||
print(reply) # Hello, World!
|
||||
await system.shutdown()
|
||||
"""
|
||||
|
||||
from .actor import Actor, ActorContext
|
||||
from .mailbox import Mailbox, MemoryMailbox
|
||||
from .middleware import Middleware
|
||||
from .ref import ActorRef, ReplyChannel
|
||||
from .supervision import AllForOneStrategy, Directive, OneForOneStrategy, SupervisorStrategy
|
||||
from .system import ActorSystem, DeadLetter
|
||||
|
||||
__all__ = [
|
||||
"Actor",
|
||||
"ActorContext",
|
||||
"ActorRef",
|
||||
"ActorSystem",
|
||||
"AllForOneStrategy",
|
||||
"DeadLetter",
|
||||
"Directive",
|
||||
"Mailbox",
|
||||
"MemoryMailbox",
|
||||
"Middleware",
|
||||
"OneForOneStrategy",
|
||||
"ReplyChannel",
|
||||
"SupervisorStrategy",
|
||||
]
|
||||
109
backend/packages/harness/deerflow/actor/actor.py
Normal file
109
backend/packages/harness/deerflow/actor/actor.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""Actor base class and per-actor context."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
|
||||
from .supervision import OneForOneStrategy, SupervisorStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .ref import ActorRef
|
||||
|
||||
# Message type variable — use Actor[MyMsg] for typed actors
|
||||
M = TypeVar("M")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class ActorContext:
|
||||
"""Per-actor runtime context, injected before ``on_started``.
|
||||
|
||||
Provides access to the actor's identity, parent, children,
|
||||
and the ability to spawn child actors.
|
||||
"""
|
||||
|
||||
__slots__ = ("_cell",)
|
||||
|
||||
def __init__(self, cell: Any) -> None:
|
||||
self._cell = cell
|
||||
|
||||
@property
|
||||
def self_ref(self) -> ActorRef:
|
||||
return self._cell.ref
|
||||
|
||||
@property
|
||||
def parent(self) -> ActorRef | None:
|
||||
p = self._cell.parent
|
||||
return p.ref if p is not None else None
|
||||
|
||||
@property
|
||||
def children(self) -> dict[str, ActorRef]:
|
||||
return {name: c.ref for name, c in self._cell.children.items()}
|
||||
|
||||
@property
|
||||
def system(self) -> Any:
|
||||
return self._cell.system
|
||||
|
||||
async def spawn(
|
||||
self,
|
||||
actor_cls: type[Actor],
|
||||
name: str,
|
||||
*,
|
||||
mailbox_size: int = 256,
|
||||
middlewares: list | None = None,
|
||||
) -> ActorRef:
|
||||
"""Spawn a child actor supervised by this actor."""
|
||||
return await self._cell.spawn_child(actor_cls, name, mailbox_size=mailbox_size, middlewares=middlewares)
|
||||
|
||||
async def run_in_executor(self, fn: Callable[..., Any], *args: Any) -> Any:
|
||||
"""Run a blocking function in the system's thread pool.
|
||||
|
||||
Usage::
|
||||
|
||||
result = await self.context.run_in_executor(requests.get, url)
|
||||
"""
|
||||
import asyncio
|
||||
executor = self._cell.system._executor
|
||||
return await asyncio.get_running_loop().run_in_executor(executor, fn, *args)
|
||||
|
||||
|
||||
class Actor(Generic[M]):
|
||||
"""Base class for all actors.
|
||||
|
||||
Type parameter ``M`` constrains the message type::
|
||||
|
||||
class Greeter(Actor[str]):
|
||||
async def on_receive(self, message: str) -> str:
|
||||
return f"Hello, {message}!"
|
||||
|
||||
class Calculator(Actor[int | tuple[str, int, int]]):
|
||||
async def on_receive(self, message: int | tuple[str, int, int]) -> int:
|
||||
...
|
||||
|
||||
Unparameterized ``Actor`` accepts ``Any`` (backward-compatible).
|
||||
"""
|
||||
|
||||
context: ActorContext
|
||||
|
||||
async def on_receive(self, message: M) -> Any:
|
||||
"""Handle an incoming message.
|
||||
|
||||
Return value is sent back as reply for ``ask`` calls.
|
||||
For ``tell`` calls, the return value is discarded.
|
||||
"""
|
||||
|
||||
async def on_started(self) -> None:
|
||||
"""Called after creation, before receiving messages."""
|
||||
|
||||
async def on_stopped(self) -> None:
|
||||
"""Called on graceful shutdown. Release resources here."""
|
||||
|
||||
async def on_restart(self, error: Exception) -> None:
|
||||
"""Called on the *new* instance before resuming after a crash."""
|
||||
|
||||
def supervisor_strategy(self) -> SupervisorStrategy:
|
||||
"""Override to customize how this actor supervises its children.
|
||||
|
||||
Default: OneForOne, up to 3 restarts per 60 seconds, always restart.
|
||||
"""
|
||||
return OneForOneStrategy()
|
||||
94
backend/packages/harness/deerflow/actor/mailbox.py
Normal file
94
backend/packages/harness/deerflow/actor/mailbox.py
Normal file
@ -0,0 +1,94 @@
|
||||
"""Pluggable mailbox abstraction — Akka-inspired enqueue/dequeue interface.
|
||||
|
||||
Built-in implementations:
|
||||
- ``MemoryMailbox``: asyncio.Queue backed (default)
|
||||
- Extend ``Mailbox`` for Redis, RabbitMQ, Kafka, etc.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Mailbox(abc.ABC):
|
||||
"""Abstract mailbox — the message queue for an actor.
|
||||
|
||||
Implementations must be async-safe for single-consumer usage.
|
||||
Multiple producers may call ``put`` concurrently.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def put(self, msg: Any) -> bool:
|
||||
"""Enqueue a message. Returns True if accepted, False if dropped."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def put_nowait(self, msg: Any) -> bool:
|
||||
"""Non-blocking enqueue. Returns True if accepted, False if dropped."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get(self) -> Any:
|
||||
"""Dequeue the next message. Blocks until available."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_nowait(self) -> Any:
|
||||
"""Non-blocking dequeue. Raises ``Empty`` if no message."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def empty(self) -> bool:
|
||||
"""Return True if no messages are queued."""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def full(self) -> bool:
|
||||
"""Return True if mailbox is at capacity."""
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Release resources. Default is no-op."""
|
||||
|
||||
|
||||
class Empty(Exception):
|
||||
"""Raised by ``get_nowait`` when mailbox is empty."""
|
||||
|
||||
|
||||
class MemoryMailbox(Mailbox):
|
||||
"""In-process mailbox backed by ``asyncio.Queue``."""
|
||||
|
||||
def __init__(self, maxsize: int = 256) -> None:
|
||||
self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=maxsize)
|
||||
self._maxsize = maxsize
|
||||
|
||||
async def put(self, msg: Any) -> bool:
|
||||
try:
|
||||
await self._queue.put(msg)
|
||||
return True
|
||||
except asyncio.QueueFull:
|
||||
return False
|
||||
|
||||
def put_nowait(self, msg: Any) -> bool:
|
||||
try:
|
||||
self._queue.put_nowait(msg)
|
||||
return True
|
||||
except asyncio.QueueFull:
|
||||
return False
|
||||
|
||||
async def get(self) -> Any:
|
||||
return await self._queue.get()
|
||||
|
||||
def get_nowait(self) -> Any:
|
||||
try:
|
||||
return self._queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
raise Empty("mailbox empty")
|
||||
|
||||
def empty(self) -> bool:
|
||||
return self._queue.empty()
|
||||
|
||||
@property
|
||||
def full(self) -> bool:
|
||||
return self._queue.full()
|
||||
|
||||
|
||||
# Type alias for mailbox factory
|
||||
MailboxFactory = type[Mailbox] | Any # Callable[[], Mailbox]
|
||||
150
backend/packages/harness/deerflow/actor/mailbox_redis.py
Normal file
150
backend/packages/harness/deerflow/actor/mailbox_redis.py
Normal file
@ -0,0 +1,150 @@
|
||||
"""Redis-backed mailbox — persistent, survives process restart.
|
||||
|
||||
Requires ``redis[hiredis]`` (``uv add redis[hiredis]``).
|
||||
|
||||
Usage::
|
||||
|
||||
import redis.asyncio as redis
|
||||
from deerflow.actor import ActorSystem
|
||||
from deerflow.actor.mailbox_redis import RedisMailbox
|
||||
|
||||
pool = redis.ConnectionPool.from_url("redis://localhost:6379")
|
||||
|
||||
system = ActorSystem("app")
|
||||
ref = await system.spawn(
|
||||
MyActor, "worker",
|
||||
mailbox=RedisMailbox(pool, "actor:inbox:worker"),
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from .mailbox import Empty, Mailbox
|
||||
from .ref import _Envelope, _Stop
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _serialize(msg: _Envelope | _Stop) -> str:
|
||||
"""Serialize an envelope to JSON for Redis storage.
|
||||
|
||||
Raises ``TypeError`` if the payload is not JSON-serializable.
|
||||
"""
|
||||
if isinstance(msg, _Stop):
|
||||
return json.dumps({"__type__": "stop"})
|
||||
try:
|
||||
return json.dumps({
|
||||
"__type__": "envelope",
|
||||
"payload": msg.payload,
|
||||
"correlation_id": msg.correlation_id,
|
||||
"reply_to": msg.reply_to,
|
||||
})
|
||||
except (TypeError, ValueError) as e:
|
||||
raise TypeError(f"Payload is not JSON-serializable: {e}. RedisMailbox requires JSON-compatible messages.") from e
|
||||
|
||||
|
||||
def _deserialize(data: str | bytes) -> _Envelope | _Stop:
|
||||
"""Deserialize a JSON string back to an envelope or stop sentinel."""
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
d = json.loads(data)
|
||||
if d.get("__type__") == "stop":
|
||||
return _Stop()
|
||||
return _Envelope(
|
||||
payload=d.get("payload"),
|
||||
sender=None,
|
||||
correlation_id=d.get("correlation_id"),
|
||||
reply_to=d.get("reply_to"),
|
||||
)
|
||||
|
||||
|
||||
class RedisMailbox(Mailbox):
|
||||
"""Mailbox backed by a Redis LIST.
|
||||
|
||||
Each actor gets its own Redis key (the ``queue_name``).
|
||||
Messages are serialized as JSON, so payloads must be JSON-compatible.
|
||||
|
||||
Args:
|
||||
pool: A ``redis.asyncio.ConnectionPool`` instance.
|
||||
queue_name: Redis key for this actor's inbox (e.g. ``"actor:inbox:worker"``).
|
||||
maxlen: Maximum queue length. 0 = unbounded. When exceeded, ``put_nowait`` returns False.
|
||||
brpop_timeout: Seconds to block on ``get()`` before retrying. Default 1s.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pool: Any,
|
||||
queue_name: str,
|
||||
*,
|
||||
maxlen: int = 0,
|
||||
brpop_timeout: float = 1.0,
|
||||
) -> None:
|
||||
self._queue_name = queue_name
|
||||
self._maxlen = maxlen
|
||||
self._brpop_timeout = brpop_timeout
|
||||
self._closed = False
|
||||
# Lazy import to avoid hard dependency on redis
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
self._redis: aioredis.Redis = aioredis.Redis(connection_pool=pool)
|
||||
except ImportError:
|
||||
raise ImportError("RedisMailbox requires 'redis' package. Install with: uv add redis[hiredis]")
|
||||
|
||||
# Lua script for atomic bounded push: check length then push
|
||||
_LUA_BOUNDED_PUSH = """
|
||||
if tonumber(ARGV[2]) > 0 and redis.call('llen', KEYS[1]) >= tonumber(ARGV[2]) then
|
||||
return 0
|
||||
end
|
||||
redis.call('lpush', KEYS[1], ARGV[1])
|
||||
return 1
|
||||
"""
|
||||
|
||||
async def put(self, msg: Any) -> bool:
|
||||
if self._closed:
|
||||
return False
|
||||
data = _serialize(msg)
|
||||
if self._maxlen > 0:
|
||||
# Atomic check+push via Lua script to avoid TOCTOU race
|
||||
result = await self._redis.evalsha_or_eval(self._LUA_BOUNDED_PUSH, 1, self._queue_name, data, self._maxlen)
|
||||
return bool(result)
|
||||
await self._redis.lpush(self._queue_name, data)
|
||||
return True
|
||||
|
||||
def put_nowait(self, msg: Any) -> bool:
|
||||
"""Redis cannot do synchronous non-blocking enqueue reliably.
|
||||
|
||||
Returns False so the caller uses dead-letter or task.cancel() fallback.
|
||||
Use ``put()`` (async) for reliable delivery.
|
||||
"""
|
||||
return False
|
||||
|
||||
async def get(self) -> Any:
|
||||
"""Blocking dequeue via BRPOP. Retries until a message arrives."""
|
||||
while not self._closed:
|
||||
result = await self._redis.brpop(self._queue_name, timeout=self._brpop_timeout)
|
||||
if result is not None:
|
||||
_, data = result
|
||||
return _deserialize(data)
|
||||
raise Empty("mailbox closed")
|
||||
|
||||
def get_nowait(self) -> Any:
|
||||
raise Empty("Redis mailbox does not support synchronous get_nowait")
|
||||
|
||||
def empty(self) -> bool:
|
||||
# Cannot query Redis synchronously. Return True so drain loops
|
||||
# terminate immediately and rely on get_nowait raising Empty.
|
||||
return True
|
||||
|
||||
@property
|
||||
def full(self) -> bool:
|
||||
# Cannot query Redis synchronously. Backpressure enforced
|
||||
# atomically inside put() via Lua script.
|
||||
return False
|
||||
|
||||
async def close(self) -> None:
|
||||
self._closed = True
|
||||
await self._redis.aclose()
|
||||
79
backend/packages/harness/deerflow/actor/middleware.py
Normal file
79
backend/packages/harness/deerflow/actor/middleware.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""Middleware pipeline — cross-cutting concerns for actors.
|
||||
|
||||
Inspired by Proto.Actor's sender/receiver middleware model.
|
||||
Middleware intercepts messages before/after the actor processes them.
|
||||
|
||||
Usage::
|
||||
|
||||
class LoggingMiddleware(Middleware):
|
||||
async def on_receive(self, ctx, message, next_fn):
|
||||
logger.info("Received: %s", message)
|
||||
result = await next_fn(ctx, message)
|
||||
logger.info("Replied: %s", result)
|
||||
return result
|
||||
|
||||
system = ActorSystem("app")
|
||||
ref = await system.spawn(MyActor, "a", middlewares=[LoggingMiddleware()])
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ActorMailboxContext:
|
||||
"""Context passed to middleware on each message."""
|
||||
|
||||
__slots__ = ("actor_ref", "sender", "message_type")
|
||||
|
||||
def __init__(self, actor_ref: Any, sender: Any, message_type: str) -> None:
|
||||
self.actor_ref = actor_ref
|
||||
self.sender = sender
|
||||
self.message_type = message_type # "tell" or "ask"
|
||||
|
||||
|
||||
# The inner handler signature: (ctx, message) -> result
|
||||
NextFn = Callable[[ActorMailboxContext, Any], Awaitable[Any]]
|
||||
|
||||
|
||||
class Middleware:
|
||||
"""Base class for actor middleware.
|
||||
|
||||
Override ``on_receive`` to intercept inbound messages.
|
||||
Must call ``await next_fn(ctx, message)`` to continue the chain.
|
||||
"""
|
||||
|
||||
async def on_receive(self, ctx: ActorMailboxContext, message: Any, next_fn: NextFn) -> Any:
|
||||
"""Intercept a message. Call next_fn to continue the chain."""
|
||||
return await next_fn(ctx, message)
|
||||
|
||||
async def on_started(self, actor_ref: Any) -> None:
|
||||
"""Called when the actor starts."""
|
||||
|
||||
async def on_stopped(self, actor_ref: Any) -> None:
|
||||
"""Called when the actor stops."""
|
||||
|
||||
async def on_restart(self, actor_ref: Any, error: Exception) -> None:
|
||||
"""Called when the actor restarts after a crash.
|
||||
|
||||
Override to reset per-actor-instance state (caches, counters, etc.)
|
||||
that should not bleed across restarts.
|
||||
"""
|
||||
|
||||
|
||||
def build_middleware_chain(middlewares: list[Middleware], handler: NextFn) -> NextFn:
|
||||
"""Build a nested middleware chain ending with *handler*.
|
||||
|
||||
Execution order: first middleware in list wraps outermost.
|
||||
``[A, B, C]`` → ``A(B(C(handler)))``
|
||||
"""
|
||||
chain = handler
|
||||
for mw in reversed(middlewares):
|
||||
outer = chain
|
||||
|
||||
async def _wrap(ctx: ActorMailboxContext, msg: Any, _mw: Middleware = mw, _next: NextFn = outer) -> Any:
|
||||
return await _mw.on_receive(ctx, msg, _next)
|
||||
|
||||
chain = _wrap
|
||||
return chain
|
||||
216
backend/packages/harness/deerflow/actor/ref.py
Normal file
216
backend/packages/harness/deerflow/actor/ref.py
Normal file
@ -0,0 +1,216 @@
|
||||
"""ActorRef — immutable, serializable reference to an actor."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .system import _ActorCell
|
||||
|
||||
|
||||
class ActorRef:
|
||||
"""Immutable handle for sending messages to an actor.
|
||||
|
||||
Users never construct this directly — it is returned by
|
||||
``ActorSystem.spawn`` or ``ActorContext.spawn``.
|
||||
"""
|
||||
|
||||
__slots__ = ("_cell",)
|
||||
|
||||
def __init__(self, cell: _ActorCell) -> None:
|
||||
self._cell = cell
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._cell.name
|
||||
|
||||
@property
|
||||
def path(self) -> str:
|
||||
return self._cell.path
|
||||
|
||||
@property
|
||||
def is_alive(self) -> bool:
|
||||
return not self._cell.stopped
|
||||
|
||||
async def tell(self, message: Any, *, sender: ActorRef | None = None) -> None:
|
||||
"""Fire-and-forget message delivery."""
|
||||
if self._cell.stopped:
|
||||
self._cell.system._dead_letter(self, message, sender)
|
||||
return
|
||||
await self._cell.enqueue(_Envelope(message, sender))
|
||||
|
||||
async def ask(self, message: Any, *, timeout: float = 5.0) -> Any:
|
||||
"""Request-response with timeout.
|
||||
|
||||
Uses correlation ID + ReplyRegistry instead of passing a Future
|
||||
through the mailbox. This makes ask work with any Mailbox backend
|
||||
(memory, Redis, RabbitMQ, etc.).
|
||||
|
||||
Raises ``asyncio.TimeoutError`` if the actor doesn't reply in time.
|
||||
Raises the actor's exception if ``on_receive`` fails.
|
||||
"""
|
||||
if self._cell.stopped:
|
||||
raise ActorStoppedError(f"Actor {self.path} is stopped")
|
||||
corr_id = uuid.uuid4().hex
|
||||
future = self._cell.system._replies.register(corr_id)
|
||||
try:
|
||||
envelope = _Envelope(message, sender=None, correlation_id=corr_id, reply_to=self._cell.system.system_id)
|
||||
await self._cell.enqueue(envelope)
|
||||
return await asyncio.wait_for(future, timeout=timeout)
|
||||
finally:
|
||||
self._cell.system._replies.discard(corr_id)
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Request graceful shutdown."""
|
||||
self._cell.request_stop()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
alive = "alive" if self.is_alive else "dead"
|
||||
return f"ActorRef({self.path}, {alive})"
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if isinstance(other, ActorRef):
|
||||
return self._cell is other._cell
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return id(self._cell)
|
||||
|
||||
|
||||
class ActorStoppedError(Exception):
|
||||
"""Raised when sending to a stopped actor via ask."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal message wrappers (serializable — no Future objects)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _Envelope:
|
||||
"""Message envelope flowing through mailboxes.
|
||||
|
||||
All fields are serializable (no asyncio.Future). This is what
|
||||
enables ask() to work across MQ-backed mailboxes.
|
||||
"""
|
||||
|
||||
__slots__ = ("payload", "sender", "correlation_id", "reply_to")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
payload: Any,
|
||||
sender: ActorRef | None = None,
|
||||
correlation_id: str | None = None,
|
||||
reply_to: str | None = None,
|
||||
) -> None:
|
||||
self.payload = payload
|
||||
self.sender = sender
|
||||
self.correlation_id = correlation_id
|
||||
self.reply_to = reply_to # System ID of the caller (for cross-process reply routing)
|
||||
|
||||
|
||||
class _Stop:
|
||||
"""Sentinel placed on the mailbox to trigger graceful shutdown."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ReplyRegistry — maps correlation_id → Future (lives on ActorSystem)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _ReplyRegistry:
|
||||
"""In-memory registry mapping correlation IDs to Futures.
|
||||
|
||||
Used by ask() to receive replies without putting Futures in the mailbox.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pending: dict[str, asyncio.Future[Any]] = {}
|
||||
|
||||
def register(self, corr_id: str) -> asyncio.Future[Any]:
|
||||
"""Create and register a Future for a correlation ID."""
|
||||
future: asyncio.Future[Any] = asyncio.get_running_loop().create_future()
|
||||
self._pending[corr_id] = future
|
||||
return future
|
||||
|
||||
def resolve(self, corr_id: str, result: Any) -> None:
|
||||
"""Complete a pending ask with a result."""
|
||||
future = self._pending.pop(corr_id, None)
|
||||
if future is not None and not future.done():
|
||||
future.set_result(result)
|
||||
|
||||
def reject(self, corr_id: str, error: Exception) -> None:
|
||||
"""Complete a pending ask with an error."""
|
||||
future = self._pending.pop(corr_id, None)
|
||||
if future is not None and not future.done():
|
||||
future.set_exception(error)
|
||||
|
||||
def discard(self, corr_id: str) -> None:
|
||||
"""Remove a pending entry (e.g. on timeout)."""
|
||||
self._pending.pop(corr_id, None)
|
||||
|
||||
def reject_all(self, error: Exception) -> None:
|
||||
"""Reject all pending asks (e.g. on system shutdown)."""
|
||||
for future in self._pending.values():
|
||||
if not future.done():
|
||||
future.set_exception(error)
|
||||
self._pending.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ReplyChannel — abstraction for routing replies (local or cross-process)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _ReplyMessage:
|
||||
"""Reply payload sent through ReplyChannel.
|
||||
|
||||
Carries the original exception object for local delivery (preserves type).
|
||||
For cross-process serialization, use ``to_dict``/``from_dict``.
|
||||
"""
|
||||
|
||||
__slots__ = ("correlation_id", "result", "error", "exception")
|
||||
|
||||
def __init__(self, correlation_id: str, result: Any = None, error: str | None = None, exception: Exception | None = None) -> None:
|
||||
self.correlation_id = correlation_id
|
||||
self.result = result
|
||||
self.error = error
|
||||
self.exception = exception # Original exception (local only, not serializable)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Serialize for cross-process transport (exception becomes string)."""
|
||||
return {"correlation_id": self.correlation_id, "result": self.result, "error": self.error}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any]) -> _ReplyMessage:
|
||||
return cls(d["correlation_id"], d.get("result"), d.get("error"))
|
||||
|
||||
|
||||
class ReplyChannel:
|
||||
"""Routes replies from actor back to the caller's ReplyRegistry.
|
||||
|
||||
Default implementation: resolve locally (same process).
|
||||
Override ``send_reply`` for cross-process routing (e.g. via Redis pub/sub).
|
||||
"""
|
||||
|
||||
async def send_reply(self, reply_to: str, reply: _ReplyMessage, local_registry: _ReplyRegistry) -> None:
|
||||
"""Deliver a reply to the system identified by *reply_to*.
|
||||
|
||||
Default: assumes reply_to is the local system → resolve directly.
|
||||
Override for MQ-backed cross-process delivery.
|
||||
"""
|
||||
if reply.exception is not None:
|
||||
# Local: preserve original exception type
|
||||
local_registry.reject(reply.correlation_id, reply.exception)
|
||||
elif reply.error is not None:
|
||||
# Cross-process: exception was serialized to string
|
||||
local_registry.reject(reply.correlation_id, RuntimeError(reply.error))
|
||||
else:
|
||||
local_registry.resolve(reply.correlation_id, reply.result)
|
||||
|
||||
async def start_listener(self, system_id: str, registry: _ReplyRegistry) -> None:
|
||||
"""Start listening for inbound replies (no-op for local)."""
|
||||
|
||||
async def stop_listener(self) -> None:
|
||||
"""Stop the reply listener (no-op for local)."""
|
||||
75
backend/packages/harness/deerflow/actor/supervision.py
Normal file
75
backend/packages/harness/deerflow/actor/supervision.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""Supervision strategies — Erlang/Akka-inspired fault tolerance."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import time
|
||||
from collections import deque
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Directive(enum.Enum):
|
||||
"""What a supervisor should do when a child fails."""
|
||||
|
||||
resume = "resume" # ignore error, keep processing
|
||||
restart = "restart" # discard state, create fresh instance
|
||||
stop = "stop" # terminate the child permanently
|
||||
escalate = "escalate" # propagate to grandparent
|
||||
|
||||
|
||||
class SupervisorStrategy:
|
||||
"""Base class for supervision strategies.
|
||||
|
||||
Args:
|
||||
max_restarts: Maximum restarts allowed within *within_seconds*.
|
||||
Exceeding this limit stops the child permanently.
|
||||
within_seconds: Time window for restart counting.
|
||||
decider: Maps exception → Directive. Default: always restart.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
max_restarts: int = 3,
|
||||
within_seconds: float = 60.0,
|
||||
decider: Callable[[Exception], Directive] | None = None,
|
||||
) -> None:
|
||||
self.max_restarts = max_restarts
|
||||
self.within_seconds = within_seconds
|
||||
self.decider = decider or (lambda _: Directive.restart)
|
||||
self._restart_timestamps: dict[str, deque[float]] = {}
|
||||
|
||||
def decide(self, error: Exception) -> Directive:
|
||||
return self.decider(error)
|
||||
|
||||
def record_restart(self, child_name: str) -> bool:
|
||||
"""Record a restart and return True if within limits."""
|
||||
now = time.monotonic()
|
||||
if child_name not in self._restart_timestamps:
|
||||
self._restart_timestamps[child_name] = deque()
|
||||
ts = self._restart_timestamps[child_name]
|
||||
# Purge old entries outside the window
|
||||
cutoff = now - self.within_seconds
|
||||
while ts and ts[0] < cutoff:
|
||||
ts.popleft()
|
||||
ts.append(now)
|
||||
return len(ts) <= self.max_restarts
|
||||
|
||||
def apply_to_children(self, failed_child: str, all_children: list[str]) -> list[str]:
|
||||
"""Return which children should be affected by the directive."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OneForOneStrategy(SupervisorStrategy):
|
||||
"""Only the failed child is affected."""
|
||||
|
||||
def apply_to_children(self, failed_child: str, all_children: list[str]) -> list[str]:
|
||||
return [failed_child]
|
||||
|
||||
|
||||
class AllForOneStrategy(SupervisorStrategy):
|
||||
"""All children are affected when any one fails."""
|
||||
|
||||
def apply_to_children(self, failed_child: str, all_children: list[str]) -> list[str]:
|
||||
return list(all_children)
|
||||
381
backend/packages/harness/deerflow/actor/system.py
Normal file
381
backend/packages/harness/deerflow/actor/system.py
Normal file
@ -0,0 +1,381 @@
|
||||
"""ActorSystem — top-level actor container and lifecycle manager."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .actor import Actor, ActorContext
|
||||
from .mailbox import Empty, Mailbox, MemoryMailbox
|
||||
from .middleware import ActorMailboxContext, Middleware, NextFn, build_middleware_chain
|
||||
from .ref import ActorRef, ActorStoppedError, ReplyChannel, _Envelope, _ReplyMessage, _ReplyRegistry, _Stop
|
||||
from .supervision import Directive, SupervisorStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Timeout for middleware lifecycle hooks (on_started/on_stopped)
|
||||
_MIDDLEWARE_HOOK_TIMEOUT = 10.0
|
||||
|
||||
# Maximum dead letters kept in memory
|
||||
_MAX_DEAD_LETTERS = 10000
|
||||
|
||||
# Maximum consecutive failures before a root actor poison-quarantines a message
|
||||
_MAX_CONSECUTIVE_FAILURES = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeadLetter:
|
||||
"""A message that could not be delivered."""
|
||||
|
||||
recipient: ActorRef
|
||||
message: Any
|
||||
sender: ActorRef | None
|
||||
|
||||
|
||||
class ActorSystem:
|
||||
"""Top-level actor container.
|
||||
|
||||
Manages root actors and provides the dead letter sink.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "system",
|
||||
*,
|
||||
max_dead_letters: int = _MAX_DEAD_LETTERS,
|
||||
executor_workers: int | None = 4,
|
||||
reply_channel: ReplyChannel | None = None,
|
||||
) -> None:
|
||||
import uuid as _uuid
|
||||
self.name = name
|
||||
self.system_id = f"{name}-{_uuid.uuid4().hex[:8]}"
|
||||
self._root_cells: dict[str, _ActorCell] = {}
|
||||
self._dead_letters: deque[DeadLetter] = deque(maxlen=max_dead_letters)
|
||||
self._on_dead_letter: list[Any] = []
|
||||
self._shutting_down = False
|
||||
self._replies = _ReplyRegistry()
|
||||
self._reply_channel = reply_channel or ReplyChannel()
|
||||
# Shared thread pool for actors to run blocking I/O
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
self._executor = ThreadPoolExecutor(max_workers=executor_workers, thread_name_prefix=f"actor-{name}") if executor_workers else None
|
||||
|
||||
async def spawn(
|
||||
self,
|
||||
actor_cls: type[Actor],
|
||||
name: str,
|
||||
*,
|
||||
mailbox_size: int = 256,
|
||||
mailbox: Mailbox | None = None,
|
||||
middlewares: list[Middleware] | None = None,
|
||||
) -> ActorRef:
|
||||
"""Spawn a root-level actor.
|
||||
|
||||
Args:
|
||||
mailbox: Custom mailbox instance. If None, uses MemoryMailbox(mailbox_size).
|
||||
"""
|
||||
if name in self._root_cells:
|
||||
raise ValueError(f"Root actor '{name}' already exists")
|
||||
cell = _ActorCell(
|
||||
actor_cls=actor_cls,
|
||||
name=name,
|
||||
parent=None,
|
||||
system=self,
|
||||
mailbox=mailbox or MemoryMailbox(mailbox_size),
|
||||
middlewares=middlewares or [],
|
||||
)
|
||||
self._root_cells[name] = cell
|
||||
await cell.start()
|
||||
return cell.ref
|
||||
|
||||
async def shutdown(self, *, timeout: float = 10.0) -> None:
|
||||
"""Gracefully stop all actors."""
|
||||
self._shutting_down = True
|
||||
tasks = []
|
||||
for cell in list(self._root_cells.values()):
|
||||
cell.request_stop()
|
||||
if cell.task is not None:
|
||||
tasks.append(cell.task)
|
||||
if tasks:
|
||||
await asyncio.wait(tasks, timeout=timeout)
|
||||
self._root_cells.clear()
|
||||
self._replies.reject_all(ActorStoppedError("ActorSystem shutting down"))
|
||||
await self._reply_channel.stop_listener()
|
||||
if self._executor is not None:
|
||||
self._executor.shutdown(wait=False)
|
||||
logger.info("ActorSystem '%s' shut down (%d dead letters)", self.name, len(self._dead_letters))
|
||||
|
||||
def _dead_letter(self, recipient: ActorRef, message: Any, sender: ActorRef | None) -> None:
|
||||
dl = DeadLetter(recipient=recipient, message=message, sender=sender)
|
||||
self._dead_letters.append(dl)
|
||||
for cb in self._on_dead_letter:
|
||||
try:
|
||||
cb(dl)
|
||||
except Exception:
|
||||
pass
|
||||
logger.debug("Dead letter: %s → %s", type(message).__name__, recipient.path)
|
||||
|
||||
def on_dead_letter(self, callback: Any) -> None:
|
||||
"""Register a dead letter listener."""
|
||||
self._on_dead_letter.append(callback)
|
||||
|
||||
@property
|
||||
def dead_letters(self) -> list[DeadLetter]:
|
||||
return list(self._dead_letters)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _ActorCell — internal runtime wrapper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _ActorCell:
|
||||
"""Runtime container for a single actor instance.
|
||||
|
||||
Manages the mailbox, processing loop, children, and supervision.
|
||||
Not part of the public API.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
actor_cls: type[Actor],
|
||||
name: str,
|
||||
parent: _ActorCell | None,
|
||||
system: ActorSystem,
|
||||
mailbox: Mailbox,
|
||||
middlewares: list[Middleware] | None = None,
|
||||
) -> None:
|
||||
self.actor_cls = actor_cls
|
||||
self.name = name
|
||||
self.parent = parent
|
||||
self.system = system
|
||||
self.children: dict[str, _ActorCell] = {}
|
||||
self.mailbox = mailbox
|
||||
self.ref = ActorRef(self)
|
||||
self.actor: Actor | None = None
|
||||
self.task: asyncio.Task[None] | None = None
|
||||
self.stopped = False
|
||||
self._supervisor_strategy: SupervisorStrategy | None = None
|
||||
self._middlewares = middlewares or []
|
||||
self._receive_chain: NextFn | None = None
|
||||
# Cache path (immutable after init — parent never changes)
|
||||
parts: list[str] = []
|
||||
cell: _ActorCell | None = self
|
||||
while cell is not None:
|
||||
parts.append(cell.name)
|
||||
cell = cell.parent
|
||||
parts.append(system.name)
|
||||
self.path = "/" + "/".join(reversed(parts))
|
||||
|
||||
async def start(self) -> None:
|
||||
self.actor = self.actor_cls()
|
||||
self.actor.context = ActorContext(self)
|
||||
async def _inner_handler(_ctx: ActorMailboxContext, message: Any) -> Any:
|
||||
return await self.actor.on_receive(message) # type: ignore[union-attr]
|
||||
if self._middlewares:
|
||||
self._receive_chain = build_middleware_chain(self._middlewares, _inner_handler)
|
||||
else:
|
||||
self._receive_chain = _inner_handler
|
||||
# Notify middleware of start (with timeout to prevent blocking)
|
||||
for mw in self._middlewares:
|
||||
try:
|
||||
await asyncio.wait_for(mw.on_started(self.ref), timeout=_MIDDLEWARE_HOOK_TIMEOUT)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Middleware %s.on_started timed out for %s", type(mw).__name__, self.path)
|
||||
await self.actor.on_started()
|
||||
self.task = asyncio.create_task(self._run(), name=f"actor:{self.path}")
|
||||
|
||||
async def enqueue(self, msg: _Envelope | _Stop) -> None:
|
||||
if not self.mailbox.put_nowait(msg):
|
||||
if isinstance(msg, _Envelope) and msg.correlation_id is not None:
|
||||
self.system._replies.reject(msg.correlation_id, RuntimeError(f"Mailbox full: {self.path}"))
|
||||
elif isinstance(msg, _Envelope):
|
||||
self.system._dead_letter(self.ref, msg.payload, msg.sender)
|
||||
|
||||
def request_stop(self) -> None:
|
||||
"""Request graceful shutdown. Falls back to task.cancel() if mailbox full."""
|
||||
if not self.stopped:
|
||||
if not self.mailbox.put_nowait(_Stop()):
|
||||
if self.task is not None and not self.task.done():
|
||||
self.task.cancel()
|
||||
else:
|
||||
self.stopped = True
|
||||
|
||||
async def spawn_child(
|
||||
self,
|
||||
actor_cls: type[Actor],
|
||||
name: str,
|
||||
*,
|
||||
mailbox_size: int = 256,
|
||||
mailbox: Mailbox | None = None,
|
||||
middlewares: list[Middleware] | None = None,
|
||||
) -> ActorRef:
|
||||
if name in self.children:
|
||||
raise ValueError(f"Child '{name}' already exists under {self.path}")
|
||||
child = _ActorCell(
|
||||
actor_cls=actor_cls,
|
||||
name=name,
|
||||
parent=self,
|
||||
system=self.system,
|
||||
mailbox=mailbox or MemoryMailbox(mailbox_size),
|
||||
middlewares=middlewares or [],
|
||||
)
|
||||
self.children[name] = child
|
||||
await child.start()
|
||||
return child.ref
|
||||
|
||||
# -- Processing loop -------------------------------------------------------
|
||||
|
||||
async def _run(self) -> None:
|
||||
consecutive_failures = 0
|
||||
try:
|
||||
while not self.stopped:
|
||||
try:
|
||||
msg = await self.mailbox.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
if isinstance(msg, _Stop):
|
||||
break
|
||||
|
||||
try:
|
||||
if not isinstance(msg, _Envelope):
|
||||
continue
|
||||
msg_type = "ask" if msg.correlation_id else "tell"
|
||||
ctx = ActorMailboxContext(self.ref, msg.sender, msg_type)
|
||||
result = await self._receive_chain(ctx, msg.payload) # type: ignore[misc]
|
||||
if msg.correlation_id is not None:
|
||||
reply = _ReplyMessage(msg.correlation_id, result=result)
|
||||
await self.system._reply_channel.send_reply(msg.reply_to or self.system.system_id, reply, self.system._replies)
|
||||
consecutive_failures = 0
|
||||
except Exception as exc:
|
||||
if isinstance(msg, _Envelope) and msg.correlation_id is not None:
|
||||
reply = _ReplyMessage(msg.correlation_id, error=str(exc), exception=exc)
|
||||
await self.system._reply_channel.send_reply(msg.reply_to or self.system.system_id, reply, self.system._replies)
|
||||
if self.parent is not None:
|
||||
await self.parent._handle_child_failure(self, exc)
|
||||
else:
|
||||
consecutive_failures += 1
|
||||
logger.error("Uncaught error in root actor %s (%d/%d): %s", self.path, consecutive_failures, _MAX_CONSECUTIVE_FAILURES, exc)
|
||||
if consecutive_failures >= _MAX_CONSECUTIVE_FAILURES:
|
||||
logger.error("Root actor %s hit consecutive failure limit — stopping", self.path)
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass # Fall through to _shutdown
|
||||
finally:
|
||||
await self._shutdown()
|
||||
|
||||
async def _shutdown(self) -> None:
|
||||
self.stopped = True
|
||||
# Parallel child shutdown prevents cascading timeouts.
|
||||
child_tasks = []
|
||||
for child in list(self.children.values()):
|
||||
child.request_stop()
|
||||
if child.task is not None:
|
||||
child_tasks.append(child.task)
|
||||
if child_tasks:
|
||||
_, pending = await asyncio.wait(child_tasks, timeout=10.0)
|
||||
for t in pending:
|
||||
t.cancel()
|
||||
# Mark leaked children as stopped
|
||||
for child in self.children.values():
|
||||
if child.task is t:
|
||||
child.stopped = True
|
||||
# Drain mailbox → dead letters (use try/except to handle all backends)
|
||||
while True:
|
||||
try:
|
||||
msg = self.mailbox.get_nowait()
|
||||
except Empty:
|
||||
break
|
||||
if isinstance(msg, _Envelope):
|
||||
if msg.correlation_id is not None:
|
||||
self.system._replies.reject(msg.correlation_id, ActorStoppedError(f"Actor {self.path} stopped"))
|
||||
else:
|
||||
self.system._dead_letter(self.ref, msg.payload, msg.sender)
|
||||
# Lifecycle hook
|
||||
for mw in self._middlewares:
|
||||
try:
|
||||
await asyncio.wait_for(mw.on_stopped(self.ref), timeout=_MIDDLEWARE_HOOK_TIMEOUT)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Middleware %s.on_stopped timed out for %s", type(mw).__name__, self.path)
|
||||
except Exception:
|
||||
logger.exception("Error in middleware on_stopped for %s", self.path)
|
||||
if self.actor is not None:
|
||||
try:
|
||||
await self.actor.on_stopped()
|
||||
except Exception:
|
||||
logger.exception("Error in on_stopped for %s", self.path)
|
||||
# Remove from parent
|
||||
if self.parent is not None:
|
||||
self.parent.children.pop(self.name, None)
|
||||
|
||||
# -- Supervision -----------------------------------------------------------
|
||||
|
||||
def _get_supervisor_strategy(self) -> SupervisorStrategy:
|
||||
if self._supervisor_strategy is None:
|
||||
self._supervisor_strategy = self.actor.supervisor_strategy() # type: ignore[union-attr]
|
||||
return self._supervisor_strategy
|
||||
|
||||
async def _handle_child_failure(self, child: _ActorCell, error: Exception) -> None:
|
||||
strategy = self._get_supervisor_strategy()
|
||||
directive = strategy.decide(error)
|
||||
|
||||
affected = strategy.apply_to_children(child.name, list(self.children.keys()))
|
||||
|
||||
if directive == Directive.resume:
|
||||
logger.info("Supervisor %s: resume %s after %s", self.path, child.path, type(error).__name__)
|
||||
return
|
||||
|
||||
if directive == Directive.stop:
|
||||
for name in affected:
|
||||
c = self.children.get(name)
|
||||
if c is not None:
|
||||
c.request_stop()
|
||||
logger.info("Supervisor %s: stop %s after %s", self.path, [self.children[n].path for n in affected if n in self.children], type(error).__name__)
|
||||
return
|
||||
|
||||
if directive == Directive.escalate:
|
||||
logger.info("Supervisor %s: escalate %s", self.path, type(error).__name__)
|
||||
raise error
|
||||
|
||||
if directive == Directive.restart:
|
||||
for name in affected:
|
||||
c = self.children.get(name)
|
||||
if c is None:
|
||||
continue
|
||||
if not strategy.record_restart(name):
|
||||
logger.warning("Supervisor %s: child %s exceeded restart limit — stopping", self.path, c.path)
|
||||
c.request_stop()
|
||||
continue
|
||||
await self._restart_child(c, error)
|
||||
|
||||
async def _restart_child(self, child: _ActorCell, error: Exception) -> None:
|
||||
logger.info("Supervisor %s: restarting %s after %s", self.path, child.path, type(error).__name__)
|
||||
# Stop the old actor (but keep the cell and mailbox)
|
||||
old_actor = child.actor
|
||||
if old_actor is not None:
|
||||
try:
|
||||
await old_actor.on_stopped()
|
||||
except Exception:
|
||||
logger.exception("Error in on_stopped during restart of %s", child.path)
|
||||
|
||||
# Notify middleware of restart (reset per-instance state)
|
||||
for mw in child._middlewares:
|
||||
try:
|
||||
await asyncio.wait_for(mw.on_restart(child.ref, error), timeout=_MIDDLEWARE_HOOK_TIMEOUT)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Middleware %s.on_restart timed out for %s", type(mw).__name__, child.path)
|
||||
except Exception:
|
||||
logger.exception("Error in middleware on_restart for %s", child.path)
|
||||
# Create fresh instance
|
||||
new_actor = child.actor_cls()
|
||||
new_actor.context = ActorContext(child)
|
||||
child.actor = new_actor
|
||||
try:
|
||||
await new_actor.on_restart(error)
|
||||
await new_actor.on_started()
|
||||
except Exception:
|
||||
logger.exception("Error during restart initialization of %s", child.path)
|
||||
child.request_stop()
|
||||
263
backend/tests/bench_actor.py
Normal file
263
backend/tests/bench_actor.py
Normal file
@ -0,0 +1,263 @@
|
||||
"""Actor framework benchmarks — throughput, latency, concurrency."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import statistics
|
||||
|
||||
from deerflow.actor import Actor, ActorSystem, Middleware
|
||||
|
||||
|
||||
class NoopActor(Actor):
|
||||
async def on_receive(self, message):
|
||||
return message
|
||||
|
||||
|
||||
class CounterActor(Actor):
|
||||
async def on_started(self):
|
||||
self.count = 0
|
||||
|
||||
async def on_receive(self, message):
|
||||
self.count += 1
|
||||
return self.count
|
||||
|
||||
|
||||
class ChainActor(Actor):
|
||||
"""Forwards message to next actor in chain."""
|
||||
next_ref = None
|
||||
|
||||
async def on_receive(self, message):
|
||||
if self.next_ref is not None:
|
||||
return await self.next_ref.ask(message)
|
||||
return message
|
||||
|
||||
|
||||
class ComputeActor(Actor):
|
||||
"""Simulates CPU work via thread pool."""
|
||||
async def on_receive(self, message):
|
||||
def fib(n):
|
||||
a, b = 0, 1
|
||||
for _ in range(n):
|
||||
a, b = b, a + b
|
||||
return a
|
||||
return await self.context.run_in_executor(fib, message)
|
||||
|
||||
|
||||
class CountMiddleware(Middleware):
|
||||
def __init__(self):
|
||||
self.count = 0
|
||||
|
||||
async def on_receive(self, ctx, message, next_fn):
|
||||
self.count += 1
|
||||
return await next_fn(ctx, message)
|
||||
|
||||
|
||||
def fmt(n):
|
||||
if n >= 1_000_000:
|
||||
return f"{n/1_000_000:.1f}M"
|
||||
if n >= 1_000:
|
||||
return f"{n/1_000:.0f}K"
|
||||
return str(n)
|
||||
|
||||
|
||||
async def bench_tell_throughput(n=100_000):
|
||||
"""Measure tell (fire-and-forget) throughput."""
|
||||
system = ActorSystem("bench")
|
||||
ref = await system.spawn(CounterActor, "counter", mailbox_size=n + 10)
|
||||
|
||||
start = time.perf_counter()
|
||||
for _ in range(n):
|
||||
await ref.tell("inc")
|
||||
# Wait for all messages to be processed
|
||||
count = await ref.ask("get", timeout=30.0)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
await system.shutdown()
|
||||
rate = n / elapsed
|
||||
print(f" tell throughput: {fmt(n)} msgs in {elapsed:.2f}s = {fmt(int(rate))}/s")
|
||||
|
||||
|
||||
async def bench_ask_throughput(n=50_000):
|
||||
"""Measure ask (request-response) throughput."""
|
||||
system = ActorSystem("bench")
|
||||
ref = await system.spawn(NoopActor, "echo")
|
||||
|
||||
start = time.perf_counter()
|
||||
for _ in range(n):
|
||||
await ref.ask("ping")
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
await system.shutdown()
|
||||
rate = n / elapsed
|
||||
print(f" ask throughput: {fmt(n)} msgs in {elapsed:.2f}s = {fmt(int(rate))}/s")
|
||||
|
||||
|
||||
async def bench_ask_latency(n=10_000):
|
||||
"""Measure ask round-trip latency percentiles."""
|
||||
system = ActorSystem("bench")
|
||||
ref = await system.spawn(NoopActor, "echo")
|
||||
|
||||
# Warmup
|
||||
for _ in range(100):
|
||||
await ref.ask("warmup")
|
||||
|
||||
latencies = []
|
||||
for _ in range(n):
|
||||
t0 = time.perf_counter()
|
||||
await ref.ask("ping")
|
||||
latencies.append((time.perf_counter() - t0) * 1_000_000) # microseconds
|
||||
|
||||
await system.shutdown()
|
||||
latencies.sort()
|
||||
p50 = latencies[len(latencies) // 2]
|
||||
p99 = latencies[int(len(latencies) * 0.99)]
|
||||
p999 = latencies[int(len(latencies) * 0.999)]
|
||||
print(f" ask latency: p50={p50:.0f}µs p99={p99:.0f}µs p99.9={p999:.0f}µs")
|
||||
|
||||
|
||||
async def bench_concurrent_actors(num_actors=1000, msgs_per_actor=100):
|
||||
"""Measure throughput with many concurrent actors."""
|
||||
system = ActorSystem("bench")
|
||||
refs = []
|
||||
for i in range(num_actors):
|
||||
refs.append(await system.spawn(CounterActor, f"a{i}", mailbox_size=msgs_per_actor + 10))
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
async def send_batch(ref, n):
|
||||
for i in range(n):
|
||||
await ref.tell("inc")
|
||||
# Yield control every 50 msgs so actor loops can drain
|
||||
if i % 50 == 49:
|
||||
await asyncio.sleep(0)
|
||||
return await ref.ask("get", timeout=30.0)
|
||||
|
||||
results = await asyncio.gather(*[send_batch(r, msgs_per_actor) for r in refs])
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
total = num_actors * msgs_per_actor
|
||||
delivered = sum(results)
|
||||
rate = total / elapsed
|
||||
loss = total - delivered
|
||||
print(f" {num_actors} actors × {msgs_per_actor} msgs: {fmt(total)} in {elapsed:.2f}s = {fmt(int(rate))}/s (loss: {loss})")
|
||||
|
||||
await system.shutdown()
|
||||
|
||||
|
||||
async def bench_actor_chain(depth=100):
|
||||
"""Measure ask latency through a chain of actors (hop overhead)."""
|
||||
system = ActorSystem("bench")
|
||||
refs = []
|
||||
for i in range(depth):
|
||||
refs.append(await system.spawn(ChainActor, f"c{i}"))
|
||||
# Link chain: c0 → c1 → ... → c99
|
||||
for i in range(depth - 1):
|
||||
refs[i]._cell.actor.next_ref = refs[i + 1]
|
||||
|
||||
start = time.perf_counter()
|
||||
result = await refs[0].ask("ping", timeout=30.0)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
assert result == "ping"
|
||||
per_hop = elapsed / depth * 1_000_000 # µs
|
||||
print(f" chain {depth} hops: {elapsed*1000:.1f}ms total, {per_hop:.0f}µs/hop")
|
||||
|
||||
await system.shutdown()
|
||||
|
||||
|
||||
async def bench_middleware_overhead(n=50_000):
|
||||
"""Measure overhead of middleware pipeline."""
|
||||
mw = CountMiddleware()
|
||||
|
||||
system_plain = ActorSystem("plain")
|
||||
ref_plain = await system_plain.spawn(NoopActor, "echo")
|
||||
|
||||
system_mw = ActorSystem("mw")
|
||||
ref_mw = await system_mw.spawn(NoopActor, "echo", middlewares=[mw])
|
||||
|
||||
# Plain
|
||||
t0 = time.perf_counter()
|
||||
for _ in range(n):
|
||||
await ref_plain.ask("p")
|
||||
plain_elapsed = time.perf_counter() - t0
|
||||
|
||||
# With middleware
|
||||
t0 = time.perf_counter()
|
||||
for _ in range(n):
|
||||
await ref_mw.ask("p")
|
||||
mw_elapsed = time.perf_counter() - t0
|
||||
|
||||
overhead = ((mw_elapsed - plain_elapsed) / plain_elapsed) * 100
|
||||
print(f" middleware overhead: {overhead:+.1f}% ({fmt(n)} ask calls, 1 middleware)")
|
||||
|
||||
await system_plain.shutdown()
|
||||
await system_mw.shutdown()
|
||||
|
||||
|
||||
async def bench_executor_parallel(num_tasks=16):
|
||||
"""Measure thread pool parallelism with CPU work."""
|
||||
system = ActorSystem("bench", executor_workers=8)
|
||||
refs = [await system.spawn(ComputeActor, f"cpu{i}") for i in range(num_tasks)]
|
||||
|
||||
start = time.perf_counter()
|
||||
results = await asyncio.gather(*[r.ask(10_000, timeout=30.0) for r in refs])
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
print(f" executor parallel: {num_tasks} fib(10K) in {elapsed*1000:.0f}ms ({num_tasks/elapsed:.0f} tasks/s)")
|
||||
|
||||
await system.shutdown()
|
||||
|
||||
|
||||
async def bench_spawn_teardown(n=5000):
|
||||
"""Measure actor spawn + shutdown speed."""
|
||||
system = ActorSystem("bench")
|
||||
|
||||
start = time.perf_counter()
|
||||
refs = []
|
||||
for i in range(n):
|
||||
refs.append(await system.spawn(NoopActor, f"a{i}"))
|
||||
spawn_elapsed = time.perf_counter() - start
|
||||
|
||||
start = time.perf_counter()
|
||||
await system.shutdown()
|
||||
shutdown_elapsed = time.perf_counter() - start
|
||||
|
||||
print(f" spawn {n}: {spawn_elapsed*1000:.0f}ms ({n/spawn_elapsed:.0f}/s)")
|
||||
print(f" shutdown {n}: {shutdown_elapsed*1000:.0f}ms")
|
||||
|
||||
|
||||
async def main():
|
||||
print("=" * 60)
|
||||
print(" Actor Framework Benchmarks")
|
||||
print("=" * 60)
|
||||
print()
|
||||
|
||||
print("[Throughput]")
|
||||
await bench_tell_throughput()
|
||||
await bench_ask_throughput()
|
||||
print()
|
||||
|
||||
print("[Latency]")
|
||||
await bench_ask_latency()
|
||||
await bench_actor_chain()
|
||||
print()
|
||||
|
||||
print("[Concurrency]")
|
||||
await bench_concurrent_actors()
|
||||
await bench_executor_parallel()
|
||||
print()
|
||||
|
||||
print("[Overhead]")
|
||||
await bench_middleware_overhead()
|
||||
print()
|
||||
|
||||
print("[Lifecycle]")
|
||||
await bench_spawn_teardown()
|
||||
print()
|
||||
|
||||
print("=" * 60)
|
||||
print(" Done")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
442
backend/tests/test_actor.py
Normal file
442
backend/tests/test_actor.py
Normal file
@ -0,0 +1,442 @@
|
||||
"""Tests for the async Actor framework."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.actor import (
|
||||
Actor,
|
||||
ActorRef,
|
||||
ActorSystem,
|
||||
AllForOneStrategy,
|
||||
Directive,
|
||||
Middleware,
|
||||
OneForOneStrategy,
|
||||
)
|
||||
from deerflow.actor.ref import ActorStoppedError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Basic actors for testing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class EchoActor(Actor):
|
||||
async def on_receive(self, message):
|
||||
return message
|
||||
|
||||
|
||||
class CounterActor(Actor):
|
||||
async def on_started(self):
|
||||
self.count = 0
|
||||
|
||||
async def on_receive(self, message):
|
||||
if message == "inc":
|
||||
self.count += 1
|
||||
elif message == "get":
|
||||
return self.count
|
||||
|
||||
|
||||
class CrashActor(Actor):
|
||||
async def on_receive(self, message):
|
||||
if message == "crash":
|
||||
raise ValueError("boom")
|
||||
return "ok"
|
||||
|
||||
|
||||
class ParentActor(Actor):
|
||||
def __init__(self):
|
||||
self.child_ref: ActorRef | None = None
|
||||
self.restarts = 0
|
||||
|
||||
def supervisor_strategy(self):
|
||||
return OneForOneStrategy(max_restarts=3, within_seconds=60)
|
||||
|
||||
async def on_started(self):
|
||||
self.child_ref = await self.context.spawn(CrashActor, "child")
|
||||
|
||||
async def on_receive(self, message):
|
||||
if message == "get_child":
|
||||
return self.child_ref
|
||||
|
||||
|
||||
class StopOnCrashParent(Actor):
|
||||
def supervisor_strategy(self):
|
||||
return OneForOneStrategy(decider=lambda _: Directive.stop)
|
||||
|
||||
async def on_started(self):
|
||||
self.child_ref = await self.context.spawn(CrashActor, "child")
|
||||
|
||||
async def on_receive(self, message):
|
||||
if message == "get_child":
|
||||
return self.child_ref
|
||||
|
||||
|
||||
class AllForOneParent(Actor):
|
||||
def supervisor_strategy(self):
|
||||
return AllForOneStrategy(max_restarts=2, within_seconds=60)
|
||||
|
||||
async def on_started(self):
|
||||
self.c1 = await self.context.spawn(CounterActor, "c1")
|
||||
self.c2 = await self.context.spawn(CrashActor, "c2")
|
||||
|
||||
async def on_receive(self, message):
|
||||
if message == "get_children":
|
||||
return (self.c1, self.c2)
|
||||
|
||||
|
||||
class LifecycleActor(Actor):
|
||||
started = False
|
||||
stopped = False
|
||||
restarted_with: Exception | None = None
|
||||
|
||||
async def on_started(self):
|
||||
LifecycleActor.started = True
|
||||
|
||||
async def on_stopped(self):
|
||||
LifecycleActor.stopped = True
|
||||
|
||||
async def on_restart(self, error):
|
||||
LifecycleActor.restarted_with = error
|
||||
|
||||
async def on_receive(self, message):
|
||||
if message == "crash":
|
||||
raise RuntimeError("lifecycle crash")
|
||||
return "alive"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBasicMessaging:
|
||||
@pytest.mark.anyio
|
||||
async def test_tell_and_ask(self):
|
||||
system = ActorSystem("test")
|
||||
ref = await system.spawn(EchoActor, "echo")
|
||||
result = await ref.ask("hello")
|
||||
assert result == "hello"
|
||||
await system.shutdown()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_ask_timeout(self):
|
||||
class SlowActor(Actor):
|
||||
async def on_receive(self, message):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
system = ActorSystem("test")
|
||||
ref = await system.spawn(SlowActor, "slow")
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await ref.ask("hi", timeout=0.1)
|
||||
await system.shutdown()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tell_fire_and_forget(self):
|
||||
system = ActorSystem("test")
|
||||
ref = await system.spawn(CounterActor, "counter")
|
||||
await ref.tell("inc")
|
||||
await ref.tell("inc")
|
||||
await ref.tell("inc")
|
||||
# Give the actor time to process
|
||||
await asyncio.sleep(0.05)
|
||||
count = await ref.ask("get")
|
||||
assert count == 3
|
||||
await system.shutdown()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_ask_stopped_actor(self):
|
||||
system = ActorSystem("test")
|
||||
ref = await system.spawn(EchoActor, "echo")
|
||||
ref.stop()
|
||||
await asyncio.sleep(0.05)
|
||||
with pytest.raises(ActorStoppedError):
|
||||
await ref.ask("hello")
|
||||
await system.shutdown()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tell_stopped_actor_goes_to_dead_letters(self):
|
||||
system = ActorSystem("test")
|
||||
ref = await system.spawn(EchoActor, "echo")
|
||||
ref.stop()
|
||||
await asyncio.sleep(0.05)
|
||||
await ref.tell("orphan")
|
||||
assert len(system.dead_letters) >= 1
|
||||
await system.shutdown()
|
||||
|
||||
|
||||
class TestActorPath:
|
||||
@pytest.mark.anyio
|
||||
async def test_root_actor_path(self):
|
||||
system = ActorSystem("app")
|
||||
ref = await system.spawn(EchoActor, "echo")
|
||||
assert ref.path == "/app/echo"
|
||||
await system.shutdown()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_child_actor_path(self):
|
||||
system = ActorSystem("app")
|
||||
parent = await system.spawn(ParentActor, "parent")
|
||||
child: ActorRef = await parent.ask("get_child")
|
||||
assert child.path == "/app/parent/child"
|
||||
await system.shutdown()
|
||||
|
||||
|
||||
class TestLifecycle:
|
||||
@pytest.mark.anyio
|
||||
async def test_on_started_called(self):
|
||||
LifecycleActor.started = False
|
||||
system = ActorSystem("test")
|
||||
await system.spawn(LifecycleActor, "lc")
|
||||
assert LifecycleActor.started is True
|
||||
await system.shutdown()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_on_stopped_called(self):
|
||||
LifecycleActor.stopped = False
|
||||
system = ActorSystem("test")
|
||||
ref = await system.spawn(LifecycleActor, "lc")
|
||||
ref.stop()
|
||||
await asyncio.sleep(0.1)
|
||||
assert LifecycleActor.stopped is True
|
||||
await system.shutdown()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_shutdown_stops_all(self):
|
||||
system = ActorSystem("test")
|
||||
r1 = await system.spawn(EchoActor, "a")
|
||||
r2 = await system.spawn(EchoActor, "b")
|
||||
await system.shutdown()
|
||||
assert not r1.is_alive
|
||||
assert not r2.is_alive
|
||||
|
||||
|
||||
class TestSupervision:
|
||||
@pytest.mark.anyio
|
||||
async def test_restart_on_crash(self):
|
||||
system = ActorSystem("test")
|
||||
parent = await system.spawn(ParentActor, "parent")
|
||||
child: ActorRef = await parent.ask("get_child")
|
||||
|
||||
# Crash the child
|
||||
with pytest.raises(ValueError, match="boom"):
|
||||
await child.ask("crash")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Child should still be alive (restarted)
|
||||
assert child.is_alive
|
||||
result = await child.ask("safe")
|
||||
assert result == "ok"
|
||||
await system.shutdown()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_stop_directive(self):
|
||||
system = ActorSystem("test")
|
||||
parent = await system.spawn(StopOnCrashParent, "parent")
|
||||
child: ActorRef = await parent.ask("get_child")
|
||||
|
||||
with pytest.raises(ValueError, match="boom"):
|
||||
await child.ask("crash")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
assert not child.is_alive
|
||||
await system.shutdown()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_restart_limit_exceeded(self):
|
||||
system = ActorSystem("test")
|
||||
|
||||
class StrictParent(Actor):
|
||||
def supervisor_strategy(self):
|
||||
return OneForOneStrategy(max_restarts=2, within_seconds=60)
|
||||
|
||||
async def on_started(self):
|
||||
self.child_ref = await self.context.spawn(CrashActor, "child")
|
||||
|
||||
async def on_receive(self, message):
|
||||
return self.child_ref
|
||||
|
||||
parent = await system.spawn(StrictParent, "parent")
|
||||
child: ActorRef = await parent.ask("any")
|
||||
|
||||
# Exhaust restart limit
|
||||
for _ in range(3):
|
||||
try:
|
||||
await child.ask("crash")
|
||||
except (ValueError, ActorStoppedError):
|
||||
pass
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# After exceeding limit, child should be stopped
|
||||
assert not child.is_alive
|
||||
await system.shutdown()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_all_for_one_restarts_siblings(self):
|
||||
system = ActorSystem("test")
|
||||
parent = await system.spawn(AllForOneParent, "parent")
|
||||
c1, c2 = await parent.ask("get_children")
|
||||
|
||||
# Increment counter on c1
|
||||
await c1.tell("inc")
|
||||
await asyncio.sleep(0.05)
|
||||
count_before = await c1.ask("get")
|
||||
assert count_before == 1
|
||||
|
||||
# Crash c2 → AllForOne should restart both
|
||||
try:
|
||||
await c2.ask("crash")
|
||||
except ValueError:
|
||||
pass
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# c1 was restarted, counter should be 0
|
||||
count_after = await c1.ask("get")
|
||||
assert count_after == 0
|
||||
await system.shutdown()
|
||||
|
||||
|
||||
class TestDeadLetters:
|
||||
@pytest.mark.anyio
|
||||
async def test_dead_letter_callback(self):
|
||||
received = []
|
||||
system = ActorSystem("test")
|
||||
system.on_dead_letter(lambda dl: received.append(dl))
|
||||
|
||||
ref = await system.spawn(EchoActor, "echo")
|
||||
ref.stop()
|
||||
await asyncio.sleep(0.05)
|
||||
await ref.tell("orphan")
|
||||
|
||||
assert len(received) >= 1
|
||||
assert received[-1].message == "orphan"
|
||||
await system.shutdown()
|
||||
|
||||
|
||||
class TestDuplicateNames:
|
||||
@pytest.mark.anyio
|
||||
async def test_duplicate_root_name_raises(self):
|
||||
system = ActorSystem("test")
|
||||
await system.spawn(EchoActor, "echo")
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
await system.spawn(EchoActor, "echo")
|
||||
await system.shutdown()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Middleware tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class LogMiddleware(Middleware):
|
||||
def __init__(self):
|
||||
self.log: list[str] = []
|
||||
|
||||
async def on_receive(self, ctx, message, next_fn):
|
||||
self.log.append(f"before:{message}")
|
||||
result = await next_fn(ctx, message)
|
||||
self.log.append(f"after:{result}")
|
||||
return result
|
||||
|
||||
async def on_started(self, actor_ref):
|
||||
self.log.append("started")
|
||||
|
||||
async def on_stopped(self, actor_ref):
|
||||
self.log.append("stopped")
|
||||
|
||||
|
||||
class TransformMiddleware(Middleware):
|
||||
"""Uppercases string messages before passing to actor."""
|
||||
|
||||
async def on_receive(self, ctx, message, next_fn):
|
||||
if isinstance(message, str):
|
||||
message = message.upper()
|
||||
return await next_fn(ctx, message)
|
||||
|
||||
|
||||
class TestExecutor:
|
||||
@pytest.mark.anyio
|
||||
async def test_run_in_executor(self):
|
||||
"""Blocking function runs in thread pool without blocking event loop."""
|
||||
import time
|
||||
|
||||
class BlockingActor(Actor):
|
||||
async def on_receive(self, message):
|
||||
# Simulate blocking I/O via thread pool
|
||||
result = await self.context.run_in_executor(time.sleep, 0.01)
|
||||
return "done"
|
||||
|
||||
system = ActorSystem("test", executor_workers=2)
|
||||
ref = await system.spawn(BlockingActor, "blocker")
|
||||
result = await ref.ask("go", timeout=5.0)
|
||||
assert result == "done"
|
||||
await system.shutdown()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_concurrent_blocking_calls(self):
|
||||
"""Multiple actors can run blocking I/O concurrently via shared pool."""
|
||||
import time
|
||||
|
||||
class SlowActor(Actor):
|
||||
async def on_receive(self, message):
|
||||
await self.context.run_in_executor(time.sleep, 0.1)
|
||||
return "ok"
|
||||
|
||||
system = ActorSystem("test", executor_workers=4)
|
||||
refs = [await system.spawn(SlowActor, f"s{i}") for i in range(4)]
|
||||
|
||||
start = time.monotonic()
|
||||
results = await asyncio.gather(*[r.ask("go", timeout=5.0) for r in refs])
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
assert all(r == "ok" for r in results)
|
||||
# 4 parallel × 0.1s should finish in ~0.1-0.2s, not 0.4s
|
||||
assert elapsed < 0.3
|
||||
await system.shutdown()
|
||||
|
||||
|
||||
class TestMiddleware:
|
||||
@pytest.mark.anyio
|
||||
async def test_middleware_intercepts_messages(self):
|
||||
mw = LogMiddleware()
|
||||
system = ActorSystem("test")
|
||||
ref = await system.spawn(EchoActor, "echo", middlewares=[mw])
|
||||
result = await ref.ask("hello")
|
||||
assert result == "hello"
|
||||
assert "before:hello" in mw.log
|
||||
assert "after:hello" in mw.log
|
||||
await system.shutdown()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_middleware_lifecycle_hooks(self):
|
||||
mw = LogMiddleware()
|
||||
system = ActorSystem("test")
|
||||
ref = await system.spawn(EchoActor, "echo", middlewares=[mw])
|
||||
assert "started" in mw.log
|
||||
ref.stop()
|
||||
await asyncio.sleep(0.1)
|
||||
assert "stopped" in mw.log
|
||||
await system.shutdown()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_middleware_chain_order(self):
|
||||
"""First middleware wraps outermost — sees original message."""
|
||||
mw1 = LogMiddleware()
|
||||
mw2 = TransformMiddleware()
|
||||
system = ActorSystem("test")
|
||||
# Chain: mw1(mw2(actor)). mw1 logs original, mw2 uppercases, actor echoes
|
||||
ref = await system.spawn(EchoActor, "echo", middlewares=[mw1, mw2])
|
||||
result = await ref.ask("hello")
|
||||
assert result == "HELLO" # TransformMiddleware uppercased
|
||||
assert "before:hello" in mw1.log # LogMiddleware saw original
|
||||
assert "after:HELLO" in mw1.log # LogMiddleware saw transformed result
|
||||
await system.shutdown()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_middleware_with_tell(self):
|
||||
mw = LogMiddleware()
|
||||
system = ActorSystem("test")
|
||||
await system.spawn(CounterActor, "counter", middlewares=[mw])
|
||||
# tell goes through middleware too
|
||||
assert any("before:" in entry for entry in mw.log) is False
|
||||
await system.shutdown()
|
||||
Loading…
x
Reference in New Issue
Block a user