mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-26 11:48:10 +00:00
feat(events): widen content type to str|dict in all store backends
Allow event content to be a dict (for structured OpenAI-format messages) in addition to plain strings. Dict values are JSON-serialized for the DB backend and deserialized on read; memory and JSONL backends handle dicts natively. Trace truncation now serializes dicts to JSON before measuring. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
b92ddafd4b
commit
17eb509dbd
@ -33,7 +33,7 @@ class RunEventStore(abc.ABC):
|
||||
run_id: str,
|
||||
event_type: str,
|
||||
category: str,
|
||||
content: str = "",
|
||||
content: str | dict = "",
|
||||
metadata: dict | None = None,
|
||||
created_at: str | None = None,
|
||||
) -> dict:
|
||||
|
||||
@ -6,6 +6,7 @@ at ``max_trace_content`` bytes to avoid bloating the database.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
@ -28,16 +29,26 @@ class DbRunEventStore(RunEventStore):
|
||||
if isinstance(val, datetime):
|
||||
d["created_at"] = val.isoformat()
|
||||
d.pop("id", None)
|
||||
# Restore dict content that was JSON-serialized on write
|
||||
content = d.get("content", "")
|
||||
if isinstance(content, str) and content and content[0] in ("{", "["):
|
||||
try:
|
||||
d["content"] = json.loads(content)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
return d
|
||||
|
||||
def _truncate_trace(self, category: str, content: str, metadata: dict | None) -> tuple[str, dict]:
|
||||
if category == "trace" and len(content) > self._max_trace_content:
|
||||
content = content[: self._max_trace_content]
|
||||
metadata = {**(metadata or {}), "content_truncated": True}
|
||||
def _truncate_trace(self, category: str, content: str | dict, metadata: dict | None) -> tuple[str | dict, dict]:
|
||||
if category == "trace":
|
||||
text = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content
|
||||
if len(text) > self._max_trace_content:
|
||||
content = text[: self._max_trace_content]
|
||||
metadata = {**(metadata or {}), "content_truncated": True}
|
||||
return content, metadata or {}
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None):
|
||||
content, metadata = self._truncate_trace(category, content, metadata)
|
||||
db_content = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content
|
||||
async with self._sf() as session:
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id))
|
||||
seq = (max_seq or 0) + 1
|
||||
@ -46,7 +57,7 @@ class DbRunEventStore(RunEventStore):
|
||||
run_id=run_id,
|
||||
event_type=event_type,
|
||||
category=category,
|
||||
content=content,
|
||||
content=db_content,
|
||||
event_metadata=metadata,
|
||||
seq=seq,
|
||||
created_at=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC),
|
||||
@ -71,12 +82,13 @@ class DbRunEventStore(RunEventStore):
|
||||
category = e.get("category", "trace")
|
||||
metadata = e.get("metadata")
|
||||
content, metadata = self._truncate_trace(category, content, metadata)
|
||||
db_content = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content
|
||||
row = RunEventRow(
|
||||
thread_id=e["thread_id"],
|
||||
run_id=e["run_id"],
|
||||
event_type=e["event_type"],
|
||||
category=category,
|
||||
content=content,
|
||||
content=db_content,
|
||||
event_metadata=metadata,
|
||||
seq=seq,
|
||||
created_at=datetime.fromisoformat(e["created_at"]) if e.get("created_at") else datetime.now(UTC),
|
||||
|
||||
@ -29,7 +29,7 @@ class MemoryRunEventStore(RunEventStore):
|
||||
run_id: str,
|
||||
event_type: str,
|
||||
category: str,
|
||||
content: str = "",
|
||||
content: str | dict = "",
|
||||
metadata: dict | None = None,
|
||||
created_at: str | None = None,
|
||||
) -> dict:
|
||||
|
||||
@ -243,7 +243,7 @@ class RunJournal(BaseCallbackHandler):
|
||||
|
||||
# -- Internal methods --
|
||||
|
||||
def _put(self, *, event_type: str, category: str, content: str = "", metadata: dict | None = None) -> None:
|
||||
def _put(self, *, event_type: str, category: str, content: str | dict = "", metadata: dict | None = None) -> None:
|
||||
self._buffer.append({
|
||||
"thread_id": self.thread_id,
|
||||
"run_id": self.run_id,
|
||||
|
||||
@ -417,3 +417,88 @@ class TestDbBackedLifecycle:
|
||||
assert "run_end" in event_types
|
||||
|
||||
await close_engine()
|
||||
|
||||
|
||||
class TestDictContent:
|
||||
"""Verify that store backends accept str | dict content."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_memory_store_dict_content(self):
|
||||
store = MemoryRunEventStore()
|
||||
record = await store.put(
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content={"role": "assistant", "content": "Hello"},
|
||||
)
|
||||
assert record["content"] == {"role": "assistant", "content": "Hello"}
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == {"role": "assistant", "content": "Hello"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_memory_store_str_content_unchanged(self):
|
||||
store = MemoryRunEventStore()
|
||||
record = await store.put(
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content="plain string",
|
||||
)
|
||||
assert record["content"] == "plain string"
|
||||
assert isinstance(record["content"], str)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_db_store_dict_content_roundtrip(self, tmp_path):
|
||||
"""Dict content survives DB roundtrip (JSON serialize on write, deserialize on read)."""
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
sf = get_session_factory()
|
||||
store = DbRunEventStore(sf)
|
||||
|
||||
nested = {"role": "assistant", "content": "Hi", "metadata": {"model": "gpt-4", "tokens": [1, 2, 3]}}
|
||||
record = await store.put(
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content=nested,
|
||||
)
|
||||
assert record["content"] == nested
|
||||
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == nested
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_db_store_trace_dict_truncation(self, tmp_path):
|
||||
"""Large dict trace content is truncated with metadata flag."""
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
sf = get_session_factory()
|
||||
store = DbRunEventStore(sf, max_trace_content=100)
|
||||
|
||||
large_dict = {"role": "assistant", "content": "x" * 200}
|
||||
record = await store.put(
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
event_type="llm_end",
|
||||
category="trace",
|
||||
content=large_dict,
|
||||
)
|
||||
assert record["metadata"].get("content_truncated") is True
|
||||
# Content should be a truncated string (serialized JSON was too long)
|
||||
assert isinstance(record["content"], str)
|
||||
assert len(record["content"]) <= 100
|
||||
|
||||
await close_engine()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user