diff --git a/backend/tests/test_phase2b_integration.py b/backend/tests/test_phase2b_integration.py index 0858efdce..da675e757 100644 --- a/backend/tests/test_phase2b_integration.py +++ b/backend/tests/test_phase2b_integration.py @@ -5,7 +5,6 @@ is correctly written to both RunStore and RunEventStore. """ import asyncio -from unittest.mock import MagicMock from uuid import uuid4 import pytest @@ -15,19 +14,30 @@ from deerflow.runtime.journal import RunJournal from deerflow.runtime.runs.store.memory import MemoryRunStore +class _FakeMessage: + def __init__(self, content, usage): + self.content = content + self.tool_calls = [] + self.response_metadata = {"model_name": "test-model"} + self.usage_metadata = usage + self.id = "test-msg-id" + + def model_dump(self): + return {"type": "ai", "content": self.content, "id": self.id, "tool_calls": [], "usage_metadata": self.usage_metadata, "response_metadata": self.response_metadata} + + +class _FakeGeneration: + def __init__(self, message): + self.message = message + + +class _FakeLLMResult: + def __init__(self, content, usage): + self.generations = [[_FakeGeneration(_FakeMessage(content, usage))]] + + def _make_llm_response(content="Hello", usage=None): - msg = MagicMock() - msg.content = content - msg.tool_calls = [] - msg.response_metadata = {"model_name": "test-model"} - msg.usage_metadata = usage - - gen = MagicMock() - gen.message = msg - - response = MagicMock() - response.generations = [[gen]] - return response + return _FakeLLMResult(content, usage) class TestRunLifecycle: @@ -152,3 +162,118 @@ class TestRunLifecycle: await mgr.set_status(record.run_id, RunStatus.running) row = await run_store.get(record.run_id) assert row["status"] == "running" + + @pytest.mark.anyio + async def test_runmanager_create_or_reject_persists(self): + """create_or_reject also persists to store.""" + from deerflow.runtime.runs.manager import RunManager + + run_store = MemoryRunStore() + mgr = RunManager(store=run_store) + + record = await mgr.create_or_reject("t1", "lead_agent", metadata={"key": "val"}) + row = await run_store.get(record.run_id) + assert row is not None + assert row["status"] == "pending" + assert row["metadata"] == {"key": "val"} + + @pytest.mark.anyio + async def test_follow_up_metadata_in_messages(self): + """human_message metadata carries follow_up_to_run_id.""" + event_store = MemoryRunEventStore() + + # Run 1 + await event_store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="Q1") + await event_store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="A1") + + # Run 2 (follow-up) + await event_store.put( + thread_id="t1", + run_id="r2", + event_type="human_message", + category="message", + content="Tell me more", + metadata={"follow_up_to_run_id": "r1"}, + ) + + messages = await event_store.list_messages("t1") + assert len(messages) == 3 + assert messages[2]["metadata"]["follow_up_to_run_id"] == "r1" + + @pytest.mark.anyio + async def test_summarization_in_history(self): + """summary message appears correctly in message history.""" + event_store = MemoryRunEventStore() + + await event_store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="Q1") + await event_store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="A1") + await event_store.put(thread_id="t1", run_id="r2", event_type="summary", category="message", content="Previous conversation summarized.", metadata={"replaced_count": 2}) + await event_store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message", content="Q2") + await event_store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message", content="A2") + + messages = await event_store.list_messages("t1") + assert len(messages) == 5 + assert messages[2]["event_type"] == "summary" + assert messages[2]["metadata"]["replaced_count"] == 2 + + @pytest.mark.anyio + async def test_db_backed_run_lifecycle(self, tmp_path): + """Full lifecycle with SQLite-backed RunRepository + DbRunEventStore.""" + from deerflow.persistence.engine import close_engine, get_session_factory, init_engine + from deerflow.persistence.repositories.run_repo import RunRepository + from deerflow.runtime.events.store.db import DbRunEventStore + from deerflow.runtime.runs.manager import RunManager + + url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" + await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) + sf = get_session_factory() + + run_store = RunRepository(sf) + event_store = DbRunEventStore(sf) + mgr = RunManager(store=run_store) + + # Create run + record = await mgr.create("t1", "lead_agent") + run_id = record.run_id + + # Write human_message + await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content="Hello DB") + + # Simulate journal + on_complete_data = {} + journal = RunJournal(run_id, "t1", event_store, on_complete=lambda **d: on_complete_data.update(d), flush_threshold=100) + journal.set_first_human_message("Hello DB") + + journal.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None) + llm_rid = uuid4() + journal.on_llm_start({"name": "test"}, [], run_id=llm_rid, tags=["lead_agent"]) + journal.on_llm_end(_make_llm_response("DB response", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), run_id=llm_rid, tags=["lead_agent"]) + journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None) + await journal.flush() + await asyncio.sleep(0.05) + + # Verify run persisted + row = await run_store.get(run_id) + assert row is not None + assert row["status"] == "pending" # RunManager set it, journal doesn't update status + + # Update completion + await run_store.update_run_completion(run_id, status="success", **on_complete_data) + row = await run_store.get(run_id) + assert row["status"] == "success" + assert row["total_tokens"] == 15 + + # Verify messages from DB + messages = await event_store.list_messages("t1") + assert len(messages) == 2 + assert messages[0]["event_type"] == "human_message" + assert messages[1]["event_type"] == "ai_message" + + # Verify events from DB + events = await event_store.list_events("t1", run_id) + event_types = {e["event_type"] for e in events} + assert "run_start" in event_types + assert "llm_end" in event_types + assert "run_end" in event_types + + await close_engine() diff --git a/backend/tests/test_run_event_store.py b/backend/tests/test_run_event_store.py index f8c6cb177..2b22b2c6f 100644 --- a/backend/tests/test_run_event_store.py +++ b/backend/tests/test_run_event_store.py @@ -357,6 +357,102 @@ class TestDbRunEventStore: await close_engine() + @pytest.mark.anyio + async def test_put_batch_seq_continuity(self, tmp_path): + """Batch write produces continuous seq values with no gaps.""" + from deerflow.persistence.engine import close_engine, get_session_factory, init_engine + from deerflow.runtime.events.store.db import DbRunEventStore + + url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" + await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) + s = DbRunEventStore(get_session_factory()) + + events = [{"thread_id": "t1", "run_id": "r1", "event_type": "trace", "category": "trace"} for _ in range(50)] + results = await s.put_batch(events) + seqs = [r["seq"] for r in results] + assert seqs == list(range(1, 51)) + await close_engine() + + +# -- Factory tests -- + + +class TestMakeRunEventStore: + """Tests for the make_run_event_store factory function.""" + + @pytest.mark.anyio + async def test_memory_backend_default(self): + from deerflow.runtime.events.store import make_run_event_store + + store = make_run_event_store(None) + assert type(store).__name__ == "MemoryRunEventStore" + + @pytest.mark.anyio + async def test_memory_backend_explicit(self): + from unittest.mock import MagicMock + + from deerflow.runtime.events.store import make_run_event_store + + config = MagicMock() + config.backend = "memory" + store = make_run_event_store(config) + assert type(store).__name__ == "MemoryRunEventStore" + + @pytest.mark.anyio + async def test_db_backend_with_engine(self, tmp_path): + from unittest.mock import MagicMock + + from deerflow.persistence.engine import close_engine, init_engine + from deerflow.runtime.events.store import make_run_event_store + + url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" + await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) + + config = MagicMock() + config.backend = "db" + config.max_trace_content = 10240 + store = make_run_event_store(config) + assert type(store).__name__ == "DbRunEventStore" + await close_engine() + + @pytest.mark.anyio + async def test_db_backend_no_engine_falls_back(self): + """db backend without engine falls back to memory.""" + from unittest.mock import MagicMock + + from deerflow.persistence.engine import close_engine, init_engine + from deerflow.runtime.events.store import make_run_event_store + + await init_engine("memory") # no engine created + + config = MagicMock() + config.backend = "db" + store = make_run_event_store(config) + assert type(store).__name__ == "MemoryRunEventStore" + await close_engine() + + @pytest.mark.anyio + async def test_jsonl_backend(self): + from unittest.mock import MagicMock + + from deerflow.runtime.events.store import make_run_event_store + + config = MagicMock() + config.backend = "jsonl" + store = make_run_event_store(config) + assert type(store).__name__ == "JsonlRunEventStore" + + @pytest.mark.anyio + async def test_unknown_backend_raises(self): + from unittest.mock import MagicMock + + from deerflow.runtime.events.store import make_run_event_store + + config = MagicMock() + config.backend = "redis" + with pytest.raises(ValueError, match="Unknown"): + make_run_event_store(config) + # -- JSONL-specific tests -- diff --git a/backend/tests/test_run_journal.py b/backend/tests/test_run_journal.py index 28dcfbddc..e4215586a 100644 --- a/backend/tests/test_run_journal.py +++ b/backend/tests/test_run_journal.py @@ -213,6 +213,92 @@ class TestIdentifyCaller: assert j._identify_caller({}) == "unknown" +class TestChainErrorCallback: + @pytest.mark.anyio + async def test_on_chain_error_writes_run_error(self, journal_setup): + j, store, _ = journal_setup + # parent_run_id must be None (top-level chain) for the event to be recorded + j.on_chain_error(ValueError("boom"), run_id=uuid4(), parent_run_id=None) + # on_chain_error calls _flush_sync internally, give async task time to complete + await asyncio.sleep(0.05) + await j.flush() + events = await store.list_events("t1", "r1") + error_events = [e for e in events if e["event_type"] == "run_error"] + assert len(error_events) == 1 + assert "boom" in error_events[0]["content"] + assert error_events[0]["metadata"]["error_type"] == "ValueError" + + +class TestTokenTrackingDisabled: + @pytest.mark.anyio + async def test_track_token_usage_false(self): + """track_token_usage=False disables token accumulation.""" + store = MemoryRunEventStore() + complete_data = {} + j = RunJournal("r1", "t1", store, track_token_usage=False, on_complete=lambda **d: complete_data.update(d), flush_threshold=100) + j.on_llm_end(_make_llm_response("X", usage={"input_tokens": 50, "output_tokens": 50, "total_tokens": 100}), run_id=uuid4(), tags=["lead_agent"]) + j.on_chain_end({}, run_id=uuid4(), parent_run_id=None) + assert complete_data["total_tokens"] == 0 + assert complete_data["llm_call_count"] == 0 + + +class TestMiddlewareNoMessage: + @pytest.mark.anyio + async def test_on_llm_end_middleware_no_ai_message(self, journal_setup): + j, store, _ = journal_setup + j.on_llm_end(_make_llm_response("Summary"), run_id=uuid4(), tags=["middleware:summarization"]) + await j.flush() + messages = await store.list_messages("t1") + assert len(messages) == 0 + + +class TestUnknownCallerTokens: + @pytest.mark.anyio + async def test_unknown_caller_tokens_go_to_lead(self, journal_setup): + """No caller tag: tokens attributed to lead_agent bucket.""" + j, store, _ = journal_setup + j.on_llm_end(_make_llm_response("X", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), run_id=uuid4(), tags=[]) + assert j._lead_agent_tokens == 15 + + +class TestConvenienceFields: + @pytest.mark.anyio + async def test_last_ai_message_tracks_latest(self, journal_setup): + j, store, complete_data = journal_setup + j.on_llm_end(_make_llm_response("First"), run_id=uuid4(), tags=["lead_agent"]) + j.on_llm_end(_make_llm_response("Second"), run_id=uuid4(), tags=["lead_agent"]) + j.on_chain_end({}, run_id=uuid4(), parent_run_id=None) + assert complete_data["last_ai_message"] == "Second" + assert complete_data["message_count"] == 2 + + @pytest.mark.anyio + async def test_first_human_message_via_set(self, journal_setup): + j, store, complete_data = journal_setup + j.set_first_human_message("What is AI?") + j.on_chain_end({}, run_id=uuid4(), parent_run_id=None) + assert complete_data["first_human_message"] == "What is AI?" + + +class TestToolError: + @pytest.mark.anyio + async def test_on_tool_error(self, journal_setup): + j, store, _ = journal_setup + j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch") + await j.flush() + events = await store.list_events("t1", "r1") + assert any(e["event_type"] == "tool_error" for e in events) + + +class TestOtherCustomEvent: + @pytest.mark.anyio + async def test_non_summarization_custom_event(self, journal_setup): + j, store, _ = journal_setup + j.on_custom_event("task_running", {"task_id": "t1", "status": "running"}, run_id=uuid4()) + await j.flush() + events = await store.list_events("t1", "r1") + assert any(e["event_type"] == "task_running" for e in events) + + class TestPublicMethods: @pytest.mark.anyio async def test_set_first_human_message(self, journal_setup): diff --git a/backend/tests/test_run_repository.py b/backend/tests/test_run_repository.py index 5198952c4..c1ecabc99 100644 --- a/backend/tests/test_run_repository.py +++ b/backend/tests/test_run_repository.py @@ -153,3 +153,44 @@ class TestRunRepository: row = await repo.get("r1") assert "obj" in row["kwargs"] await _cleanup() + + @pytest.mark.anyio + async def test_update_run_completion_preserves_existing_fields(self, tmp_path): + """update_run_completion does not overwrite thread_id or assistant_id.""" + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", assistant_id="agent1", status="running") + await repo.update_run_completion("r1", status="success", total_tokens=100) + row = await repo.get("r1") + assert row["thread_id"] == "t1" + assert row["assistant_id"] == "agent1" + assert row["total_tokens"] == 100 + await _cleanup() + + @pytest.mark.anyio + async def test_list_by_thread_ordered_desc(self, tmp_path): + """list_by_thread returns newest first.""" + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", created_at="2024-01-01T00:00:00+00:00") + await repo.put("r2", thread_id="t1", created_at="2024-01-02T00:00:00+00:00") + rows = await repo.list_by_thread("t1") + assert rows[0]["run_id"] == "r2" + assert rows[1]["run_id"] == "r1" + await _cleanup() + + @pytest.mark.anyio + async def test_list_by_thread_limit(self, tmp_path): + repo = await _make_repo(tmp_path) + for i in range(5): + await repo.put(f"r{i}", thread_id="t1") + rows = await repo.list_by_thread("t1", limit=2) + assert len(rows) == 2 + await _cleanup() + + @pytest.mark.anyio + async def test_owner_none_returns_all(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", owner_id="alice") + await repo.put("r2", thread_id="t1", owner_id="bob") + rows = await repo.list_by_thread("t1", owner_id=None) + assert len(rows) == 2 + await _cleanup() diff --git a/backend/tests/test_thread_meta_repo.py b/backend/tests/test_thread_meta_repo.py new file mode 100644 index 000000000..9104275ff --- /dev/null +++ b/backend/tests/test_thread_meta_repo.py @@ -0,0 +1,132 @@ +"""Tests for ThreadMetaRepository (SQLAlchemy-backed).""" + +import pytest + +from deerflow.persistence.repositories.thread_meta_repo import ThreadMetaRepository + + +async def _make_repo(tmp_path): + from deerflow.persistence.engine import get_session_factory, init_engine + + url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" + await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) + return ThreadMetaRepository(get_session_factory()) + + +async def _cleanup(): + from deerflow.persistence.engine import close_engine + + await close_engine() + + +class TestThreadMetaRepository: + @pytest.mark.anyio + async def test_create_and_get(self, tmp_path): + repo = await _make_repo(tmp_path) + record = await repo.create("t1") + assert record["thread_id"] == "t1" + assert record["status"] == "idle" + assert "created_at" in record + + fetched = await repo.get("t1") + assert fetched is not None + assert fetched["thread_id"] == "t1" + await _cleanup() + + @pytest.mark.anyio + async def test_create_with_assistant_id(self, tmp_path): + repo = await _make_repo(tmp_path) + record = await repo.create("t1", assistant_id="agent1") + assert record["assistant_id"] == "agent1" + await _cleanup() + + @pytest.mark.anyio + async def test_create_with_owner_and_display_name(self, tmp_path): + repo = await _make_repo(tmp_path) + record = await repo.create("t1", owner_id="user1", display_name="My Thread") + assert record["owner_id"] == "user1" + assert record["display_name"] == "My Thread" + await _cleanup() + + @pytest.mark.anyio + async def test_create_with_metadata(self, tmp_path): + repo = await _make_repo(tmp_path) + record = await repo.create("t1", metadata={"key": "value"}) + assert record["metadata"] == {"key": "value"} + await _cleanup() + + @pytest.mark.anyio + async def test_get_nonexistent(self, tmp_path): + repo = await _make_repo(tmp_path) + assert await repo.get("nonexistent") is None + await _cleanup() + + @pytest.mark.anyio + async def test_list_by_owner(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.create("t1", owner_id="user1") + await repo.create("t2", owner_id="user1") + await repo.create("t3", owner_id="user2") + results = await repo.list_by_owner("user1") + assert len(results) == 2 + assert all(r["owner_id"] == "user1" for r in results) + await _cleanup() + + @pytest.mark.anyio + async def test_list_by_owner_with_limit_and_offset(self, tmp_path): + repo = await _make_repo(tmp_path) + for i in range(5): + await repo.create(f"t{i}", owner_id="user1") + results = await repo.list_by_owner("user1", limit=2, offset=1) + assert len(results) == 2 + await _cleanup() + + @pytest.mark.anyio + async def test_check_access_no_record_allows(self, tmp_path): + repo = await _make_repo(tmp_path) + assert await repo.check_access("unknown", "user1") is True + await _cleanup() + + @pytest.mark.anyio + async def test_check_access_owner_matches(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.create("t1", owner_id="user1") + assert await repo.check_access("t1", "user1") is True + await _cleanup() + + @pytest.mark.anyio + async def test_check_access_owner_mismatch(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.create("t1", owner_id="user1") + assert await repo.check_access("t1", "user2") is False + await _cleanup() + + @pytest.mark.anyio + async def test_check_access_no_owner_allows_all(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.create("t1") # owner_id=None + assert await repo.check_access("t1", "anyone") is True + await _cleanup() + + @pytest.mark.anyio + async def test_update_status(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.create("t1") + await repo.update_status("t1", "busy") + record = await repo.get("t1") + assert record["status"] == "busy" + await _cleanup() + + @pytest.mark.anyio + async def test_delete(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.create("t1") + await repo.delete("t1") + assert await repo.get("t1") is None + await _cleanup() + + @pytest.mark.anyio + async def test_delete_nonexistent_is_noop(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.delete("nonexistent") # should not raise + await _cleanup() diff --git a/config.example.yaml b/config.example.yaml index ca1ecbe87..8654e6b8f 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -232,7 +232,6 @@ models: # supports_vision: true # supports_thinking: true - # Example: OpenRouter (OpenAI-compatible) # OpenRouter models use the same ChatOpenAI + base_url pattern as other OpenAI-compatible gateways. # - name: openrouter-gemini-2.5-flash @@ -552,34 +551,20 @@ memory: max_injection_tokens: 2000 # Maximum tokens for memory injection # ============================================================================ -# Checkpointer Configuration +# Checkpointer Configuration (DEPRECATED — use `database` instead) # ============================================================================ -# Configure state persistence for the embedded DeerFlowClient. -# The LangGraph Server manages its own state persistence separately -# via the server infrastructure (this setting does not affect it). +# Legacy standalone checkpointer config. Kept for backward compatibility. +# Prefer the unified `database` section below, which drives BOTH the +# LangGraph checkpointer AND DeerFlow application data (runs, feedback, +# events) from a single backend setting. # -# When configured, DeerFlowClient will automatically use this checkpointer, -# enabling multi-turn conversations to persist across process restarts. +# If both `checkpointer` and `database` are present, `checkpointer` +# takes precedence for LangGraph state persistence only. # -# Supported types: -# memory - In-process only. State is lost when the process exits. (default) -# sqlite - File-based SQLite persistence. Survives restarts. -# Requires: uv add langgraph-checkpoint-sqlite -# postgres - PostgreSQL persistence. Suitable for multi-process deployments. -# Requires: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool -# -# Examples: -# -# In-memory (default when omitted — no persistence): # checkpointer: -# type: memory +# type: sqlite +# connection_string: checkpoints.db # -# SQLite (file-based, single-process): -checkpointer: - type: sqlite - connection_string: checkpoints.db -# -# PostgreSQL (multi-process, production): # checkpointer: # type: postgres # connection_string: postgresql://user:password@localhost:5432/deerflow @@ -588,7 +573,7 @@ checkpointer: # Database # ============================================================================ # Unified storage backend for LangGraph checkpointer and DeerFlow -# application data (runs, threads metadata, etc.). +# application data (runs, threads metadata, feedback, etc.). # # backend: memory -- No persistence, data lost on restart (default) # backend: sqlite -- Single-node deployment, files in sqlite_dir @@ -604,7 +589,6 @@ checkpointer: # NOTE: When both `checkpointer` and `database` are configured, # `checkpointer` takes precedence for LangGraph state persistence. # If you use `database`, you can remove the `checkpointer` section. -# # database: # backend: sqlite # sqlite_dir: ./data @@ -612,6 +596,9 @@ checkpointer: # database: # backend: postgres # postgres_url: $DATABASE_URL +database: + backend: sqlite + sqlite_dir: ./data # ============================================================================ # Run Events Configuration @@ -626,6 +613,10 @@ checkpointer: # backend: memory # max_trace_content: 10240 # Truncation threshold for trace content (db backend, bytes) # track_token_usage: true # Accumulate token counts to RunRow +run_events: + backend: memory + max_trace_content: 10240 + track_token_usage: true # ============================================================================ # IM Channels Configuration