"""Tests for RunManager.""" import re import pytest from deerflow.runtime import RunManager, RunStatus from deerflow.runtime.runs.store.memory import MemoryRunStore ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}") @pytest.fixture def manager() -> RunManager: return RunManager() @pytest.mark.anyio async def test_create_and_get(manager: RunManager): """Created run should be retrievable with new fields.""" record = await manager.create( "thread-1", "lead_agent", metadata={"key": "val"}, kwargs={"input": {}}, multitask_strategy="reject", ) assert record.status == RunStatus.pending assert record.thread_id == "thread-1" assert record.assistant_id == "lead_agent" assert record.metadata == {"key": "val"} assert record.kwargs == {"input": {}} assert record.multitask_strategy == "reject" assert ISO_RE.match(record.created_at) assert ISO_RE.match(record.updated_at) fetched = manager.get(record.run_id) assert fetched is record @pytest.mark.anyio async def test_status_transitions(manager: RunManager): """Status should transition pending -> running -> success.""" record = await manager.create("thread-1") assert record.status == RunStatus.pending await manager.set_status(record.run_id, RunStatus.running) assert record.status == RunStatus.running assert ISO_RE.match(record.updated_at) await manager.set_status(record.run_id, RunStatus.success) assert record.status == RunStatus.success @pytest.mark.anyio async def test_cancel(manager: RunManager): """Cancel should set abort_event and transition to interrupted.""" record = await manager.create("thread-1") await manager.set_status(record.run_id, RunStatus.running) cancelled = await manager.cancel(record.run_id) assert cancelled is True assert record.abort_event.is_set() assert record.status == RunStatus.interrupted @pytest.mark.anyio async def test_cancel_not_inflight(manager: RunManager): """Cancelling a completed run should return False.""" record = await manager.create("thread-1") await manager.set_status(record.run_id, RunStatus.success) cancelled = await manager.cancel(record.run_id) assert cancelled is False @pytest.mark.anyio async def test_list_by_thread(manager: RunManager): """Same thread should return multiple runs.""" r1 = await manager.create("thread-1") r2 = await manager.create("thread-1") await manager.create("thread-2") runs = await manager.list_by_thread("thread-1") assert len(runs) == 2 # list_by_thread returns oldest-first (ascending created_at). assert runs[0].run_id == r1.run_id assert runs[1].run_id == r2.run_id @pytest.mark.anyio async def test_list_by_thread_is_stable_when_timestamps_tie(manager: RunManager, monkeypatch: pytest.MonkeyPatch): """Ordering should be stable (insertion order) even when timestamps tie.""" monkeypatch.setattr("deerflow.runtime.runs.manager._now_iso", lambda: "2026-01-01T00:00:00+00:00") r1 = await manager.create("thread-1") r2 = await manager.create("thread-1") runs = await manager.list_by_thread("thread-1") assert [run.run_id for run in runs] == [r1.run_id, r2.run_id] @pytest.mark.anyio async def test_has_inflight(manager: RunManager): """has_inflight should be True when a run is pending or running.""" record = await manager.create("thread-1") assert await manager.has_inflight("thread-1") is True await manager.set_status(record.run_id, RunStatus.success) assert await manager.has_inflight("thread-1") is False @pytest.mark.anyio async def test_cleanup(manager: RunManager): """After cleanup, the run should be gone.""" record = await manager.create("thread-1") run_id = record.run_id await manager.cleanup(run_id, delay=0) assert manager.get(run_id) is None @pytest.mark.anyio async def test_set_status_with_error(manager: RunManager): """Error message should be stored on the record.""" record = await manager.create("thread-1") await manager.set_status(record.run_id, RunStatus.error, error="Something went wrong") assert record.status == RunStatus.error assert record.error == "Something went wrong" @pytest.mark.anyio async def test_get_nonexistent(manager: RunManager): """Getting a nonexistent run should return None.""" assert manager.get("does-not-exist") is None @pytest.mark.anyio async def test_create_defaults(manager: RunManager): """Create with no optional args should use defaults.""" record = await manager.create("thread-1") assert record.metadata == {} assert record.kwargs == {} assert record.multitask_strategy == "reject" assert record.assistant_id is None @pytest.mark.anyio async def test_model_name_create_or_reject(): """create_or_reject should accept and persist model_name.""" from deerflow.runtime.runs.schemas import DisconnectMode store = MemoryRunStore() mgr = RunManager(store=store) record = await mgr.create_or_reject( "thread-1", assistant_id="lead_agent", on_disconnect=DisconnectMode.cancel, metadata={"key": "val"}, kwargs={"input": {}}, multitask_strategy="reject", model_name="anthropic.claude-sonnet-4-20250514-v1:0", ) assert record.model_name == "anthropic.claude-sonnet-4-20250514-v1:0" assert record.status == RunStatus.pending # Verify model_name was persisted to store stored = await store.get(record.run_id) assert stored is not None assert stored["model_name"] == "anthropic.claude-sonnet-4-20250514-v1:0" # Verify retrieval returns the model_name via in-memory record fetched = mgr.get(record.run_id) assert fetched is not None assert fetched.model_name == "anthropic.claude-sonnet-4-20250514-v1:0" @pytest.mark.anyio async def test_model_name_default_is_none(): """create_or_reject without model_name should default to None.""" from deerflow.runtime.runs.schemas import DisconnectMode store = MemoryRunStore() mgr = RunManager(store=store) record = await mgr.create_or_reject( "thread-1", on_disconnect=DisconnectMode.cancel, model_name=None, ) assert record.model_name is None stored = await store.get(record.run_id) assert stored["model_name"] is None # --------------------------------------------------------------------------- # Store fallback tests (simulates gateway restart scenario) # --------------------------------------------------------------------------- @pytest.fixture def manager_with_store() -> RunManager: """RunManager backed by a MemoryRunStore.""" return RunManager(store=MemoryRunStore()) @pytest.mark.anyio async def test_list_by_thread_returns_store_records_after_restart(manager_with_store: RunManager): """After in-memory state is cleared (simulating restart), list_by_thread should still return runs from the persistent store.""" mgr = manager_with_store r1 = await mgr.create("thread-1", "agent-1") await mgr.set_status(r1.run_id, RunStatus.success) r2 = await mgr.create("thread-1", "agent-2") await mgr.set_status(r2.run_id, RunStatus.error, error="boom") # Clear in-memory dict to simulate a restart mgr._runs.clear() runs = await mgr.list_by_thread("thread-1") assert len(runs) == 2 statuses = {r.run_id: r.status for r in runs} assert statuses[r1.run_id] == RunStatus.success assert statuses[r2.run_id] == RunStatus.error # Verify other fields survive the round-trip for r in runs: assert r.thread_id == "thread-1" assert ISO_RE.match(r.created_at) @pytest.mark.anyio async def test_list_by_thread_merges_in_memory_and_store(manager_with_store: RunManager): """In-memory runs should be included alongside store-only records.""" mgr = manager_with_store # Create a run and let it complete (will be in both memory and store) r1 = await mgr.create("thread-1") await mgr.set_status(r1.run_id, RunStatus.success) # Simulate restart: clear memory, then create a new in-memory run mgr._runs.clear() r2 = await mgr.create("thread-1") runs = await mgr.list_by_thread("thread-1") assert len(runs) == 2 run_ids = {r.run_id for r in runs} assert r1.run_id in run_ids assert r2.run_id in run_ids # r2 should be the in-memory record (has live state) r2_record = next(r for r in runs if r.run_id == r2.run_id) assert r2_record is r2 # same object reference @pytest.mark.anyio async def test_list_by_thread_no_store(): """Without a store, list_by_thread should only return in-memory runs.""" mgr = RunManager() await mgr.create("thread-1") mgr._runs.clear() runs = await mgr.list_by_thread("thread-1") assert runs == [] @pytest.mark.anyio async def test_aget_returns_in_memory_record(manager_with_store: RunManager): """aget should return the in-memory record when available.""" mgr = manager_with_store r1 = await mgr.create("thread-1", "agent-1") result = await mgr.aget(r1.run_id) assert result is r1 # same object @pytest.mark.anyio async def test_aget_falls_back_to_store(manager_with_store: RunManager): """aget should return a record from the store when not in memory.""" mgr = manager_with_store r1 = await mgr.create("thread-1", "agent-1") await mgr.set_status(r1.run_id, RunStatus.success) mgr._runs.clear() result = await mgr.aget(r1.run_id) assert result is not None assert result.run_id == r1.run_id assert result.status == RunStatus.success assert result.thread_id == "thread-1" assert result.assistant_id == "agent-1" @pytest.mark.anyio async def test_aget_falls_back_to_store_with_user_filter(): """aget should honor user_id when reading store-only records.""" store = MemoryRunStore() await store.put("run-1", thread_id="thread-1", user_id="user-1", status="success") mgr = RunManager(store=store) allowed = await mgr.aget("run-1", user_id="user-1") denied = await mgr.aget("run-1", user_id="user-2") assert allowed is not None assert denied is None @pytest.mark.anyio async def test_aget_returns_none_for_unknown(manager_with_store: RunManager): """aget should return None for a run ID that doesn't exist anywhere.""" result = await manager_with_store.aget("nonexistent-run-id") assert result is None @pytest.mark.anyio async def test_aget_store_failure_is_graceful(): """If the store raises, aget should return None instead of propagating.""" from unittest.mock import AsyncMock store = MemoryRunStore() store.get = AsyncMock(side_effect=RuntimeError("db down")) mgr = RunManager(store=store) result = await mgr.aget("some-id") assert result is None @pytest.mark.anyio async def test_list_by_thread_store_failure_is_graceful(): """If the store raises, list_by_thread should return only in-memory runs.""" from unittest.mock import AsyncMock store = MemoryRunStore() store.list_by_thread = AsyncMock(side_effect=RuntimeError("db down")) mgr = RunManager(store=store) r1 = await mgr.create("thread-1") runs = await mgr.list_by_thread("thread-1") assert len(runs) == 1 assert runs[0].run_id == r1.run_id @pytest.mark.anyio async def test_list_by_thread_falls_back_to_store_with_user_filter(): """list_by_thread should return only the requesting user's store records.""" store = MemoryRunStore() await store.put("run-1", thread_id="thread-1", user_id="user-1", status="success") await store.put("run-2", thread_id="thread-1", user_id="user-2", status="success") mgr = RunManager(store=store) runs = await mgr.list_by_thread("thread-1", user_id="user-1") assert [r.run_id for r in runs] == ["run-1"]