"""Tests for the in-memory StreamBridge implementation.""" import asyncio import re import pytest from deerflow.runtime import END_SENTINEL, HEARTBEAT_SENTINEL, MemoryStreamBridge, make_stream_bridge # --------------------------------------------------------------------------- # Unit tests for MemoryStreamBridge # --------------------------------------------------------------------------- @pytest.fixture def bridge() -> MemoryStreamBridge: return MemoryStreamBridge(queue_maxsize=256) @pytest.mark.anyio async def test_publish_subscribe(bridge: MemoryStreamBridge): """Three events followed by end should be received in order.""" run_id = "run-1" await bridge.publish(run_id, "metadata", {"run_id": run_id}) await bridge.publish(run_id, "values", {"messages": []}) await bridge.publish(run_id, "updates", {"step": 1}) await bridge.publish_end(run_id) received = [] async for entry in bridge.subscribe(run_id, heartbeat_interval=1.0): received.append(entry) if entry is END_SENTINEL: break assert len(received) == 4 assert received[0].event == "metadata" assert received[1].event == "values" assert received[2].event == "updates" assert received[3] is END_SENTINEL @pytest.mark.anyio async def test_heartbeat(bridge: MemoryStreamBridge): """When no events arrive within the heartbeat interval, yield a heartbeat.""" run_id = "run-heartbeat" bridge._get_or_create_queue(run_id) # ensure queue exists received = [] async def consumer(): async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1): received.append(entry) if entry is HEARTBEAT_SENTINEL: break await asyncio.wait_for(consumer(), timeout=2.0) assert len(received) == 1 assert received[0] is HEARTBEAT_SENTINEL @pytest.mark.anyio async def test_cleanup(bridge: MemoryStreamBridge): """After cleanup, the run's queue is removed.""" run_id = "run-cleanup" await bridge.publish(run_id, "test", {}) assert run_id in bridge._queues await bridge.cleanup(run_id) assert run_id not in bridge._queues assert run_id not in bridge._counters @pytest.mark.anyio async def test_backpressure(): """With maxsize=1, publish should not block forever.""" bridge = MemoryStreamBridge(queue_maxsize=1) run_id = "run-bp" await bridge.publish(run_id, "first", {}) # Second publish should either succeed after queue drains or warn+drop # It should not hang indefinitely async def publish_second(): await bridge.publish(run_id, "second", {}) # Give it a generous timeout — the publish timeout is 30s but we don't # want to wait that long in tests. Instead, drain the queue first. async def drain(): await asyncio.sleep(0.05) bridge._queues[run_id].get_nowait() await asyncio.gather(publish_second(), drain()) assert bridge._queues[run_id].qsize() == 1 @pytest.mark.anyio async def test_multiple_runs(bridge: MemoryStreamBridge): """Two different run_ids should not interfere with each other.""" await bridge.publish("run-a", "event-a", {"a": 1}) await bridge.publish("run-b", "event-b", {"b": 2}) await bridge.publish_end("run-a") await bridge.publish_end("run-b") events_a = [] async for entry in bridge.subscribe("run-a", heartbeat_interval=1.0): events_a.append(entry) if entry is END_SENTINEL: break events_b = [] async for entry in bridge.subscribe("run-b", heartbeat_interval=1.0): events_b.append(entry) if entry is END_SENTINEL: break assert len(events_a) == 2 assert events_a[0].event == "event-a" assert events_a[0].data == {"a": 1} assert len(events_b) == 2 assert events_b[0].event == "event-b" assert events_b[0].data == {"b": 2} @pytest.mark.anyio async def test_event_id_format(bridge: MemoryStreamBridge): """Event IDs should use timestamp-sequence format.""" run_id = "run-id-format" await bridge.publish(run_id, "test", {"key": "value"}) await bridge.publish_end(run_id) received = [] async for entry in bridge.subscribe(run_id, heartbeat_interval=1.0): received.append(entry) if entry is END_SENTINEL: break event = received[0] assert re.match(r"^\d+-\d+$", event.id), f"Expected timestamp-seq format, got {event.id}" # --------------------------------------------------------------------------- # Factory tests # --------------------------------------------------------------------------- @pytest.mark.anyio async def test_make_stream_bridge_defaults(): """make_stream_bridge() with no config yields a MemoryStreamBridge.""" async with make_stream_bridge() as bridge: assert isinstance(bridge, MemoryStreamBridge)