mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-26 03:38:06 +00:00
- Fix spawn() zombie cell: clean up registry on start() failure - Fix shutdown(): cancel + await tasks that exceed graceful timeout - Fix _shutdown(): await mailbox.close() to release backend resources - Fix escalate directive: stop failing child before propagating to grandparent - Fix RedisMailbox.put(): wrap Redis errors in try/except, return False on failure - Fix retry.py: replace assert with proper raise for last_exc - Add put_batch() to Mailbox abstraction for single-roundtrip bulk enqueue - Add RedisMailbox.put_batch() with atomic Lua script for bounded queues - Add MailboxFullError exception type for semantic backpressure handling - Add redis>=7.4.0 dependency with public PyPI sources in uv.lock Tests added (31 total, up from 27): - test_middleware_on_restart_hook: verifies middleware.on_restart() on supervision restart - test_ask_propagates_actor_exception: ask() re-raises original exception type - test_ask_propagates_exception_while_supervised: exception propagates; root actor survives - test_ask_timeout_late_reply_no_exception: late reply after timeout is silent no-op - test_actor_backpressure.py: MailboxFullError + dead letter on full mailbox - test_actor_retry.py: ask_with_retry with exponential backoff - test_mailbox_redis.py: RedisMailbox put/get/batch/close - bench_actor_redis.py: RedisMailbox throughput benchmarks
535 lines
17 KiB
Python
535 lines
17 KiB
Python
"""Tests for the async Actor framework."""
|
||
|
||
import asyncio
|
||
|
||
import pytest
|
||
|
||
from deerflow.actor import (
|
||
Actor,
|
||
ActorRef,
|
||
ActorSystem,
|
||
AllForOneStrategy,
|
||
Directive,
|
||
Middleware,
|
||
OneForOneStrategy,
|
||
)
|
||
from deerflow.actor.ref import ActorStoppedError
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Basic actors for testing
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class EchoActor(Actor):
|
||
async def on_receive(self, message):
|
||
return message
|
||
|
||
|
||
class CounterActor(Actor):
|
||
async def on_started(self):
|
||
self.count = 0
|
||
|
||
async def on_receive(self, message):
|
||
if message == "inc":
|
||
self.count += 1
|
||
elif message == "get":
|
||
return self.count
|
||
|
||
|
||
class CrashActor(Actor):
|
||
async def on_receive(self, message):
|
||
if message == "crash":
|
||
raise ValueError("boom")
|
||
return "ok"
|
||
|
||
|
||
class ParentActor(Actor):
|
||
def __init__(self):
|
||
self.child_ref: ActorRef | None = None
|
||
self.restarts = 0
|
||
|
||
def supervisor_strategy(self):
|
||
return OneForOneStrategy(max_restarts=3, within_seconds=60)
|
||
|
||
async def on_started(self):
|
||
self.child_ref = await self.context.spawn(CrashActor, "child")
|
||
|
||
async def on_receive(self, message):
|
||
if message == "get_child":
|
||
return self.child_ref
|
||
|
||
|
||
class StopOnCrashParent(Actor):
|
||
def supervisor_strategy(self):
|
||
return OneForOneStrategy(decider=lambda _: Directive.stop)
|
||
|
||
async def on_started(self):
|
||
self.child_ref = await self.context.spawn(CrashActor, "child")
|
||
|
||
async def on_receive(self, message):
|
||
if message == "get_child":
|
||
return self.child_ref
|
||
|
||
|
||
class AllForOneParent(Actor):
|
||
def supervisor_strategy(self):
|
||
return AllForOneStrategy(max_restarts=2, within_seconds=60)
|
||
|
||
async def on_started(self):
|
||
self.c1 = await self.context.spawn(CounterActor, "c1")
|
||
self.c2 = await self.context.spawn(CrashActor, "c2")
|
||
|
||
async def on_receive(self, message):
|
||
if message == "get_children":
|
||
return (self.c1, self.c2)
|
||
|
||
|
||
class LifecycleActor(Actor):
|
||
started = False
|
||
stopped = False
|
||
restarted_with: Exception | None = None
|
||
|
||
async def on_started(self):
|
||
LifecycleActor.started = True
|
||
|
||
async def on_stopped(self):
|
||
LifecycleActor.stopped = True
|
||
|
||
async def on_restart(self, error):
|
||
LifecycleActor.restarted_with = error
|
||
|
||
async def on_receive(self, message):
|
||
if message == "crash":
|
||
raise RuntimeError("lifecycle crash")
|
||
return "alive"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Tests
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestBasicMessaging:
|
||
@pytest.mark.anyio
|
||
async def test_tell_and_ask(self):
|
||
system = ActorSystem("test")
|
||
ref = await system.spawn(EchoActor, "echo")
|
||
result = await ref.ask("hello")
|
||
assert result == "hello"
|
||
await system.shutdown()
|
||
|
||
@pytest.mark.anyio
|
||
async def test_ask_timeout(self):
|
||
class SlowActor(Actor):
|
||
async def on_receive(self, message):
|
||
await asyncio.sleep(10)
|
||
|
||
system = ActorSystem("test")
|
||
ref = await system.spawn(SlowActor, "slow")
|
||
with pytest.raises(asyncio.TimeoutError):
|
||
await ref.ask("hi", timeout=0.1)
|
||
await system.shutdown()
|
||
|
||
@pytest.mark.anyio
|
||
async def test_tell_fire_and_forget(self):
|
||
system = ActorSystem("test")
|
||
ref = await system.spawn(CounterActor, "counter")
|
||
await ref.tell("inc")
|
||
await ref.tell("inc")
|
||
await ref.tell("inc")
|
||
# Give the actor time to process
|
||
await asyncio.sleep(0.05)
|
||
count = await ref.ask("get")
|
||
assert count == 3
|
||
await system.shutdown()
|
||
|
||
@pytest.mark.anyio
|
||
async def test_ask_stopped_actor(self):
|
||
system = ActorSystem("test")
|
||
ref = await system.spawn(EchoActor, "echo")
|
||
ref.stop()
|
||
await asyncio.sleep(0.05)
|
||
with pytest.raises(ActorStoppedError):
|
||
await ref.ask("hello")
|
||
await system.shutdown()
|
||
|
||
@pytest.mark.anyio
|
||
async def test_tell_stopped_actor_goes_to_dead_letters(self):
|
||
system = ActorSystem("test")
|
||
ref = await system.spawn(EchoActor, "echo")
|
||
ref.stop()
|
||
await asyncio.sleep(0.05)
|
||
await ref.tell("orphan")
|
||
assert len(system.dead_letters) >= 1
|
||
await system.shutdown()
|
||
|
||
|
||
class TestActorPath:
|
||
@pytest.mark.anyio
|
||
async def test_root_actor_path(self):
|
||
system = ActorSystem("app")
|
||
ref = await system.spawn(EchoActor, "echo")
|
||
assert ref.path == "/app/echo"
|
||
await system.shutdown()
|
||
|
||
@pytest.mark.anyio
|
||
async def test_child_actor_path(self):
|
||
system = ActorSystem("app")
|
||
parent = await system.spawn(ParentActor, "parent")
|
||
child: ActorRef = await parent.ask("get_child")
|
||
assert child.path == "/app/parent/child"
|
||
await system.shutdown()
|
||
|
||
|
||
class TestLifecycle:
|
||
@pytest.mark.anyio
|
||
async def test_on_started_called(self):
|
||
LifecycleActor.started = False
|
||
system = ActorSystem("test")
|
||
await system.spawn(LifecycleActor, "lc")
|
||
assert LifecycleActor.started is True
|
||
await system.shutdown()
|
||
|
||
@pytest.mark.anyio
|
||
async def test_on_stopped_called(self):
|
||
LifecycleActor.stopped = False
|
||
system = ActorSystem("test")
|
||
ref = await system.spawn(LifecycleActor, "lc")
|
||
ref.stop()
|
||
await asyncio.sleep(0.1)
|
||
assert LifecycleActor.stopped is True
|
||
await system.shutdown()
|
||
|
||
@pytest.mark.anyio
|
||
async def test_shutdown_stops_all(self):
|
||
system = ActorSystem("test")
|
||
r1 = await system.spawn(EchoActor, "a")
|
||
r2 = await system.spawn(EchoActor, "b")
|
||
await system.shutdown()
|
||
assert not r1.is_alive
|
||
assert not r2.is_alive
|
||
|
||
|
||
class TestSupervision:
|
||
@pytest.mark.anyio
|
||
async def test_restart_on_crash(self):
|
||
system = ActorSystem("test")
|
||
parent = await system.spawn(ParentActor, "parent")
|
||
child: ActorRef = await parent.ask("get_child")
|
||
|
||
# Crash the child
|
||
with pytest.raises(ValueError, match="boom"):
|
||
await child.ask("crash")
|
||
await asyncio.sleep(0.1)
|
||
|
||
# Child should still be alive (restarted)
|
||
assert child.is_alive
|
||
result = await child.ask("safe")
|
||
assert result == "ok"
|
||
await system.shutdown()
|
||
|
||
@pytest.mark.anyio
|
||
async def test_stop_directive(self):
|
||
system = ActorSystem("test")
|
||
parent = await system.spawn(StopOnCrashParent, "parent")
|
||
child: ActorRef = await parent.ask("get_child")
|
||
|
||
with pytest.raises(ValueError, match="boom"):
|
||
await child.ask("crash")
|
||
await asyncio.sleep(0.1)
|
||
|
||
assert not child.is_alive
|
||
await system.shutdown()
|
||
|
||
@pytest.mark.anyio
|
||
async def test_restart_limit_exceeded(self):
|
||
system = ActorSystem("test")
|
||
|
||
class StrictParent(Actor):
|
||
def supervisor_strategy(self):
|
||
return OneForOneStrategy(max_restarts=2, within_seconds=60)
|
||
|
||
async def on_started(self):
|
||
self.child_ref = await self.context.spawn(CrashActor, "child")
|
||
|
||
async def on_receive(self, message):
|
||
return self.child_ref
|
||
|
||
parent = await system.spawn(StrictParent, "parent")
|
||
child: ActorRef = await parent.ask("any")
|
||
|
||
# Exhaust restart limit
|
||
for _ in range(3):
|
||
try:
|
||
await child.ask("crash")
|
||
except (ValueError, ActorStoppedError):
|
||
pass
|
||
await asyncio.sleep(0.05)
|
||
|
||
# After exceeding limit, child should be stopped
|
||
assert not child.is_alive
|
||
await system.shutdown()
|
||
|
||
@pytest.mark.anyio
|
||
async def test_all_for_one_restarts_siblings(self):
|
||
system = ActorSystem("test")
|
||
parent = await system.spawn(AllForOneParent, "parent")
|
||
c1, c2 = await parent.ask("get_children")
|
||
|
||
# Increment counter on c1
|
||
await c1.tell("inc")
|
||
await asyncio.sleep(0.05)
|
||
count_before = await c1.ask("get")
|
||
assert count_before == 1
|
||
|
||
# Crash c2 → AllForOne should restart both
|
||
try:
|
||
await c2.ask("crash")
|
||
except ValueError:
|
||
pass
|
||
await asyncio.sleep(0.1)
|
||
|
||
# c1 was restarted, counter should be 0
|
||
count_after = await c1.ask("get")
|
||
assert count_after == 0
|
||
await system.shutdown()
|
||
|
||
|
||
class TestDeadLetters:
|
||
@pytest.mark.anyio
|
||
async def test_dead_letter_callback(self):
|
||
received = []
|
||
system = ActorSystem("test")
|
||
system.on_dead_letter(lambda dl: received.append(dl))
|
||
|
||
ref = await system.spawn(EchoActor, "echo")
|
||
ref.stop()
|
||
await asyncio.sleep(0.05)
|
||
await ref.tell("orphan")
|
||
|
||
assert len(received) >= 1
|
||
assert received[-1].message == "orphan"
|
||
await system.shutdown()
|
||
|
||
|
||
class TestDuplicateNames:
|
||
@pytest.mark.anyio
|
||
async def test_duplicate_root_name_raises(self):
|
||
system = ActorSystem("test")
|
||
await system.spawn(EchoActor, "echo")
|
||
with pytest.raises(ValueError, match="already exists"):
|
||
await system.spawn(EchoActor, "echo")
|
||
await system.shutdown()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Middleware tests
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class LogMiddleware(Middleware):
|
||
def __init__(self):
|
||
self.log: list[str] = []
|
||
|
||
async def on_receive(self, ctx, message, next_fn):
|
||
self.log.append(f"before:{message}")
|
||
result = await next_fn(ctx, message)
|
||
self.log.append(f"after:{result}")
|
||
return result
|
||
|
||
async def on_started(self, actor_ref):
|
||
self.log.append("started")
|
||
|
||
async def on_stopped(self, actor_ref):
|
||
self.log.append("stopped")
|
||
|
||
|
||
class TransformMiddleware(Middleware):
|
||
"""Uppercases string messages before passing to actor."""
|
||
|
||
async def on_receive(self, ctx, message, next_fn):
|
||
if isinstance(message, str):
|
||
message = message.upper()
|
||
return await next_fn(ctx, message)
|
||
|
||
|
||
class TestExecutor:
|
||
@pytest.mark.anyio
|
||
async def test_run_in_executor(self):
|
||
"""Blocking function runs in thread pool without blocking event loop."""
|
||
import time
|
||
|
||
class BlockingActor(Actor):
|
||
async def on_receive(self, message):
|
||
# Simulate blocking I/O via thread pool
|
||
result = await self.context.run_in_executor(time.sleep, 0.01)
|
||
return "done"
|
||
|
||
system = ActorSystem("test", executor_workers=2)
|
||
ref = await system.spawn(BlockingActor, "blocker")
|
||
result = await ref.ask("go", timeout=5.0)
|
||
assert result == "done"
|
||
await system.shutdown()
|
||
|
||
@pytest.mark.anyio
|
||
async def test_concurrent_blocking_calls(self):
|
||
"""Multiple actors can run blocking I/O concurrently via shared pool."""
|
||
import time
|
||
|
||
class SlowActor(Actor):
|
||
async def on_receive(self, message):
|
||
await self.context.run_in_executor(time.sleep, 0.1)
|
||
return "ok"
|
||
|
||
system = ActorSystem("test", executor_workers=4)
|
||
refs = [await system.spawn(SlowActor, f"s{i}") for i in range(4)]
|
||
|
||
start = time.monotonic()
|
||
results = await asyncio.gather(*[r.ask("go", timeout=5.0) for r in refs])
|
||
elapsed = time.monotonic() - start
|
||
|
||
assert all(r == "ok" for r in results)
|
||
# 4 parallel × 0.1s should finish in ~0.1-0.2s, not 0.4s
|
||
assert elapsed < 0.3
|
||
await system.shutdown()
|
||
|
||
|
||
class TestMiddleware:
|
||
@pytest.mark.anyio
|
||
async def test_middleware_intercepts_messages(self):
|
||
mw = LogMiddleware()
|
||
system = ActorSystem("test")
|
||
ref = await system.spawn(EchoActor, "echo", middlewares=[mw])
|
||
result = await ref.ask("hello")
|
||
assert result == "hello"
|
||
assert "before:hello" in mw.log
|
||
assert "after:hello" in mw.log
|
||
await system.shutdown()
|
||
|
||
@pytest.mark.anyio
|
||
async def test_middleware_lifecycle_hooks(self):
|
||
mw = LogMiddleware()
|
||
system = ActorSystem("test")
|
||
ref = await system.spawn(EchoActor, "echo", middlewares=[mw])
|
||
assert "started" in mw.log
|
||
ref.stop()
|
||
await asyncio.sleep(0.1)
|
||
assert "stopped" in mw.log
|
||
await system.shutdown()
|
||
|
||
@pytest.mark.anyio
|
||
async def test_middleware_chain_order(self):
|
||
"""First middleware wraps outermost — sees original message."""
|
||
mw1 = LogMiddleware()
|
||
mw2 = TransformMiddleware()
|
||
system = ActorSystem("test")
|
||
# Chain: mw1(mw2(actor)). mw1 logs original, mw2 uppercases, actor echoes
|
||
ref = await system.spawn(EchoActor, "echo", middlewares=[mw1, mw2])
|
||
result = await ref.ask("hello")
|
||
assert result == "HELLO" # TransformMiddleware uppercased
|
||
assert "before:hello" in mw1.log # LogMiddleware saw original
|
||
assert "after:HELLO" in mw1.log # LogMiddleware saw transformed result
|
||
await system.shutdown()
|
||
|
||
@pytest.mark.anyio
|
||
async def test_middleware_with_tell(self):
|
||
mw = LogMiddleware()
|
||
system = ActorSystem("test")
|
||
await system.spawn(CounterActor, "counter", middlewares=[mw])
|
||
# tell goes through middleware too
|
||
assert any("before:" in entry for entry in mw.log) is False
|
||
await system.shutdown()
|
||
|
||
@pytest.mark.anyio
|
||
async def test_middleware_on_restart_hook(self):
|
||
"""on_restart is called on the middleware when a child actor is restarted."""
|
||
|
||
class RestartTrackingMiddleware(Middleware):
|
||
def __init__(self):
|
||
self.restart_errors: list[Exception] = []
|
||
|
||
async def on_restart(self, actor_ref, error):
|
||
self.restart_errors.append(error)
|
||
|
||
mw = RestartTrackingMiddleware()
|
||
|
||
class ChildSpawningParent(Actor):
|
||
async def on_receive(self, message):
|
||
if message == "spawn":
|
||
ref = await self.context.spawn(CrashActor, "child", middlewares=[mw])
|
||
return ref
|
||
|
||
system = ActorSystem("test")
|
||
parent = await system.spawn(ChildSpawningParent, "parent")
|
||
child = await parent.ask("spawn")
|
||
|
||
# Crash the child — parent supervisor will restart it
|
||
try:
|
||
await child.ask("crash")
|
||
except ValueError:
|
||
pass
|
||
await asyncio.sleep(0.1)
|
||
|
||
assert len(mw.restart_errors) == 1
|
||
assert isinstance(mw.restart_errors[0], ValueError)
|
||
await system.shutdown()
|
||
|
||
|
||
class TestAskErrorPropagation:
|
||
@pytest.mark.anyio
|
||
async def test_ask_propagates_actor_exception(self):
|
||
"""ask() re-raises the original exception type when on_receive crashes."""
|
||
|
||
class BoomActor(Actor):
|
||
async def on_receive(self, message):
|
||
raise ValueError("intentional crash")
|
||
|
||
system = ActorSystem("test")
|
||
ref = await system.spawn(BoomActor, "boom")
|
||
with pytest.raises(ValueError, match="intentional crash"):
|
||
await ref.ask("trigger")
|
||
await system.shutdown()
|
||
|
||
@pytest.mark.anyio
|
||
async def test_ask_propagates_exception_while_supervised(self):
|
||
"""ask() gets the exception even when the actor is supervised (not stopped)."""
|
||
|
||
class SometimesCrashActor(Actor):
|
||
async def on_receive(self, message):
|
||
if message == "crash":
|
||
raise RuntimeError("supervised crash")
|
||
return "ok"
|
||
|
||
system = ActorSystem("test")
|
||
ref = await system.spawn(SometimesCrashActor, "sca")
|
||
with pytest.raises(RuntimeError, match="supervised crash"):
|
||
await ref.ask("crash")
|
||
# Root actor keeps running after a crash (consecutive_failures, not restart)
|
||
result = await ref.ask("hello", timeout=2.0)
|
||
assert result == "ok"
|
||
await system.shutdown()
|
||
|
||
@pytest.mark.anyio
|
||
async def test_ask_timeout_late_reply_no_exception(self):
|
||
"""Late reply arriving after ask() timeout is silently dropped — no exception, no orphaned future."""
|
||
|
||
class SlowActor(Actor):
|
||
async def on_receive(self, message):
|
||
await asyncio.sleep(0.3)
|
||
return "late"
|
||
|
||
system = ActorSystem("test")
|
||
ref = await system.spawn(SlowActor, "slow")
|
||
|
||
with pytest.raises(asyncio.TimeoutError):
|
||
await ref.ask("go", timeout=0.05)
|
||
|
||
# Wait for actor to finish processing — late reply arrives, should be a no-op
|
||
await asyncio.sleep(0.4)
|
||
# System still functional: no orphaned futures, no leaked state
|
||
assert ref.is_alive
|
||
result = await ref.ask("go", timeout=2.0)
|
||
assert result == "late"
|
||
await system.shutdown()
|