feat: asyncio-native Actor framework with supervision, middleware, and pluggable mailbox

Lightweight actor library built on asyncio primitives (~800 lines):

- Actor base class with lifecycle hooks (on_started/on_stopped/on_restart)
- ActorRef with tell (fire-and-forget) and ask (request-response)
- Supervision: OneForOne/AllForOne strategies with restart limits
- Middleware pipeline for cross-cutting concerns
- Pluggable Mailbox interface (MemoryMailbox default, RedisMailbox optional)
- ReplyRegistry + ReplyChannel: ask() works across any mailbox backend
- System-level thread pool for blocking I/O (run_in_executor)
- Dead letter handling, poison message quarantine, parallel shutdown
- 22 tests + benchmark suite
This commit is contained in:
greatmengqi 2026-03-30 23:35:28 +08:00
parent 9e3d484858
commit 3e17417122
11 changed files with 1851 additions and 0 deletions

2
.gitignore vendored
View File

@ -2,6 +2,8 @@
docker/.cache/
# oh-my-claudecode state
.omc/
# Collaborator plugin state
.collaborator/
# OS generated files
.DS_Store
*.local

View File

@ -0,0 +1,40 @@
"""Async Actor framework — lightweight, asyncio-native, supervision-ready.
Usage::
from deerflow.actor import Actor, ActorSystem
class Greeter(Actor):
async def on_receive(self, message):
return f"Hello, {message}!"
async def main():
system = ActorSystem("app")
ref = await system.spawn(Greeter, "greeter")
reply = await ref.ask("World", timeout=5.0)
print(reply) # Hello, World!
await system.shutdown()
"""
from .actor import Actor, ActorContext
from .mailbox import Mailbox, MemoryMailbox
from .middleware import Middleware
from .ref import ActorRef, ReplyChannel
from .supervision import AllForOneStrategy, Directive, OneForOneStrategy, SupervisorStrategy
from .system import ActorSystem, DeadLetter
__all__ = [
"Actor",
"ActorContext",
"ActorRef",
"ActorSystem",
"AllForOneStrategy",
"DeadLetter",
"Directive",
"Mailbox",
"MemoryMailbox",
"Middleware",
"OneForOneStrategy",
"ReplyChannel",
"SupervisorStrategy",
]

View File

@ -0,0 +1,109 @@
"""Actor base class and per-actor context."""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from .supervision import OneForOneStrategy, SupervisorStrategy
if TYPE_CHECKING:
from .ref import ActorRef
# Message type variable — use Actor[MyMsg] for typed actors
M = TypeVar("M")
R = TypeVar("R")
class ActorContext:
"""Per-actor runtime context, injected before ``on_started``.
Provides access to the actor's identity, parent, children,
and the ability to spawn child actors.
"""
__slots__ = ("_cell",)
def __init__(self, cell: Any) -> None:
self._cell = cell
@property
def self_ref(self) -> ActorRef:
return self._cell.ref
@property
def parent(self) -> ActorRef | None:
p = self._cell.parent
return p.ref if p is not None else None
@property
def children(self) -> dict[str, ActorRef]:
return {name: c.ref for name, c in self._cell.children.items()}
@property
def system(self) -> Any:
return self._cell.system
async def spawn(
self,
actor_cls: type[Actor],
name: str,
*,
mailbox_size: int = 256,
middlewares: list | None = None,
) -> ActorRef:
"""Spawn a child actor supervised by this actor."""
return await self._cell.spawn_child(actor_cls, name, mailbox_size=mailbox_size, middlewares=middlewares)
async def run_in_executor(self, fn: Callable[..., Any], *args: Any) -> Any:
"""Run a blocking function in the system's thread pool.
Usage::
result = await self.context.run_in_executor(requests.get, url)
"""
import asyncio
executor = self._cell.system._executor
return await asyncio.get_running_loop().run_in_executor(executor, fn, *args)
class Actor(Generic[M]):
"""Base class for all actors.
Type parameter ``M`` constrains the message type::
class Greeter(Actor[str]):
async def on_receive(self, message: str) -> str:
return f"Hello, {message}!"
class Calculator(Actor[int | tuple[str, int, int]]):
async def on_receive(self, message: int | tuple[str, int, int]) -> int:
...
Unparameterized ``Actor`` accepts ``Any`` (backward-compatible).
"""
context: ActorContext
async def on_receive(self, message: M) -> Any:
"""Handle an incoming message.
Return value is sent back as reply for ``ask`` calls.
For ``tell`` calls, the return value is discarded.
"""
async def on_started(self) -> None:
"""Called after creation, before receiving messages."""
async def on_stopped(self) -> None:
"""Called on graceful shutdown. Release resources here."""
async def on_restart(self, error: Exception) -> None:
"""Called on the *new* instance before resuming after a crash."""
def supervisor_strategy(self) -> SupervisorStrategy:
"""Override to customize how this actor supervises its children.
Default: OneForOne, up to 3 restarts per 60 seconds, always restart.
"""
return OneForOneStrategy()

View File

@ -0,0 +1,94 @@
"""Pluggable mailbox abstraction — Akka-inspired enqueue/dequeue interface.
Built-in implementations:
- ``MemoryMailbox``: asyncio.Queue backed (default)
- Extend ``Mailbox`` for Redis, RabbitMQ, Kafka, etc.
"""
from __future__ import annotations
import abc
import asyncio
from typing import Any
class Mailbox(abc.ABC):
"""Abstract mailbox — the message queue for an actor.
Implementations must be async-safe for single-consumer usage.
Multiple producers may call ``put`` concurrently.
"""
@abc.abstractmethod
async def put(self, msg: Any) -> bool:
"""Enqueue a message. Returns True if accepted, False if dropped."""
@abc.abstractmethod
def put_nowait(self, msg: Any) -> bool:
"""Non-blocking enqueue. Returns True if accepted, False if dropped."""
@abc.abstractmethod
async def get(self) -> Any:
"""Dequeue the next message. Blocks until available."""
@abc.abstractmethod
def get_nowait(self) -> Any:
"""Non-blocking dequeue. Raises ``Empty`` if no message."""
@abc.abstractmethod
def empty(self) -> bool:
"""Return True if no messages are queued."""
@property
@abc.abstractmethod
def full(self) -> bool:
"""Return True if mailbox is at capacity."""
async def close(self) -> None:
"""Release resources. Default is no-op."""
class Empty(Exception):
"""Raised by ``get_nowait`` when mailbox is empty."""
class MemoryMailbox(Mailbox):
"""In-process mailbox backed by ``asyncio.Queue``."""
def __init__(self, maxsize: int = 256) -> None:
self._queue: asyncio.Queue[Any] = asyncio.Queue(maxsize=maxsize)
self._maxsize = maxsize
async def put(self, msg: Any) -> bool:
try:
await self._queue.put(msg)
return True
except asyncio.QueueFull:
return False
def put_nowait(self, msg: Any) -> bool:
try:
self._queue.put_nowait(msg)
return True
except asyncio.QueueFull:
return False
async def get(self) -> Any:
return await self._queue.get()
def get_nowait(self) -> Any:
try:
return self._queue.get_nowait()
except asyncio.QueueEmpty:
raise Empty("mailbox empty")
def empty(self) -> bool:
return self._queue.empty()
@property
def full(self) -> bool:
return self._queue.full()
# Type alias for mailbox factory
MailboxFactory = type[Mailbox] | Any # Callable[[], Mailbox]

View File

@ -0,0 +1,150 @@
"""Redis-backed mailbox — persistent, survives process restart.
Requires ``redis[hiredis]`` (``uv add redis[hiredis]``).
Usage::
import redis.asyncio as redis
from deerflow.actor import ActorSystem
from deerflow.actor.mailbox_redis import RedisMailbox
pool = redis.ConnectionPool.from_url("redis://localhost:6379")
system = ActorSystem("app")
ref = await system.spawn(
MyActor, "worker",
mailbox=RedisMailbox(pool, "actor:inbox:worker"),
)
"""
from __future__ import annotations
import json
import logging
from typing import Any
from .mailbox import Empty, Mailbox
from .ref import _Envelope, _Stop
logger = logging.getLogger(__name__)
def _serialize(msg: _Envelope | _Stop) -> str:
"""Serialize an envelope to JSON for Redis storage.
Raises ``TypeError`` if the payload is not JSON-serializable.
"""
if isinstance(msg, _Stop):
return json.dumps({"__type__": "stop"})
try:
return json.dumps({
"__type__": "envelope",
"payload": msg.payload,
"correlation_id": msg.correlation_id,
"reply_to": msg.reply_to,
})
except (TypeError, ValueError) as e:
raise TypeError(f"Payload is not JSON-serializable: {e}. RedisMailbox requires JSON-compatible messages.") from e
def _deserialize(data: str | bytes) -> _Envelope | _Stop:
"""Deserialize a JSON string back to an envelope or stop sentinel."""
if isinstance(data, bytes):
data = data.decode("utf-8")
d = json.loads(data)
if d.get("__type__") == "stop":
return _Stop()
return _Envelope(
payload=d.get("payload"),
sender=None,
correlation_id=d.get("correlation_id"),
reply_to=d.get("reply_to"),
)
class RedisMailbox(Mailbox):
"""Mailbox backed by a Redis LIST.
Each actor gets its own Redis key (the ``queue_name``).
Messages are serialized as JSON, so payloads must be JSON-compatible.
Args:
pool: A ``redis.asyncio.ConnectionPool`` instance.
queue_name: Redis key for this actor's inbox (e.g. ``"actor:inbox:worker"``).
maxlen: Maximum queue length. 0 = unbounded. When exceeded, ``put_nowait`` returns False.
brpop_timeout: Seconds to block on ``get()`` before retrying. Default 1s.
"""
def __init__(
self,
pool: Any,
queue_name: str,
*,
maxlen: int = 0,
brpop_timeout: float = 1.0,
) -> None:
self._queue_name = queue_name
self._maxlen = maxlen
self._brpop_timeout = brpop_timeout
self._closed = False
# Lazy import to avoid hard dependency on redis
try:
import redis.asyncio as aioredis
self._redis: aioredis.Redis = aioredis.Redis(connection_pool=pool)
except ImportError:
raise ImportError("RedisMailbox requires 'redis' package. Install with: uv add redis[hiredis]")
# Lua script for atomic bounded push: check length then push
_LUA_BOUNDED_PUSH = """
if tonumber(ARGV[2]) > 0 and redis.call('llen', KEYS[1]) >= tonumber(ARGV[2]) then
return 0
end
redis.call('lpush', KEYS[1], ARGV[1])
return 1
"""
async def put(self, msg: Any) -> bool:
if self._closed:
return False
data = _serialize(msg)
if self._maxlen > 0:
# Atomic check+push via Lua script to avoid TOCTOU race
result = await self._redis.evalsha_or_eval(self._LUA_BOUNDED_PUSH, 1, self._queue_name, data, self._maxlen)
return bool(result)
await self._redis.lpush(self._queue_name, data)
return True
def put_nowait(self, msg: Any) -> bool:
"""Redis cannot do synchronous non-blocking enqueue reliably.
Returns False so the caller uses dead-letter or task.cancel() fallback.
Use ``put()`` (async) for reliable delivery.
"""
return False
async def get(self) -> Any:
"""Blocking dequeue via BRPOP. Retries until a message arrives."""
while not self._closed:
result = await self._redis.brpop(self._queue_name, timeout=self._brpop_timeout)
if result is not None:
_, data = result
return _deserialize(data)
raise Empty("mailbox closed")
def get_nowait(self) -> Any:
raise Empty("Redis mailbox does not support synchronous get_nowait")
def empty(self) -> bool:
# Cannot query Redis synchronously. Return True so drain loops
# terminate immediately and rely on get_nowait raising Empty.
return True
@property
def full(self) -> bool:
# Cannot query Redis synchronously. Backpressure enforced
# atomically inside put() via Lua script.
return False
async def close(self) -> None:
self._closed = True
await self._redis.aclose()

View File

@ -0,0 +1,79 @@
"""Middleware pipeline — cross-cutting concerns for actors.
Inspired by Proto.Actor's sender/receiver middleware model.
Middleware intercepts messages before/after the actor processes them.
Usage::
class LoggingMiddleware(Middleware):
async def on_receive(self, ctx, message, next_fn):
logger.info("Received: %s", message)
result = await next_fn(ctx, message)
logger.info("Replied: %s", result)
return result
system = ActorSystem("app")
ref = await system.spawn(MyActor, "a", middlewares=[LoggingMiddleware()])
"""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import Any
class ActorMailboxContext:
"""Context passed to middleware on each message."""
__slots__ = ("actor_ref", "sender", "message_type")
def __init__(self, actor_ref: Any, sender: Any, message_type: str) -> None:
self.actor_ref = actor_ref
self.sender = sender
self.message_type = message_type # "tell" or "ask"
# The inner handler signature: (ctx, message) -> result
NextFn = Callable[[ActorMailboxContext, Any], Awaitable[Any]]
class Middleware:
"""Base class for actor middleware.
Override ``on_receive`` to intercept inbound messages.
Must call ``await next_fn(ctx, message)`` to continue the chain.
"""
async def on_receive(self, ctx: ActorMailboxContext, message: Any, next_fn: NextFn) -> Any:
"""Intercept a message. Call next_fn to continue the chain."""
return await next_fn(ctx, message)
async def on_started(self, actor_ref: Any) -> None:
"""Called when the actor starts."""
async def on_stopped(self, actor_ref: Any) -> None:
"""Called when the actor stops."""
async def on_restart(self, actor_ref: Any, error: Exception) -> None:
"""Called when the actor restarts after a crash.
Override to reset per-actor-instance state (caches, counters, etc.)
that should not bleed across restarts.
"""
def build_middleware_chain(middlewares: list[Middleware], handler: NextFn) -> NextFn:
"""Build a nested middleware chain ending with *handler*.
Execution order: first middleware in list wraps outermost.
``[A, B, C]`` ``A(B(C(handler)))``
"""
chain = handler
for mw in reversed(middlewares):
outer = chain
async def _wrap(ctx: ActorMailboxContext, msg: Any, _mw: Middleware = mw, _next: NextFn = outer) -> Any:
return await _mw.on_receive(ctx, msg, _next)
chain = _wrap
return chain

View File

@ -0,0 +1,216 @@
"""ActorRef — immutable, serializable reference to an actor."""
from __future__ import annotations
import asyncio
import uuid
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from .system import _ActorCell
class ActorRef:
"""Immutable handle for sending messages to an actor.
Users never construct this directly it is returned by
``ActorSystem.spawn`` or ``ActorContext.spawn``.
"""
__slots__ = ("_cell",)
def __init__(self, cell: _ActorCell) -> None:
self._cell = cell
@property
def name(self) -> str:
return self._cell.name
@property
def path(self) -> str:
return self._cell.path
@property
def is_alive(self) -> bool:
return not self._cell.stopped
async def tell(self, message: Any, *, sender: ActorRef | None = None) -> None:
"""Fire-and-forget message delivery."""
if self._cell.stopped:
self._cell.system._dead_letter(self, message, sender)
return
await self._cell.enqueue(_Envelope(message, sender))
async def ask(self, message: Any, *, timeout: float = 5.0) -> Any:
"""Request-response with timeout.
Uses correlation ID + ReplyRegistry instead of passing a Future
through the mailbox. This makes ask work with any Mailbox backend
(memory, Redis, RabbitMQ, etc.).
Raises ``asyncio.TimeoutError`` if the actor doesn't reply in time.
Raises the actor's exception if ``on_receive`` fails.
"""
if self._cell.stopped:
raise ActorStoppedError(f"Actor {self.path} is stopped")
corr_id = uuid.uuid4().hex
future = self._cell.system._replies.register(corr_id)
try:
envelope = _Envelope(message, sender=None, correlation_id=corr_id, reply_to=self._cell.system.system_id)
await self._cell.enqueue(envelope)
return await asyncio.wait_for(future, timeout=timeout)
finally:
self._cell.system._replies.discard(corr_id)
def stop(self) -> None:
"""Request graceful shutdown."""
self._cell.request_stop()
def __repr__(self) -> str:
alive = "alive" if self.is_alive else "dead"
return f"ActorRef({self.path}, {alive})"
def __eq__(self, other: object) -> bool:
if isinstance(other, ActorRef):
return self._cell is other._cell
return NotImplemented
def __hash__(self) -> int:
return id(self._cell)
class ActorStoppedError(Exception):
"""Raised when sending to a stopped actor via ask."""
# ---------------------------------------------------------------------------
# Internal message wrappers (serializable — no Future objects)
# ---------------------------------------------------------------------------
class _Envelope:
"""Message envelope flowing through mailboxes.
All fields are serializable (no asyncio.Future). This is what
enables ask() to work across MQ-backed mailboxes.
"""
__slots__ = ("payload", "sender", "correlation_id", "reply_to")
def __init__(
self,
payload: Any,
sender: ActorRef | None = None,
correlation_id: str | None = None,
reply_to: str | None = None,
) -> None:
self.payload = payload
self.sender = sender
self.correlation_id = correlation_id
self.reply_to = reply_to # System ID of the caller (for cross-process reply routing)
class _Stop:
"""Sentinel placed on the mailbox to trigger graceful shutdown."""
# ---------------------------------------------------------------------------
# ReplyRegistry — maps correlation_id → Future (lives on ActorSystem)
# ---------------------------------------------------------------------------
class _ReplyRegistry:
"""In-memory registry mapping correlation IDs to Futures.
Used by ask() to receive replies without putting Futures in the mailbox.
"""
def __init__(self) -> None:
self._pending: dict[str, asyncio.Future[Any]] = {}
def register(self, corr_id: str) -> asyncio.Future[Any]:
"""Create and register a Future for a correlation ID."""
future: asyncio.Future[Any] = asyncio.get_running_loop().create_future()
self._pending[corr_id] = future
return future
def resolve(self, corr_id: str, result: Any) -> None:
"""Complete a pending ask with a result."""
future = self._pending.pop(corr_id, None)
if future is not None and not future.done():
future.set_result(result)
def reject(self, corr_id: str, error: Exception) -> None:
"""Complete a pending ask with an error."""
future = self._pending.pop(corr_id, None)
if future is not None and not future.done():
future.set_exception(error)
def discard(self, corr_id: str) -> None:
"""Remove a pending entry (e.g. on timeout)."""
self._pending.pop(corr_id, None)
def reject_all(self, error: Exception) -> None:
"""Reject all pending asks (e.g. on system shutdown)."""
for future in self._pending.values():
if not future.done():
future.set_exception(error)
self._pending.clear()
# ---------------------------------------------------------------------------
# ReplyChannel — abstraction for routing replies (local or cross-process)
# ---------------------------------------------------------------------------
class _ReplyMessage:
"""Reply payload sent through ReplyChannel.
Carries the original exception object for local delivery (preserves type).
For cross-process serialization, use ``to_dict``/``from_dict``.
"""
__slots__ = ("correlation_id", "result", "error", "exception")
def __init__(self, correlation_id: str, result: Any = None, error: str | None = None, exception: Exception | None = None) -> None:
self.correlation_id = correlation_id
self.result = result
self.error = error
self.exception = exception # Original exception (local only, not serializable)
def to_dict(self) -> dict[str, Any]:
"""Serialize for cross-process transport (exception becomes string)."""
return {"correlation_id": self.correlation_id, "result": self.result, "error": self.error}
@classmethod
def from_dict(cls, d: dict[str, Any]) -> _ReplyMessage:
return cls(d["correlation_id"], d.get("result"), d.get("error"))
class ReplyChannel:
"""Routes replies from actor back to the caller's ReplyRegistry.
Default implementation: resolve locally (same process).
Override ``send_reply`` for cross-process routing (e.g. via Redis pub/sub).
"""
async def send_reply(self, reply_to: str, reply: _ReplyMessage, local_registry: _ReplyRegistry) -> None:
"""Deliver a reply to the system identified by *reply_to*.
Default: assumes reply_to is the local system resolve directly.
Override for MQ-backed cross-process delivery.
"""
if reply.exception is not None:
# Local: preserve original exception type
local_registry.reject(reply.correlation_id, reply.exception)
elif reply.error is not None:
# Cross-process: exception was serialized to string
local_registry.reject(reply.correlation_id, RuntimeError(reply.error))
else:
local_registry.resolve(reply.correlation_id, reply.result)
async def start_listener(self, system_id: str, registry: _ReplyRegistry) -> None:
"""Start listening for inbound replies (no-op for local)."""
async def stop_listener(self) -> None:
"""Stop the reply listener (no-op for local)."""

View File

@ -0,0 +1,75 @@
"""Supervision strategies — Erlang/Akka-inspired fault tolerance."""
from __future__ import annotations
import enum
import time
from collections import deque
from collections.abc import Callable
from typing import Any
class Directive(enum.Enum):
"""What a supervisor should do when a child fails."""
resume = "resume" # ignore error, keep processing
restart = "restart" # discard state, create fresh instance
stop = "stop" # terminate the child permanently
escalate = "escalate" # propagate to grandparent
class SupervisorStrategy:
"""Base class for supervision strategies.
Args:
max_restarts: Maximum restarts allowed within *within_seconds*.
Exceeding this limit stops the child permanently.
within_seconds: Time window for restart counting.
decider: Maps exception Directive. Default: always restart.
"""
def __init__(
self,
*,
max_restarts: int = 3,
within_seconds: float = 60.0,
decider: Callable[[Exception], Directive] | None = None,
) -> None:
self.max_restarts = max_restarts
self.within_seconds = within_seconds
self.decider = decider or (lambda _: Directive.restart)
self._restart_timestamps: dict[str, deque[float]] = {}
def decide(self, error: Exception) -> Directive:
return self.decider(error)
def record_restart(self, child_name: str) -> bool:
"""Record a restart and return True if within limits."""
now = time.monotonic()
if child_name not in self._restart_timestamps:
self._restart_timestamps[child_name] = deque()
ts = self._restart_timestamps[child_name]
# Purge old entries outside the window
cutoff = now - self.within_seconds
while ts and ts[0] < cutoff:
ts.popleft()
ts.append(now)
return len(ts) <= self.max_restarts
def apply_to_children(self, failed_child: str, all_children: list[str]) -> list[str]:
"""Return which children should be affected by the directive."""
raise NotImplementedError
class OneForOneStrategy(SupervisorStrategy):
"""Only the failed child is affected."""
def apply_to_children(self, failed_child: str, all_children: list[str]) -> list[str]:
return [failed_child]
class AllForOneStrategy(SupervisorStrategy):
"""All children are affected when any one fails."""
def apply_to_children(self, failed_child: str, all_children: list[str]) -> list[str]:
return list(all_children)

View File

@ -0,0 +1,381 @@
"""ActorSystem — top-level actor container and lifecycle manager."""
from __future__ import annotations
import asyncio
import logging
from collections import deque
from dataclasses import dataclass
from typing import Any
from .actor import Actor, ActorContext
from .mailbox import Empty, Mailbox, MemoryMailbox
from .middleware import ActorMailboxContext, Middleware, NextFn, build_middleware_chain
from .ref import ActorRef, ActorStoppedError, ReplyChannel, _Envelope, _ReplyMessage, _ReplyRegistry, _Stop
from .supervision import Directive, SupervisorStrategy
logger = logging.getLogger(__name__)
# Timeout for middleware lifecycle hooks (on_started/on_stopped)
_MIDDLEWARE_HOOK_TIMEOUT = 10.0
# Maximum dead letters kept in memory
_MAX_DEAD_LETTERS = 10000
# Maximum consecutive failures before a root actor poison-quarantines a message
_MAX_CONSECUTIVE_FAILURES = 10
@dataclass
class DeadLetter:
"""A message that could not be delivered."""
recipient: ActorRef
message: Any
sender: ActorRef | None
class ActorSystem:
"""Top-level actor container.
Manages root actors and provides the dead letter sink.
"""
def __init__(
self,
name: str = "system",
*,
max_dead_letters: int = _MAX_DEAD_LETTERS,
executor_workers: int | None = 4,
reply_channel: ReplyChannel | None = None,
) -> None:
import uuid as _uuid
self.name = name
self.system_id = f"{name}-{_uuid.uuid4().hex[:8]}"
self._root_cells: dict[str, _ActorCell] = {}
self._dead_letters: deque[DeadLetter] = deque(maxlen=max_dead_letters)
self._on_dead_letter: list[Any] = []
self._shutting_down = False
self._replies = _ReplyRegistry()
self._reply_channel = reply_channel or ReplyChannel()
# Shared thread pool for actors to run blocking I/O
from concurrent.futures import ThreadPoolExecutor
self._executor = ThreadPoolExecutor(max_workers=executor_workers, thread_name_prefix=f"actor-{name}") if executor_workers else None
async def spawn(
self,
actor_cls: type[Actor],
name: str,
*,
mailbox_size: int = 256,
mailbox: Mailbox | None = None,
middlewares: list[Middleware] | None = None,
) -> ActorRef:
"""Spawn a root-level actor.
Args:
mailbox: Custom mailbox instance. If None, uses MemoryMailbox(mailbox_size).
"""
if name in self._root_cells:
raise ValueError(f"Root actor '{name}' already exists")
cell = _ActorCell(
actor_cls=actor_cls,
name=name,
parent=None,
system=self,
mailbox=mailbox or MemoryMailbox(mailbox_size),
middlewares=middlewares or [],
)
self._root_cells[name] = cell
await cell.start()
return cell.ref
async def shutdown(self, *, timeout: float = 10.0) -> None:
"""Gracefully stop all actors."""
self._shutting_down = True
tasks = []
for cell in list(self._root_cells.values()):
cell.request_stop()
if cell.task is not None:
tasks.append(cell.task)
if tasks:
await asyncio.wait(tasks, timeout=timeout)
self._root_cells.clear()
self._replies.reject_all(ActorStoppedError("ActorSystem shutting down"))
await self._reply_channel.stop_listener()
if self._executor is not None:
self._executor.shutdown(wait=False)
logger.info("ActorSystem '%s' shut down (%d dead letters)", self.name, len(self._dead_letters))
def _dead_letter(self, recipient: ActorRef, message: Any, sender: ActorRef | None) -> None:
dl = DeadLetter(recipient=recipient, message=message, sender=sender)
self._dead_letters.append(dl)
for cb in self._on_dead_letter:
try:
cb(dl)
except Exception:
pass
logger.debug("Dead letter: %s%s", type(message).__name__, recipient.path)
def on_dead_letter(self, callback: Any) -> None:
"""Register a dead letter listener."""
self._on_dead_letter.append(callback)
@property
def dead_letters(self) -> list[DeadLetter]:
return list(self._dead_letters)
# ---------------------------------------------------------------------------
# _ActorCell — internal runtime wrapper
# ---------------------------------------------------------------------------
class _ActorCell:
"""Runtime container for a single actor instance.
Manages the mailbox, processing loop, children, and supervision.
Not part of the public API.
"""
def __init__(
self,
actor_cls: type[Actor],
name: str,
parent: _ActorCell | None,
system: ActorSystem,
mailbox: Mailbox,
middlewares: list[Middleware] | None = None,
) -> None:
self.actor_cls = actor_cls
self.name = name
self.parent = parent
self.system = system
self.children: dict[str, _ActorCell] = {}
self.mailbox = mailbox
self.ref = ActorRef(self)
self.actor: Actor | None = None
self.task: asyncio.Task[None] | None = None
self.stopped = False
self._supervisor_strategy: SupervisorStrategy | None = None
self._middlewares = middlewares or []
self._receive_chain: NextFn | None = None
# Cache path (immutable after init — parent never changes)
parts: list[str] = []
cell: _ActorCell | None = self
while cell is not None:
parts.append(cell.name)
cell = cell.parent
parts.append(system.name)
self.path = "/" + "/".join(reversed(parts))
async def start(self) -> None:
self.actor = self.actor_cls()
self.actor.context = ActorContext(self)
async def _inner_handler(_ctx: ActorMailboxContext, message: Any) -> Any:
return await self.actor.on_receive(message) # type: ignore[union-attr]
if self._middlewares:
self._receive_chain = build_middleware_chain(self._middlewares, _inner_handler)
else:
self._receive_chain = _inner_handler
# Notify middleware of start (with timeout to prevent blocking)
for mw in self._middlewares:
try:
await asyncio.wait_for(mw.on_started(self.ref), timeout=_MIDDLEWARE_HOOK_TIMEOUT)
except asyncio.TimeoutError:
logger.warning("Middleware %s.on_started timed out for %s", type(mw).__name__, self.path)
await self.actor.on_started()
self.task = asyncio.create_task(self._run(), name=f"actor:{self.path}")
async def enqueue(self, msg: _Envelope | _Stop) -> None:
if not self.mailbox.put_nowait(msg):
if isinstance(msg, _Envelope) and msg.correlation_id is not None:
self.system._replies.reject(msg.correlation_id, RuntimeError(f"Mailbox full: {self.path}"))
elif isinstance(msg, _Envelope):
self.system._dead_letter(self.ref, msg.payload, msg.sender)
def request_stop(self) -> None:
"""Request graceful shutdown. Falls back to task.cancel() if mailbox full."""
if not self.stopped:
if not self.mailbox.put_nowait(_Stop()):
if self.task is not None and not self.task.done():
self.task.cancel()
else:
self.stopped = True
async def spawn_child(
self,
actor_cls: type[Actor],
name: str,
*,
mailbox_size: int = 256,
mailbox: Mailbox | None = None,
middlewares: list[Middleware] | None = None,
) -> ActorRef:
if name in self.children:
raise ValueError(f"Child '{name}' already exists under {self.path}")
child = _ActorCell(
actor_cls=actor_cls,
name=name,
parent=self,
system=self.system,
mailbox=mailbox or MemoryMailbox(mailbox_size),
middlewares=middlewares or [],
)
self.children[name] = child
await child.start()
return child.ref
# -- Processing loop -------------------------------------------------------
async def _run(self) -> None:
consecutive_failures = 0
try:
while not self.stopped:
try:
msg = await self.mailbox.get()
except asyncio.CancelledError:
break
if isinstance(msg, _Stop):
break
try:
if not isinstance(msg, _Envelope):
continue
msg_type = "ask" if msg.correlation_id else "tell"
ctx = ActorMailboxContext(self.ref, msg.sender, msg_type)
result = await self._receive_chain(ctx, msg.payload) # type: ignore[misc]
if msg.correlation_id is not None:
reply = _ReplyMessage(msg.correlation_id, result=result)
await self.system._reply_channel.send_reply(msg.reply_to or self.system.system_id, reply, self.system._replies)
consecutive_failures = 0
except Exception as exc:
if isinstance(msg, _Envelope) and msg.correlation_id is not None:
reply = _ReplyMessage(msg.correlation_id, error=str(exc), exception=exc)
await self.system._reply_channel.send_reply(msg.reply_to or self.system.system_id, reply, self.system._replies)
if self.parent is not None:
await self.parent._handle_child_failure(self, exc)
else:
consecutive_failures += 1
logger.error("Uncaught error in root actor %s (%d/%d): %s", self.path, consecutive_failures, _MAX_CONSECUTIVE_FAILURES, exc)
if consecutive_failures >= _MAX_CONSECUTIVE_FAILURES:
logger.error("Root actor %s hit consecutive failure limit — stopping", self.path)
break
except asyncio.CancelledError:
pass # Fall through to _shutdown
finally:
await self._shutdown()
async def _shutdown(self) -> None:
self.stopped = True
# Parallel child shutdown prevents cascading timeouts.
child_tasks = []
for child in list(self.children.values()):
child.request_stop()
if child.task is not None:
child_tasks.append(child.task)
if child_tasks:
_, pending = await asyncio.wait(child_tasks, timeout=10.0)
for t in pending:
t.cancel()
# Mark leaked children as stopped
for child in self.children.values():
if child.task is t:
child.stopped = True
# Drain mailbox → dead letters (use try/except to handle all backends)
while True:
try:
msg = self.mailbox.get_nowait()
except Empty:
break
if isinstance(msg, _Envelope):
if msg.correlation_id is not None:
self.system._replies.reject(msg.correlation_id, ActorStoppedError(f"Actor {self.path} stopped"))
else:
self.system._dead_letter(self.ref, msg.payload, msg.sender)
# Lifecycle hook
for mw in self._middlewares:
try:
await asyncio.wait_for(mw.on_stopped(self.ref), timeout=_MIDDLEWARE_HOOK_TIMEOUT)
except asyncio.TimeoutError:
logger.warning("Middleware %s.on_stopped timed out for %s", type(mw).__name__, self.path)
except Exception:
logger.exception("Error in middleware on_stopped for %s", self.path)
if self.actor is not None:
try:
await self.actor.on_stopped()
except Exception:
logger.exception("Error in on_stopped for %s", self.path)
# Remove from parent
if self.parent is not None:
self.parent.children.pop(self.name, None)
# -- Supervision -----------------------------------------------------------
def _get_supervisor_strategy(self) -> SupervisorStrategy:
if self._supervisor_strategy is None:
self._supervisor_strategy = self.actor.supervisor_strategy() # type: ignore[union-attr]
return self._supervisor_strategy
async def _handle_child_failure(self, child: _ActorCell, error: Exception) -> None:
strategy = self._get_supervisor_strategy()
directive = strategy.decide(error)
affected = strategy.apply_to_children(child.name, list(self.children.keys()))
if directive == Directive.resume:
logger.info("Supervisor %s: resume %s after %s", self.path, child.path, type(error).__name__)
return
if directive == Directive.stop:
for name in affected:
c = self.children.get(name)
if c is not None:
c.request_stop()
logger.info("Supervisor %s: stop %s after %s", self.path, [self.children[n].path for n in affected if n in self.children], type(error).__name__)
return
if directive == Directive.escalate:
logger.info("Supervisor %s: escalate %s", self.path, type(error).__name__)
raise error
if directive == Directive.restart:
for name in affected:
c = self.children.get(name)
if c is None:
continue
if not strategy.record_restart(name):
logger.warning("Supervisor %s: child %s exceeded restart limit — stopping", self.path, c.path)
c.request_stop()
continue
await self._restart_child(c, error)
async def _restart_child(self, child: _ActorCell, error: Exception) -> None:
logger.info("Supervisor %s: restarting %s after %s", self.path, child.path, type(error).__name__)
# Stop the old actor (but keep the cell and mailbox)
old_actor = child.actor
if old_actor is not None:
try:
await old_actor.on_stopped()
except Exception:
logger.exception("Error in on_stopped during restart of %s", child.path)
# Notify middleware of restart (reset per-instance state)
for mw in child._middlewares:
try:
await asyncio.wait_for(mw.on_restart(child.ref, error), timeout=_MIDDLEWARE_HOOK_TIMEOUT)
except asyncio.TimeoutError:
logger.warning("Middleware %s.on_restart timed out for %s", type(mw).__name__, child.path)
except Exception:
logger.exception("Error in middleware on_restart for %s", child.path)
# Create fresh instance
new_actor = child.actor_cls()
new_actor.context = ActorContext(child)
child.actor = new_actor
try:
await new_actor.on_restart(error)
await new_actor.on_started()
except Exception:
logger.exception("Error during restart initialization of %s", child.path)
child.request_stop()

View File

@ -0,0 +1,263 @@
"""Actor framework benchmarks — throughput, latency, concurrency."""
import asyncio
import time
import statistics
from deerflow.actor import Actor, ActorSystem, Middleware
class NoopActor(Actor):
async def on_receive(self, message):
return message
class CounterActor(Actor):
async def on_started(self):
self.count = 0
async def on_receive(self, message):
self.count += 1
return self.count
class ChainActor(Actor):
"""Forwards message to next actor in chain."""
next_ref = None
async def on_receive(self, message):
if self.next_ref is not None:
return await self.next_ref.ask(message)
return message
class ComputeActor(Actor):
"""Simulates CPU work via thread pool."""
async def on_receive(self, message):
def fib(n):
a, b = 0, 1
for _ in range(n):
a, b = b, a + b
return a
return await self.context.run_in_executor(fib, message)
class CountMiddleware(Middleware):
def __init__(self):
self.count = 0
async def on_receive(self, ctx, message, next_fn):
self.count += 1
return await next_fn(ctx, message)
def fmt(n):
if n >= 1_000_000:
return f"{n/1_000_000:.1f}M"
if n >= 1_000:
return f"{n/1_000:.0f}K"
return str(n)
async def bench_tell_throughput(n=100_000):
"""Measure tell (fire-and-forget) throughput."""
system = ActorSystem("bench")
ref = await system.spawn(CounterActor, "counter", mailbox_size=n + 10)
start = time.perf_counter()
for _ in range(n):
await ref.tell("inc")
# Wait for all messages to be processed
count = await ref.ask("get", timeout=30.0)
elapsed = time.perf_counter() - start
await system.shutdown()
rate = n / elapsed
print(f" tell throughput: {fmt(n)} msgs in {elapsed:.2f}s = {fmt(int(rate))}/s")
async def bench_ask_throughput(n=50_000):
"""Measure ask (request-response) throughput."""
system = ActorSystem("bench")
ref = await system.spawn(NoopActor, "echo")
start = time.perf_counter()
for _ in range(n):
await ref.ask("ping")
elapsed = time.perf_counter() - start
await system.shutdown()
rate = n / elapsed
print(f" ask throughput: {fmt(n)} msgs in {elapsed:.2f}s = {fmt(int(rate))}/s")
async def bench_ask_latency(n=10_000):
"""Measure ask round-trip latency percentiles."""
system = ActorSystem("bench")
ref = await system.spawn(NoopActor, "echo")
# Warmup
for _ in range(100):
await ref.ask("warmup")
latencies = []
for _ in range(n):
t0 = time.perf_counter()
await ref.ask("ping")
latencies.append((time.perf_counter() - t0) * 1_000_000) # microseconds
await system.shutdown()
latencies.sort()
p50 = latencies[len(latencies) // 2]
p99 = latencies[int(len(latencies) * 0.99)]
p999 = latencies[int(len(latencies) * 0.999)]
print(f" ask latency: p50={p50:.0f}µs p99={p99:.0f}µs p99.9={p999:.0f}µs")
async def bench_concurrent_actors(num_actors=1000, msgs_per_actor=100):
"""Measure throughput with many concurrent actors."""
system = ActorSystem("bench")
refs = []
for i in range(num_actors):
refs.append(await system.spawn(CounterActor, f"a{i}", mailbox_size=msgs_per_actor + 10))
start = time.perf_counter()
async def send_batch(ref, n):
for i in range(n):
await ref.tell("inc")
# Yield control every 50 msgs so actor loops can drain
if i % 50 == 49:
await asyncio.sleep(0)
return await ref.ask("get", timeout=30.0)
results = await asyncio.gather(*[send_batch(r, msgs_per_actor) for r in refs])
elapsed = time.perf_counter() - start
total = num_actors * msgs_per_actor
delivered = sum(results)
rate = total / elapsed
loss = total - delivered
print(f" {num_actors} actors × {msgs_per_actor} msgs: {fmt(total)} in {elapsed:.2f}s = {fmt(int(rate))}/s (loss: {loss})")
await system.shutdown()
async def bench_actor_chain(depth=100):
"""Measure ask latency through a chain of actors (hop overhead)."""
system = ActorSystem("bench")
refs = []
for i in range(depth):
refs.append(await system.spawn(ChainActor, f"c{i}"))
# Link chain: c0 → c1 → ... → c99
for i in range(depth - 1):
refs[i]._cell.actor.next_ref = refs[i + 1]
start = time.perf_counter()
result = await refs[0].ask("ping", timeout=30.0)
elapsed = time.perf_counter() - start
assert result == "ping"
per_hop = elapsed / depth * 1_000_000 # µs
print(f" chain {depth} hops: {elapsed*1000:.1f}ms total, {per_hop:.0f}µs/hop")
await system.shutdown()
async def bench_middleware_overhead(n=50_000):
"""Measure overhead of middleware pipeline."""
mw = CountMiddleware()
system_plain = ActorSystem("plain")
ref_plain = await system_plain.spawn(NoopActor, "echo")
system_mw = ActorSystem("mw")
ref_mw = await system_mw.spawn(NoopActor, "echo", middlewares=[mw])
# Plain
t0 = time.perf_counter()
for _ in range(n):
await ref_plain.ask("p")
plain_elapsed = time.perf_counter() - t0
# With middleware
t0 = time.perf_counter()
for _ in range(n):
await ref_mw.ask("p")
mw_elapsed = time.perf_counter() - t0
overhead = ((mw_elapsed - plain_elapsed) / plain_elapsed) * 100
print(f" middleware overhead: {overhead:+.1f}% ({fmt(n)} ask calls, 1 middleware)")
await system_plain.shutdown()
await system_mw.shutdown()
async def bench_executor_parallel(num_tasks=16):
"""Measure thread pool parallelism with CPU work."""
system = ActorSystem("bench", executor_workers=8)
refs = [await system.spawn(ComputeActor, f"cpu{i}") for i in range(num_tasks)]
start = time.perf_counter()
results = await asyncio.gather(*[r.ask(10_000, timeout=30.0) for r in refs])
elapsed = time.perf_counter() - start
print(f" executor parallel: {num_tasks} fib(10K) in {elapsed*1000:.0f}ms ({num_tasks/elapsed:.0f} tasks/s)")
await system.shutdown()
async def bench_spawn_teardown(n=5000):
"""Measure actor spawn + shutdown speed."""
system = ActorSystem("bench")
start = time.perf_counter()
refs = []
for i in range(n):
refs.append(await system.spawn(NoopActor, f"a{i}"))
spawn_elapsed = time.perf_counter() - start
start = time.perf_counter()
await system.shutdown()
shutdown_elapsed = time.perf_counter() - start
print(f" spawn {n}: {spawn_elapsed*1000:.0f}ms ({n/spawn_elapsed:.0f}/s)")
print(f" shutdown {n}: {shutdown_elapsed*1000:.0f}ms")
async def main():
print("=" * 60)
print(" Actor Framework Benchmarks")
print("=" * 60)
print()
print("[Throughput]")
await bench_tell_throughput()
await bench_ask_throughput()
print()
print("[Latency]")
await bench_ask_latency()
await bench_actor_chain()
print()
print("[Concurrency]")
await bench_concurrent_actors()
await bench_executor_parallel()
print()
print("[Overhead]")
await bench_middleware_overhead()
print()
print("[Lifecycle]")
await bench_spawn_teardown()
print()
print("=" * 60)
print(" Done")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(main())

442
backend/tests/test_actor.py Normal file
View File

@ -0,0 +1,442 @@
"""Tests for the async Actor framework."""
import asyncio
import pytest
from deerflow.actor import (
Actor,
ActorRef,
ActorSystem,
AllForOneStrategy,
Directive,
Middleware,
OneForOneStrategy,
)
from deerflow.actor.ref import ActorStoppedError
# ---------------------------------------------------------------------------
# Basic actors for testing
# ---------------------------------------------------------------------------
class EchoActor(Actor):
async def on_receive(self, message):
return message
class CounterActor(Actor):
async def on_started(self):
self.count = 0
async def on_receive(self, message):
if message == "inc":
self.count += 1
elif message == "get":
return self.count
class CrashActor(Actor):
async def on_receive(self, message):
if message == "crash":
raise ValueError("boom")
return "ok"
class ParentActor(Actor):
def __init__(self):
self.child_ref: ActorRef | None = None
self.restarts = 0
def supervisor_strategy(self):
return OneForOneStrategy(max_restarts=3, within_seconds=60)
async def on_started(self):
self.child_ref = await self.context.spawn(CrashActor, "child")
async def on_receive(self, message):
if message == "get_child":
return self.child_ref
class StopOnCrashParent(Actor):
def supervisor_strategy(self):
return OneForOneStrategy(decider=lambda _: Directive.stop)
async def on_started(self):
self.child_ref = await self.context.spawn(CrashActor, "child")
async def on_receive(self, message):
if message == "get_child":
return self.child_ref
class AllForOneParent(Actor):
def supervisor_strategy(self):
return AllForOneStrategy(max_restarts=2, within_seconds=60)
async def on_started(self):
self.c1 = await self.context.spawn(CounterActor, "c1")
self.c2 = await self.context.spawn(CrashActor, "c2")
async def on_receive(self, message):
if message == "get_children":
return (self.c1, self.c2)
class LifecycleActor(Actor):
started = False
stopped = False
restarted_with: Exception | None = None
async def on_started(self):
LifecycleActor.started = True
async def on_stopped(self):
LifecycleActor.stopped = True
async def on_restart(self, error):
LifecycleActor.restarted_with = error
async def on_receive(self, message):
if message == "crash":
raise RuntimeError("lifecycle crash")
return "alive"
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestBasicMessaging:
@pytest.mark.anyio
async def test_tell_and_ask(self):
system = ActorSystem("test")
ref = await system.spawn(EchoActor, "echo")
result = await ref.ask("hello")
assert result == "hello"
await system.shutdown()
@pytest.mark.anyio
async def test_ask_timeout(self):
class SlowActor(Actor):
async def on_receive(self, message):
await asyncio.sleep(10)
system = ActorSystem("test")
ref = await system.spawn(SlowActor, "slow")
with pytest.raises(asyncio.TimeoutError):
await ref.ask("hi", timeout=0.1)
await system.shutdown()
@pytest.mark.anyio
async def test_tell_fire_and_forget(self):
system = ActorSystem("test")
ref = await system.spawn(CounterActor, "counter")
await ref.tell("inc")
await ref.tell("inc")
await ref.tell("inc")
# Give the actor time to process
await asyncio.sleep(0.05)
count = await ref.ask("get")
assert count == 3
await system.shutdown()
@pytest.mark.anyio
async def test_ask_stopped_actor(self):
system = ActorSystem("test")
ref = await system.spawn(EchoActor, "echo")
ref.stop()
await asyncio.sleep(0.05)
with pytest.raises(ActorStoppedError):
await ref.ask("hello")
await system.shutdown()
@pytest.mark.anyio
async def test_tell_stopped_actor_goes_to_dead_letters(self):
system = ActorSystem("test")
ref = await system.spawn(EchoActor, "echo")
ref.stop()
await asyncio.sleep(0.05)
await ref.tell("orphan")
assert len(system.dead_letters) >= 1
await system.shutdown()
class TestActorPath:
@pytest.mark.anyio
async def test_root_actor_path(self):
system = ActorSystem("app")
ref = await system.spawn(EchoActor, "echo")
assert ref.path == "/app/echo"
await system.shutdown()
@pytest.mark.anyio
async def test_child_actor_path(self):
system = ActorSystem("app")
parent = await system.spawn(ParentActor, "parent")
child: ActorRef = await parent.ask("get_child")
assert child.path == "/app/parent/child"
await system.shutdown()
class TestLifecycle:
@pytest.mark.anyio
async def test_on_started_called(self):
LifecycleActor.started = False
system = ActorSystem("test")
await system.spawn(LifecycleActor, "lc")
assert LifecycleActor.started is True
await system.shutdown()
@pytest.mark.anyio
async def test_on_stopped_called(self):
LifecycleActor.stopped = False
system = ActorSystem("test")
ref = await system.spawn(LifecycleActor, "lc")
ref.stop()
await asyncio.sleep(0.1)
assert LifecycleActor.stopped is True
await system.shutdown()
@pytest.mark.anyio
async def test_shutdown_stops_all(self):
system = ActorSystem("test")
r1 = await system.spawn(EchoActor, "a")
r2 = await system.spawn(EchoActor, "b")
await system.shutdown()
assert not r1.is_alive
assert not r2.is_alive
class TestSupervision:
@pytest.mark.anyio
async def test_restart_on_crash(self):
system = ActorSystem("test")
parent = await system.spawn(ParentActor, "parent")
child: ActorRef = await parent.ask("get_child")
# Crash the child
with pytest.raises(ValueError, match="boom"):
await child.ask("crash")
await asyncio.sleep(0.1)
# Child should still be alive (restarted)
assert child.is_alive
result = await child.ask("safe")
assert result == "ok"
await system.shutdown()
@pytest.mark.anyio
async def test_stop_directive(self):
system = ActorSystem("test")
parent = await system.spawn(StopOnCrashParent, "parent")
child: ActorRef = await parent.ask("get_child")
with pytest.raises(ValueError, match="boom"):
await child.ask("crash")
await asyncio.sleep(0.1)
assert not child.is_alive
await system.shutdown()
@pytest.mark.anyio
async def test_restart_limit_exceeded(self):
system = ActorSystem("test")
class StrictParent(Actor):
def supervisor_strategy(self):
return OneForOneStrategy(max_restarts=2, within_seconds=60)
async def on_started(self):
self.child_ref = await self.context.spawn(CrashActor, "child")
async def on_receive(self, message):
return self.child_ref
parent = await system.spawn(StrictParent, "parent")
child: ActorRef = await parent.ask("any")
# Exhaust restart limit
for _ in range(3):
try:
await child.ask("crash")
except (ValueError, ActorStoppedError):
pass
await asyncio.sleep(0.05)
# After exceeding limit, child should be stopped
assert not child.is_alive
await system.shutdown()
@pytest.mark.anyio
async def test_all_for_one_restarts_siblings(self):
system = ActorSystem("test")
parent = await system.spawn(AllForOneParent, "parent")
c1, c2 = await parent.ask("get_children")
# Increment counter on c1
await c1.tell("inc")
await asyncio.sleep(0.05)
count_before = await c1.ask("get")
assert count_before == 1
# Crash c2 → AllForOne should restart both
try:
await c2.ask("crash")
except ValueError:
pass
await asyncio.sleep(0.1)
# c1 was restarted, counter should be 0
count_after = await c1.ask("get")
assert count_after == 0
await system.shutdown()
class TestDeadLetters:
@pytest.mark.anyio
async def test_dead_letter_callback(self):
received = []
system = ActorSystem("test")
system.on_dead_letter(lambda dl: received.append(dl))
ref = await system.spawn(EchoActor, "echo")
ref.stop()
await asyncio.sleep(0.05)
await ref.tell("orphan")
assert len(received) >= 1
assert received[-1].message == "orphan"
await system.shutdown()
class TestDuplicateNames:
@pytest.mark.anyio
async def test_duplicate_root_name_raises(self):
system = ActorSystem("test")
await system.spawn(EchoActor, "echo")
with pytest.raises(ValueError, match="already exists"):
await system.spawn(EchoActor, "echo")
await system.shutdown()
# ---------------------------------------------------------------------------
# Middleware tests
# ---------------------------------------------------------------------------
class LogMiddleware(Middleware):
def __init__(self):
self.log: list[str] = []
async def on_receive(self, ctx, message, next_fn):
self.log.append(f"before:{message}")
result = await next_fn(ctx, message)
self.log.append(f"after:{result}")
return result
async def on_started(self, actor_ref):
self.log.append("started")
async def on_stopped(self, actor_ref):
self.log.append("stopped")
class TransformMiddleware(Middleware):
"""Uppercases string messages before passing to actor."""
async def on_receive(self, ctx, message, next_fn):
if isinstance(message, str):
message = message.upper()
return await next_fn(ctx, message)
class TestExecutor:
@pytest.mark.anyio
async def test_run_in_executor(self):
"""Blocking function runs in thread pool without blocking event loop."""
import time
class BlockingActor(Actor):
async def on_receive(self, message):
# Simulate blocking I/O via thread pool
result = await self.context.run_in_executor(time.sleep, 0.01)
return "done"
system = ActorSystem("test", executor_workers=2)
ref = await system.spawn(BlockingActor, "blocker")
result = await ref.ask("go", timeout=5.0)
assert result == "done"
await system.shutdown()
@pytest.mark.anyio
async def test_concurrent_blocking_calls(self):
"""Multiple actors can run blocking I/O concurrently via shared pool."""
import time
class SlowActor(Actor):
async def on_receive(self, message):
await self.context.run_in_executor(time.sleep, 0.1)
return "ok"
system = ActorSystem("test", executor_workers=4)
refs = [await system.spawn(SlowActor, f"s{i}") for i in range(4)]
start = time.monotonic()
results = await asyncio.gather(*[r.ask("go", timeout=5.0) for r in refs])
elapsed = time.monotonic() - start
assert all(r == "ok" for r in results)
# 4 parallel × 0.1s should finish in ~0.1-0.2s, not 0.4s
assert elapsed < 0.3
await system.shutdown()
class TestMiddleware:
@pytest.mark.anyio
async def test_middleware_intercepts_messages(self):
mw = LogMiddleware()
system = ActorSystem("test")
ref = await system.spawn(EchoActor, "echo", middlewares=[mw])
result = await ref.ask("hello")
assert result == "hello"
assert "before:hello" in mw.log
assert "after:hello" in mw.log
await system.shutdown()
@pytest.mark.anyio
async def test_middleware_lifecycle_hooks(self):
mw = LogMiddleware()
system = ActorSystem("test")
ref = await system.spawn(EchoActor, "echo", middlewares=[mw])
assert "started" in mw.log
ref.stop()
await asyncio.sleep(0.1)
assert "stopped" in mw.log
await system.shutdown()
@pytest.mark.anyio
async def test_middleware_chain_order(self):
"""First middleware wraps outermost — sees original message."""
mw1 = LogMiddleware()
mw2 = TransformMiddleware()
system = ActorSystem("test")
# Chain: mw1(mw2(actor)). mw1 logs original, mw2 uppercases, actor echoes
ref = await system.spawn(EchoActor, "echo", middlewares=[mw1, mw2])
result = await ref.ask("hello")
assert result == "HELLO" # TransformMiddleware uppercased
assert "before:hello" in mw1.log # LogMiddleware saw original
assert "after:HELLO" in mw1.log # LogMiddleware saw transformed result
await system.shutdown()
@pytest.mark.anyio
async def test_middleware_with_tell(self):
mw = LogMiddleware()
system = ActorSystem("test")
await system.spawn(CounterActor, "counter", middlewares=[mw])
# tell goes through middleware too
assert any("before:" in entry for entry in mw.log) is False
await system.shutdown()