diff --git a/backend/packages/harness/deerflow/runtime/events/store/db.py b/backend/packages/harness/deerflow/runtime/events/store/db.py index e4a21d006..9374769f3 100644 --- a/backend/packages/harness/deerflow/runtime/events/store/db.py +++ b/backend/packages/harness/deerflow/runtime/events/store/db.py @@ -9,6 +9,7 @@ from __future__ import annotations import json import logging from datetime import UTC, datetime +from typing import Any from sqlalchemy import delete, func, select from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker @@ -33,20 +34,21 @@ class DbRunEventStore(RunEventStore): if isinstance(val, datetime): d["created_at"] = val.isoformat() d.pop("id", None) - # Restore dict content that was JSON-serialized on write + # Restore structured content that was JSON-serialized on write. raw = d.get("content", "") - if isinstance(raw, str) and d.get("metadata", {}).get("content_is_dict"): + metadata = d.get("metadata", {}) + if isinstance(raw, str) and (metadata.get("content_is_json") or metadata.get("content_is_dict")): try: d["content"] = json.loads(raw) except (json.JSONDecodeError, ValueError): - # Content looked like JSON (content_is_dict flag) but failed to parse; + # Content looked like JSON but failed to parse; # keep the raw string as-is. logger.debug("Failed to deserialize content as JSON for event seq=%s", d.get("seq")) return d - def _truncate_trace(self, category: str, content: str | dict, metadata: dict | None) -> tuple[str | dict, dict]: + def _truncate_trace(self, category: str, content: Any, metadata: dict | None) -> tuple[Any, dict]: if category == "trace": - text = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content + text = content if isinstance(content, str) else json.dumps(content, default=str, ensure_ascii=False) encoded = text.encode("utf-8") if len(encoded) > self._max_trace_content: # Truncate by bytes, then decode back (may cut a multi-byte char, so use errors="ignore") @@ -54,6 +56,18 @@ class DbRunEventStore(RunEventStore): metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)} return content, metadata or {} + @staticmethod + def _content_to_db(content: Any, metadata: dict | None) -> tuple[str, dict]: + metadata = metadata or {} + if isinstance(content, str): + return content, metadata + + db_content = json.dumps(content, default=str, ensure_ascii=False) + metadata = {**metadata, "content_is_json": True} + if isinstance(content, dict): + metadata["content_is_dict"] = True + return db_content, metadata + @staticmethod def _user_id_from_context() -> str | None: """Soft read of user_id from contextvar for write paths. @@ -82,11 +96,7 @@ class DbRunEventStore(RunEventStore): the initial ``human_message`` event (once per run). """ content, metadata = self._truncate_trace(category, content, metadata) - if isinstance(content, dict): - db_content = json.dumps(content, default=str, ensure_ascii=False) - metadata = {**(metadata or {}), "content_is_dict": True} - else: - db_content = content + db_content, metadata = self._content_to_db(content, metadata) user_id = self._user_id_from_context() async with self._sf() as session: async with session.begin(): @@ -128,11 +138,7 @@ class DbRunEventStore(RunEventStore): category = e.get("category", "trace") metadata = e.get("metadata") content, metadata = self._truncate_trace(category, content, metadata) - if isinstance(content, dict): - db_content = json.dumps(content, default=str, ensure_ascii=False) - metadata = {**(metadata or {}), "content_is_dict": True} - else: - db_content = content + db_content, metadata = self._content_to_db(content, metadata) row = RunEventRow( thread_id=e["thread_id"], run_id=e["run_id"], diff --git a/backend/tests/test_run_event_store.py b/backend/tests/test_run_event_store.py index 2b22b2c6f..d2c78ccf0 100644 --- a/backend/tests/test_run_event_store.py +++ b/backend/tests/test_run_event_store.py @@ -310,6 +310,28 @@ class TestDbRunEventStore: await close_engine() + @pytest.mark.anyio + async def test_structured_content_round_trips(self, tmp_path): + from deerflow.persistence.engine import close_engine, get_session_factory, init_engine + from deerflow.runtime.events.store.db import DbRunEventStore + + url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" + await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) + s = DbRunEventStore(get_session_factory()) + + content = [{"type": "text", "text": "hello"}, {"type": "image_url", "image_url": {"url": "https://example.test/a.png"}}] + record = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content=content) + + assert record["content"] == content + assert record["metadata"]["content_is_json"] is True + assert "content_is_dict" not in record["metadata"] + + messages = await s.list_messages("t1") + assert messages[0]["content"] == content + assert messages[0]["metadata"]["content_is_json"] is True + + await close_engine() + @pytest.mark.anyio async def test_pagination(self, tmp_path): from deerflow.persistence.engine import close_engine, get_session_factory, init_engine @@ -373,6 +395,55 @@ class TestDbRunEventStore: assert seqs == list(range(1, 51)) await close_engine() + @pytest.mark.anyio + async def test_put_batch_accepts_structured_content(self, tmp_path): + from deerflow.persistence.engine import close_engine, get_session_factory, init_engine + from deerflow.runtime.events.store.db import DbRunEventStore + + url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" + await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) + s = DbRunEventStore(get_session_factory()) + + content = [{"messages": [{"type": "ai", "content": ""}]}] + results = await s.put_batch( + [ + { + "thread_id": "t1", + "run_id": "r1", + "event_type": "run.end", + "category": "outputs", + "content": content, + } + ] + ) + + assert results[0]["content"] == content + assert results[0]["metadata"]["content_is_json"] is True + + events = await s.list_events("t1", "r1") + assert events[0]["content"] == content + assert events[0]["metadata"]["content_is_json"] is True + + await close_engine() + + @pytest.mark.anyio + async def test_dict_content_keeps_legacy_metadata_flag(self, tmp_path): + from deerflow.persistence.engine import close_engine, get_session_factory, init_engine + from deerflow.runtime.events.store.db import DbRunEventStore + + url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}" + await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)) + s = DbRunEventStore(get_session_factory()) + + content = {"status": "success"} + record = await s.put(thread_id="t1", run_id="r1", event_type="run.end", category="outputs", content=content) + + assert record["content"] == content + assert record["metadata"]["content_is_json"] is True + assert record["metadata"]["content_is_dict"] is True + + await close_engine() + # -- Factory tests --