mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
refactor(harness): remove old persistence layer from deerflow package
Remove the following deprecated modules: - deerflow/persistence/ - old SQL persistence models and repositories - deerflow/config/checkpointer_config.py - checkpointer configuration - deerflow/config/database_config.py - database configuration - deerflow/runtime/checkpointer/ - checkpointer providers - deerflow/runtime/store/ - store providers - deerflow/runtime/events/ - event store implementations - deerflow/runtime/journal.py - run journal These components are replaced by the new storage layer in app/infra/. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
2fe0856e33
commit
37fd8b0d7a
@ -1,46 +0,0 @@
|
||||
"""Configuration for LangGraph checkpointer."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
CheckpointerType = Literal["memory", "sqlite", "postgres"]
|
||||
|
||||
|
||||
class CheckpointerConfig(BaseModel):
|
||||
"""Configuration for LangGraph state persistence checkpointer."""
|
||||
|
||||
type: CheckpointerType = Field(
|
||||
description="Checkpointer backend type. "
|
||||
"'memory' is in-process only (lost on restart). "
|
||||
"'sqlite' persists to a local file (requires langgraph-checkpoint-sqlite). "
|
||||
"'postgres' persists to PostgreSQL (requires langgraph-checkpoint-postgres)."
|
||||
)
|
||||
connection_string: str | None = Field(
|
||||
default=None,
|
||||
description="Connection string for sqlite (file path) or postgres (DSN). "
|
||||
"Required for sqlite and postgres types. "
|
||||
"For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. "
|
||||
"For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.",
|
||||
)
|
||||
|
||||
|
||||
# Global configuration instance — None means no checkpointer is configured.
|
||||
_checkpointer_config: CheckpointerConfig | None = None
|
||||
|
||||
|
||||
def get_checkpointer_config() -> CheckpointerConfig | None:
|
||||
"""Get the current checkpointer configuration, or None if not configured."""
|
||||
return _checkpointer_config
|
||||
|
||||
|
||||
def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
|
||||
"""Set the checkpointer configuration."""
|
||||
global _checkpointer_config
|
||||
_checkpointer_config = config
|
||||
|
||||
|
||||
def load_checkpointer_config_from_dict(config_dict: dict) -> None:
|
||||
"""Load checkpointer configuration from a dictionary."""
|
||||
global _checkpointer_config
|
||||
_checkpointer_config = CheckpointerConfig(**config_dict)
|
||||
@ -1,102 +0,0 @@
|
||||
"""Unified database backend configuration.
|
||||
|
||||
Controls BOTH the LangGraph checkpointer and the DeerFlow application
|
||||
persistence layer (runs, threads metadata, users, etc.). The user
|
||||
configures one backend; the system handles physical separation details.
|
||||
|
||||
SQLite mode: checkpointer and app share a single .db file
|
||||
({sqlite_dir}/deerflow.db) with WAL journal mode enabled on every
|
||||
connection. WAL allows concurrent readers and a single writer without
|
||||
blocking, making a unified file safe for both workloads. Writers
|
||||
that contend for the lock wait via the default 5-second sqlite3
|
||||
busy timeout rather than failing immediately.
|
||||
|
||||
Postgres mode: both use the same database URL but maintain independent
|
||||
connection pools with different lifecycles.
|
||||
|
||||
Memory mode: checkpointer uses MemorySaver, app uses in-memory stores.
|
||||
No database is initialized.
|
||||
|
||||
Sensitive values (postgres_url) should use $VAR syntax in config.yaml
|
||||
to reference environment variables from .env:
|
||||
|
||||
database:
|
||||
backend: postgres
|
||||
postgres_url: $DATABASE_URL
|
||||
|
||||
The $VAR resolution is handled by AppConfig.resolve_env_variables()
|
||||
before this config is instantiated -- DatabaseConfig itself does not
|
||||
need to do any environment variable processing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DatabaseConfig(BaseModel):
|
||||
backend: Literal["memory", "sqlite", "postgres"] = Field(
|
||||
default="memory",
|
||||
description=("Storage backend for both checkpointer and application data. 'memory' for development (no persistence across restarts), 'sqlite' for single-node deployment, 'postgres' for production multi-node deployment."),
|
||||
)
|
||||
sqlite_dir: str = Field(
|
||||
default=".deer-flow/data",
|
||||
description=("Directory for the SQLite database file. Both checkpointer and application data share {sqlite_dir}/deerflow.db."),
|
||||
)
|
||||
postgres_url: str = Field(
|
||||
default="",
|
||||
description=(
|
||||
"PostgreSQL connection URL, shared by checkpointer and app. "
|
||||
"Use $DATABASE_URL in config.yaml to reference .env. "
|
||||
"Example: postgresql://user:pass@host:5432/deerflow "
|
||||
"(the +asyncpg driver suffix is added automatically where needed)."
|
||||
),
|
||||
)
|
||||
echo_sql: bool = Field(
|
||||
default=False,
|
||||
description="Echo all SQL statements to log (debug only).",
|
||||
)
|
||||
pool_size: int = Field(
|
||||
default=5,
|
||||
description="Connection pool size for the app ORM engine (postgres only).",
|
||||
)
|
||||
|
||||
# -- Derived helpers (not user-configured) --
|
||||
|
||||
@property
|
||||
def _resolved_sqlite_dir(self) -> str:
|
||||
"""Resolve sqlite_dir to an absolute path (relative to CWD)."""
|
||||
from pathlib import Path
|
||||
|
||||
return str(Path(self.sqlite_dir).resolve())
|
||||
|
||||
@property
|
||||
def sqlite_path(self) -> str:
|
||||
"""Unified SQLite file path shared by checkpointer and app."""
|
||||
return os.path.join(self._resolved_sqlite_dir, "deerflow.db")
|
||||
|
||||
# Backward-compatible aliases
|
||||
@property
|
||||
def checkpointer_sqlite_path(self) -> str:
|
||||
"""SQLite file path for the LangGraph checkpointer (alias for sqlite_path)."""
|
||||
return self.sqlite_path
|
||||
|
||||
@property
|
||||
def app_sqlite_path(self) -> str:
|
||||
"""SQLite file path for application ORM data (alias for sqlite_path)."""
|
||||
return self.sqlite_path
|
||||
|
||||
@property
|
||||
def app_sqlalchemy_url(self) -> str:
|
||||
"""SQLAlchemy async URL for the application ORM engine."""
|
||||
if self.backend == "sqlite":
|
||||
return f"sqlite+aiosqlite:///{self.sqlite_path}"
|
||||
if self.backend == "postgres":
|
||||
url = self.postgres_url
|
||||
if url.startswith("postgresql://"):
|
||||
url = url.replace("postgresql://", "postgresql+asyncpg://", 1)
|
||||
return url
|
||||
raise ValueError(f"No SQLAlchemy URL for backend={self.backend!r}")
|
||||
@ -1,13 +0,0 @@
|
||||
"""DeerFlow application persistence layer (SQLAlchemy 2.0 async ORM).
|
||||
|
||||
This module manages DeerFlow's own application data -- runs metadata,
|
||||
thread ownership, cron jobs, users. It is completely separate from
|
||||
LangGraph's checkpointer, which manages graph execution state.
|
||||
|
||||
Usage:
|
||||
from deerflow.persistence import init_engine, close_engine, get_session_factory
|
||||
"""
|
||||
|
||||
from deerflow.persistence.engine import close_engine, get_engine, get_session_factory, init_engine
|
||||
|
||||
__all__ = ["close_engine", "get_engine", "get_session_factory", "init_engine"]
|
||||
@ -1,40 +0,0 @@
|
||||
"""SQLAlchemy declarative base with automatic to_dict support.
|
||||
|
||||
All DeerFlow ORM models inherit from this Base. It provides a generic
|
||||
to_dict() method via SQLAlchemy's inspect() so individual models don't
|
||||
need to write their own serialization logic.
|
||||
|
||||
LangGraph's checkpointer tables are NOT managed by this Base.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import inspect as sa_inspect
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all DeerFlow ORM models.
|
||||
|
||||
Provides:
|
||||
- Automatic to_dict() via SQLAlchemy column inspection.
|
||||
- Standard __repr__() showing all column values.
|
||||
"""
|
||||
|
||||
def to_dict(self, *, exclude: set[str] | None = None) -> dict:
|
||||
"""Convert ORM instance to plain dict.
|
||||
|
||||
Uses SQLAlchemy's inspect() to iterate mapped column attributes.
|
||||
|
||||
Args:
|
||||
exclude: Optional set of column keys to omit.
|
||||
|
||||
Returns:
|
||||
Dict of {column_key: value} for all mapped columns.
|
||||
"""
|
||||
exclude = exclude or set()
|
||||
return {c.key: getattr(self, c.key) for c in sa_inspect(type(self)).mapper.column_attrs if c.key not in exclude}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
cols = ", ".join(f"{c.key}={getattr(self, c.key)!r}" for c in sa_inspect(type(self)).mapper.column_attrs)
|
||||
return f"{type(self).__name__}({cols})"
|
||||
@ -1,190 +0,0 @@
|
||||
"""Async SQLAlchemy engine lifecycle management.
|
||||
|
||||
Initializes at Gateway startup, provides session factory for
|
||||
repositories, disposes at shutdown.
|
||||
|
||||
When database.backend="memory", init_engine is a no-op and
|
||||
get_session_factory() returns None. Repositories must check for
|
||||
None and fall back to in-memory implementations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
|
||||
def _json_serializer(obj: object) -> str:
|
||||
"""JSON serializer with ensure_ascii=False for Chinese character support."""
|
||||
return json.dumps(obj, ensure_ascii=False)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_engine: AsyncEngine | None = None
|
||||
_session_factory: async_sessionmaker[AsyncSession] | None = None
|
||||
|
||||
|
||||
async def _auto_create_postgres_db(url: str) -> None:
|
||||
"""Connect to the ``postgres`` maintenance DB and CREATE DATABASE.
|
||||
|
||||
The target database name is extracted from *url*. The connection is
|
||||
made to the default ``postgres`` database on the same server using
|
||||
``AUTOCOMMIT`` isolation (CREATE DATABASE cannot run inside a
|
||||
transaction).
|
||||
"""
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
parsed = make_url(url)
|
||||
db_name = parsed.database
|
||||
if not db_name:
|
||||
raise ValueError("Cannot auto-create database: no database name in URL")
|
||||
|
||||
# Connect to the default 'postgres' database to issue CREATE DATABASE
|
||||
maint_url = parsed.set(database="postgres")
|
||||
maint_engine = create_async_engine(maint_url, isolation_level="AUTOCOMMIT")
|
||||
try:
|
||||
async with maint_engine.connect() as conn:
|
||||
await conn.execute(text(f'CREATE DATABASE "{db_name}"'))
|
||||
logger.info("Auto-created PostgreSQL database: %s", db_name)
|
||||
finally:
|
||||
await maint_engine.dispose()
|
||||
|
||||
|
||||
async def init_engine(
|
||||
backend: str,
|
||||
*,
|
||||
url: str = "",
|
||||
echo: bool = False,
|
||||
pool_size: int = 5,
|
||||
sqlite_dir: str = "",
|
||||
) -> None:
|
||||
"""Create the async engine and session factory, then auto-create tables.
|
||||
|
||||
Args:
|
||||
backend: "memory", "sqlite", or "postgres".
|
||||
url: SQLAlchemy async URL (for sqlite/postgres).
|
||||
echo: Echo SQL to log.
|
||||
pool_size: Postgres connection pool size.
|
||||
sqlite_dir: Directory to create for SQLite (ensured to exist).
|
||||
"""
|
||||
global _engine, _session_factory
|
||||
|
||||
if backend == "memory":
|
||||
logger.info("Persistence backend=memory -- ORM engine not initialized")
|
||||
return
|
||||
|
||||
if backend == "postgres":
|
||||
try:
|
||||
import asyncpg # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError("database.backend is set to 'postgres' but asyncpg is not installed.\nInstall it with:\n uv sync --extra postgres\nOr switch to backend: sqlite in config.yaml for single-node deployment.") from None
|
||||
|
||||
if backend == "sqlite":
|
||||
import os
|
||||
|
||||
from sqlalchemy import event
|
||||
|
||||
os.makedirs(sqlite_dir or ".", exist_ok=True)
|
||||
_engine = create_async_engine(url, echo=echo, json_serializer=_json_serializer)
|
||||
|
||||
# Enable WAL on every new connection. SQLite PRAGMA settings are
|
||||
# per-connection, so we wire the listener instead of running PRAGMA
|
||||
# once at startup. WAL gives concurrent reads + writers without
|
||||
# blocking and is the standard recommendation for any production
|
||||
# SQLite deployment (TC-UPG-06 in AUTH_TEST_PLAN.md). The companion
|
||||
# ``synchronous=NORMAL`` is the safe-and-fast pairing — fsync only
|
||||
# at WAL checkpoint boundaries instead of every commit.
|
||||
# Note: we do not set PRAGMA busy_timeout here — Python's sqlite3
|
||||
# driver already defaults to a 5-second busy timeout (see the
|
||||
# ``timeout`` kwarg of ``sqlite3.connect``), and aiosqlite /
|
||||
# SQLAlchemy's aiosqlite dialect inherit that default. Setting
|
||||
# it again would be a no-op.
|
||||
@event.listens_for(_engine.sync_engine, "connect")
|
||||
def _enable_sqlite_wal(dbapi_conn, _record): # noqa: ARG001 — SQLAlchemy contract
|
||||
cursor = dbapi_conn.cursor()
|
||||
try:
|
||||
cursor.execute("PRAGMA journal_mode=WAL;")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL;")
|
||||
cursor.execute("PRAGMA foreign_keys=ON;")
|
||||
finally:
|
||||
cursor.close()
|
||||
elif backend == "postgres":
|
||||
_engine = create_async_engine(
|
||||
url,
|
||||
echo=echo,
|
||||
pool_size=pool_size,
|
||||
pool_pre_ping=True,
|
||||
json_serializer=_json_serializer,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown persistence backend: {backend!r}")
|
||||
|
||||
_session_factory = async_sessionmaker(_engine, expire_on_commit=False)
|
||||
|
||||
# Auto-create tables (dev convenience). Production should use Alembic.
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
# Import all models so Base.metadata discovers them.
|
||||
# When no models exist yet (scaffolding phase), this is a no-op.
|
||||
try:
|
||||
import deerflow.persistence.models # noqa: F401
|
||||
except ImportError:
|
||||
# Models package not yet available — tables won't be auto-created.
|
||||
# This is expected during initial scaffolding or minimal installs.
|
||||
logger.debug("deerflow.persistence.models not found; skipping auto-create tables")
|
||||
|
||||
try:
|
||||
async with _engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
except Exception as exc:
|
||||
if backend == "postgres" and "does not exist" in str(exc):
|
||||
# Database not yet created — attempt to auto-create it, then retry.
|
||||
await _auto_create_postgres_db(url)
|
||||
# Rebuild engine against the now-existing database
|
||||
await _engine.dispose()
|
||||
_engine = create_async_engine(url, echo=echo, pool_size=pool_size, pool_pre_ping=True, json_serializer=_json_serializer)
|
||||
_session_factory = async_sessionmaker(_engine, expire_on_commit=False)
|
||||
async with _engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
else:
|
||||
raise
|
||||
|
||||
logger.info("Persistence engine initialized: backend=%s", backend)
|
||||
|
||||
|
||||
async def init_engine_from_config(config) -> None:
|
||||
"""Convenience: init engine from a DatabaseConfig object."""
|
||||
if config.backend == "memory":
|
||||
await init_engine("memory")
|
||||
return
|
||||
await init_engine(
|
||||
backend=config.backend,
|
||||
url=config.app_sqlalchemy_url,
|
||||
echo=config.echo_sql,
|
||||
pool_size=config.pool_size,
|
||||
sqlite_dir=config.sqlite_dir if config.backend == "sqlite" else "",
|
||||
)
|
||||
|
||||
|
||||
def get_session_factory() -> async_sessionmaker[AsyncSession] | None:
|
||||
"""Return the async session factory, or None if backend=memory."""
|
||||
return _session_factory
|
||||
|
||||
|
||||
def get_engine() -> AsyncEngine | None:
|
||||
"""Return the async engine, or None if not initialized."""
|
||||
return _engine
|
||||
|
||||
|
||||
async def close_engine() -> None:
|
||||
"""Dispose the engine, release all connections."""
|
||||
global _engine, _session_factory
|
||||
if _engine is not None:
|
||||
await _engine.dispose()
|
||||
logger.info("Persistence engine closed")
|
||||
_engine = None
|
||||
_session_factory = None
|
||||
@ -1,6 +0,0 @@
|
||||
"""Feedback persistence — ORM and SQL repository."""
|
||||
|
||||
from deerflow.persistence.feedback.model import FeedbackRow
|
||||
from deerflow.persistence.feedback.sql import FeedbackRepository
|
||||
|
||||
__all__ = ["FeedbackRepository", "FeedbackRow"]
|
||||
@ -1,34 +0,0 @@
|
||||
"""ORM model for user feedback on runs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import DateTime, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
class FeedbackRow(Base):
|
||||
__tablename__ = "feedback"
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"),
|
||||
)
|
||||
|
||||
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
message_id: Mapped[str | None] = mapped_column(String(64))
|
||||
# message_id is an optional RunEventStore event identifier —
|
||||
# allows feedback to target a specific message or the entire run
|
||||
|
||||
rating: Mapped[int] = mapped_column(nullable=False)
|
||||
# +1 (thumbs-up) or -1 (thumbs-down)
|
||||
|
||||
comment: Mapped[str | None] = mapped_column(Text)
|
||||
# Optional text feedback from the user
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
@ -1,217 +0,0 @@
|
||||
"""SQLAlchemy-backed feedback storage.
|
||||
|
||||
Each method acquires its own short-lived session.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import case, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.feedback.model import FeedbackRow
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
|
||||
|
||||
class FeedbackRepository:
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._sf = session_factory
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: FeedbackRow) -> dict:
|
||||
d = row.to_dict()
|
||||
val = d.get("created_at")
|
||||
if isinstance(val, datetime):
|
||||
d["created_at"] = val.isoformat()
|
||||
return d
|
||||
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
rating: int,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
message_id: str | None = None,
|
||||
comment: str | None = None,
|
||||
) -> dict:
|
||||
"""Create a feedback record. rating must be +1 or -1."""
|
||||
if rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.create")
|
||||
row = FeedbackRow(
|
||||
feedback_id=str(uuid.uuid4()),
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
user_id=resolved_user_id,
|
||||
message_id=message_id,
|
||||
rating=rating,
|
||||
comment=comment,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
async with self._sf() as session:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
feedback_id: str,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> dict | None:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(FeedbackRow, feedback_id)
|
||||
if row is None:
|
||||
return None
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def list_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_run")
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def list_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 100,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread")
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
feedback_id: str,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> bool:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(FeedbackRow, feedback_id)
|
||||
if row is None:
|
||||
return False
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return False
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
rating: int,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
comment: str | None = None,
|
||||
) -> dict:
|
||||
"""Create or update feedback for (thread_id, run_id, user_id). rating must be +1 or -1."""
|
||||
if rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.upsert")
|
||||
async with self._sf() as session:
|
||||
stmt = select(FeedbackRow).where(
|
||||
FeedbackRow.thread_id == thread_id,
|
||||
FeedbackRow.run_id == run_id,
|
||||
FeedbackRow.user_id == resolved_user_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is not None:
|
||||
row.rating = rating
|
||||
row.comment = comment
|
||||
row.created_at = datetime.now(UTC)
|
||||
else:
|
||||
row = FeedbackRow(
|
||||
feedback_id=str(uuid.uuid4()),
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
user_id=resolved_user_id,
|
||||
rating=rating,
|
||||
comment=comment,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def delete_by_run(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> bool:
|
||||
"""Delete the current user's feedback for a run. Returns True if a record was deleted."""
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete_by_run")
|
||||
async with self._sf() as session:
|
||||
stmt = select(FeedbackRow).where(
|
||||
FeedbackRow.thread_id == thread_id,
|
||||
FeedbackRow.run_id == run_id,
|
||||
FeedbackRow.user_id == resolved_user_id,
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
return False
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def list_by_thread_grouped(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> dict[str, dict]:
|
||||
"""Return feedback grouped by run_id for a thread: {run_id: feedback_dict}."""
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread_grouped")
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return {row.run_id: self._row_to_dict(row) for row in result.scalars()}
|
||||
|
||||
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict:
|
||||
"""Aggregate feedback stats for a run using database-side counting."""
|
||||
stmt = select(
|
||||
func.count().label("total"),
|
||||
func.coalesce(func.sum(case((FeedbackRow.rating == 1, 1), else_=0)), 0).label("positive"),
|
||||
func.coalesce(func.sum(case((FeedbackRow.rating == -1, 1), else_=0)), 0).label("negative"),
|
||||
).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
|
||||
async with self._sf() as session:
|
||||
row = (await session.execute(stmt)).one()
|
||||
return {
|
||||
"run_id": run_id,
|
||||
"total": row.total,
|
||||
"positive": row.positive,
|
||||
"negative": row.negative,
|
||||
}
|
||||
@ -1,38 +0,0 @@
|
||||
[alembic]
|
||||
script_location = %(here)s
|
||||
# Default URL for offline mode / autogenerate.
|
||||
# Runtime uses engine from DeerFlow config.
|
||||
sqlalchemy.url = sqlite+aiosqlite:///./data/deerflow.db
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
@ -1,65 +0,0 @@
|
||||
"""Alembic environment for DeerFlow application tables.
|
||||
|
||||
ONLY manages DeerFlow's tables (runs, threads_meta, cron_jobs, users).
|
||||
LangGraph's checkpointer tables are managed by LangGraph itself -- they
|
||||
have their own schema lifecycle and must not be touched by Alembic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
# Import all models so metadata is populated.
|
||||
try:
|
||||
import deerflow.persistence.models # noqa: F401 — register ORM models with Base.metadata
|
||||
except ImportError:
|
||||
# Models not available — migration will work with existing metadata only.
|
||||
logging.getLogger(__name__).warning("Could not import deerflow.persistence.models; Alembic may not detect all tables")
|
||||
|
||||
config = context.config
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
render_as_batch=True,
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection):
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
render_as_batch=True, # Required for SQLite ALTER TABLE support
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_migrations_online() -> None:
|
||||
connectable = create_async_engine(config.get_main_option("sqlalchemy.url"))
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_run_migrations)
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
asyncio.run(run_migrations_online())
|
||||
@ -1,23 +0,0 @@
|
||||
"""ORM model registration entry point.
|
||||
|
||||
Importing this module ensures all ORM models are registered with
|
||||
``Base.metadata`` so Alembic autogenerate detects every table.
|
||||
|
||||
The actual ORM classes have moved to entity-specific subpackages:
|
||||
- ``deerflow.persistence.thread_meta``
|
||||
- ``deerflow.persistence.run``
|
||||
- ``deerflow.persistence.feedback``
|
||||
- ``deerflow.persistence.user``
|
||||
|
||||
``RunEventRow`` remains in ``deerflow.persistence.models.run_event`` because
|
||||
its storage implementation lives in ``deerflow.runtime.events.store.db`` and
|
||||
there is no matching entity directory.
|
||||
"""
|
||||
|
||||
from deerflow.persistence.feedback.model import FeedbackRow
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.persistence.run.model import RunRow
|
||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow", "UserRow"]
|
||||
@ -1,35 +0,0 @@
|
||||
"""ORM model for run events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import JSON, DateTime, Index, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
class RunEventRow(Base):
|
||||
__tablename__ = "run_events"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
run_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
# Owner of the conversation this event belongs to. Nullable for data
|
||||
# created before auth was introduced; populated by auth middleware on
|
||||
# new writes and by the boot-time orphan migration on existing rows.
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
|
||||
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
category: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
# "message" | "trace" | "lifecycle"
|
||||
content: Mapped[str] = mapped_column(Text, default="")
|
||||
event_metadata: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
seq: Mapped[int] = mapped_column(nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("thread_id", "seq", name="uq_events_thread_seq"),
|
||||
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
|
||||
Index("ix_events_run", "thread_id", "run_id", "seq"),
|
||||
)
|
||||
@ -1,6 +0,0 @@
|
||||
"""Run metadata persistence — ORM and SQL repository."""
|
||||
|
||||
from deerflow.persistence.run.model import RunRow
|
||||
from deerflow.persistence.run.sql import RunRepository
|
||||
|
||||
__all__ = ["RunRepository", "RunRow"]
|
||||
@ -1,49 +0,0 @@
|
||||
"""ORM model for run metadata."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import JSON, DateTime, Index, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
class RunRow(Base):
|
||||
__tablename__ = "runs"
|
||||
|
||||
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128))
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="pending")
|
||||
# "pending" | "running" | "success" | "error" | "timeout" | "interrupted"
|
||||
|
||||
model_name: Mapped[str | None] = mapped_column(String(128))
|
||||
multitask_strategy: Mapped[str] = mapped_column(String(20), default="reject")
|
||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
kwargs_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
error: Mapped[str | None] = mapped_column(Text)
|
||||
|
||||
# Convenience fields (for listing pages without querying RunEventStore)
|
||||
message_count: Mapped[int] = mapped_column(default=0)
|
||||
first_human_message: Mapped[str | None] = mapped_column(Text)
|
||||
last_ai_message: Mapped[str | None] = mapped_column(Text)
|
||||
|
||||
# Token usage (accumulated in-memory by RunJournal, written on run completion)
|
||||
total_input_tokens: Mapped[int] = mapped_column(default=0)
|
||||
total_output_tokens: Mapped[int] = mapped_column(default=0)
|
||||
total_tokens: Mapped[int] = mapped_column(default=0)
|
||||
llm_call_count: Mapped[int] = mapped_column(default=0)
|
||||
lead_agent_tokens: Mapped[int] = mapped_column(default=0)
|
||||
subagent_tokens: Mapped[int] = mapped_column(default=0)
|
||||
middleware_tokens: Mapped[int] = mapped_column(default=0)
|
||||
|
||||
# Follow-up association
|
||||
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64))
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
|
||||
|
||||
__table_args__ = (Index("ix_runs_thread_status", "thread_id", "status"),)
|
||||
@ -1,255 +0,0 @@
|
||||
"""SQLAlchemy-backed RunStore implementation.
|
||||
|
||||
Each method acquires and releases its own short-lived session.
|
||||
Run status updates happen from background workers that may live
|
||||
minutes -- we don't hold connections across long execution.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.run.model import RunRow
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
|
||||
|
||||
class RunRepository(RunStore):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._sf = session_factory
|
||||
|
||||
@staticmethod
|
||||
def _safe_json(obj: Any) -> Any:
|
||||
"""Ensure obj is JSON-serializable. Falls back to model_dump() or str()."""
|
||||
if obj is None:
|
||||
return None
|
||||
if isinstance(obj, (str, int, float, bool)):
|
||||
return obj
|
||||
if isinstance(obj, dict):
|
||||
return {k: RunRepository._safe_json(v) for k, v in obj.items()}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [RunRepository._safe_json(v) for v in obj]
|
||||
if hasattr(obj, "model_dump"):
|
||||
try:
|
||||
return obj.model_dump()
|
||||
except Exception:
|
||||
pass
|
||||
if hasattr(obj, "dict"):
|
||||
try:
|
||||
return obj.dict()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
json.dumps(obj)
|
||||
return obj
|
||||
except (TypeError, ValueError):
|
||||
return str(obj)
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: RunRow) -> dict[str, Any]:
|
||||
d = row.to_dict()
|
||||
# Remap JSON columns to match RunStore interface
|
||||
d["metadata"] = d.pop("metadata_json", {})
|
||||
d["kwargs"] = d.pop("kwargs_json", {})
|
||||
# Convert datetime to ISO string for consistency with MemoryRunStore
|
||||
for key in ("created_at", "updated_at"):
|
||||
val = d.get(key)
|
||||
if isinstance(val, datetime):
|
||||
d[key] = val.isoformat()
|
||||
return d
|
||||
|
||||
async def put(
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
thread_id,
|
||||
assistant_id=None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
status="pending",
|
||||
multitask_strategy="reject",
|
||||
metadata=None,
|
||||
kwargs=None,
|
||||
error=None,
|
||||
created_at=None,
|
||||
follow_up_to_run_id=None,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put")
|
||||
now = datetime.now(UTC)
|
||||
row = RunRow(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
user_id=resolved_user_id,
|
||||
status=status,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata_json=self._safe_json(metadata) or {},
|
||||
kwargs_json=self._safe_json(kwargs) or {},
|
||||
error=error,
|
||||
follow_up_to_run_id=follow_up_to_run_id,
|
||||
created_at=datetime.fromisoformat(created_at) if created_at else now,
|
||||
updated_at=now,
|
||||
)
|
||||
async with self._sf() as session:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
|
||||
async def get(
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(RunRow, run_id)
|
||||
if row is None:
|
||||
return None
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def list_by_thread(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
limit=100,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.list_by_thread")
|
||||
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunRow.user_id == resolved_user_id)
|
||||
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def update_status(self, run_id, status, *, error=None):
|
||||
values: dict[str, Any] = {"status": status, "updated_at": datetime.now(UTC)}
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
async with self._sf() as session:
|
||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||
await session.commit()
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
run_id,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(RunRow, run_id)
|
||||
if row is None:
|
||||
return
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
|
||||
async def list_pending(self, *, before=None):
|
||||
if before is None:
|
||||
before_dt = datetime.now(UTC)
|
||||
elif isinstance(before, datetime):
|
||||
before_dt = before
|
||||
else:
|
||||
before_dt = datetime.fromisoformat(before)
|
||||
stmt = select(RunRow).where(RunRow.status == "pending", RunRow.created_at <= before_dt).order_by(RunRow.created_at.asc())
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
status: str,
|
||||
total_input_tokens: int = 0,
|
||||
total_output_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
llm_call_count: int = 0,
|
||||
lead_agent_tokens: int = 0,
|
||||
subagent_tokens: int = 0,
|
||||
middleware_tokens: int = 0,
|
||||
message_count: int = 0,
|
||||
last_ai_message: str | None = None,
|
||||
first_human_message: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Update status + token usage + convenience fields on run completion."""
|
||||
values: dict[str, Any] = {
|
||||
"status": status,
|
||||
"total_input_tokens": total_input_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"llm_call_count": llm_call_count,
|
||||
"lead_agent_tokens": lead_agent_tokens,
|
||||
"subagent_tokens": subagent_tokens,
|
||||
"middleware_tokens": middleware_tokens,
|
||||
"message_count": message_count,
|
||||
"updated_at": datetime.now(UTC),
|
||||
}
|
||||
if last_ai_message is not None:
|
||||
values["last_ai_message"] = last_ai_message[:2000]
|
||||
if first_human_message is not None:
|
||||
values["first_human_message"] = first_human_message[:2000]
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
async with self._sf() as session:
|
||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||
await session.commit()
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
"""Aggregate token usage via a single SQL GROUP BY query."""
|
||||
_completed = RunRow.status.in_(("success", "error"))
|
||||
_thread = RunRow.thread_id == thread_id
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
func.coalesce(RunRow.model_name, "unknown").label("model"),
|
||||
func.count().label("runs"),
|
||||
func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"),
|
||||
func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"),
|
||||
func.coalesce(func.sum(RunRow.total_output_tokens), 0).label("total_output_tokens"),
|
||||
func.coalesce(func.sum(RunRow.lead_agent_tokens), 0).label("lead_agent"),
|
||||
func.coalesce(func.sum(RunRow.subagent_tokens), 0).label("subagent"),
|
||||
func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"),
|
||||
)
|
||||
.where(_thread, _completed)
|
||||
.group_by(func.coalesce(RunRow.model_name, "unknown"))
|
||||
)
|
||||
|
||||
async with self._sf() as session:
|
||||
rows = (await session.execute(stmt)).all()
|
||||
|
||||
total_tokens = total_input = total_output = total_runs = 0
|
||||
lead_agent = subagent = middleware = 0
|
||||
by_model: dict[str, dict] = {}
|
||||
for r in rows:
|
||||
by_model[r.model] = {"tokens": r.total_tokens, "runs": r.runs}
|
||||
total_tokens += r.total_tokens
|
||||
total_input += r.total_input_tokens
|
||||
total_output += r.total_output_tokens
|
||||
total_runs += r.runs
|
||||
lead_agent += r.lead_agent
|
||||
subagent += r.subagent
|
||||
middleware += r.middleware
|
||||
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"total_input_tokens": total_input,
|
||||
"total_output_tokens": total_output,
|
||||
"total_runs": total_runs,
|
||||
"by_model": by_model,
|
||||
"by_caller": {
|
||||
"lead_agent": lead_agent,
|
||||
"subagent": subagent,
|
||||
"middleware": middleware,
|
||||
},
|
||||
}
|
||||
@ -1,38 +0,0 @@
|
||||
"""Thread metadata persistence — ORM, abstract store, and concrete implementations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||
from deerflow.persistence.thread_meta.sql import ThreadMetaRepository
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langgraph.store.base import BaseStore
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
__all__ = [
|
||||
"MemoryThreadMetaStore",
|
||||
"ThreadMetaRepository",
|
||||
"ThreadMetaRow",
|
||||
"ThreadMetaStore",
|
||||
"make_thread_store",
|
||||
]
|
||||
|
||||
|
||||
def make_thread_store(
|
||||
session_factory: async_sessionmaker[AsyncSession] | None,
|
||||
store: BaseStore | None = None,
|
||||
) -> ThreadMetaStore:
|
||||
"""Create the appropriate ThreadMetaStore based on available backends.
|
||||
|
||||
Returns a SQL-backed repository when a session factory is available,
|
||||
otherwise falls back to the in-memory LangGraph Store implementation.
|
||||
"""
|
||||
if session_factory is not None:
|
||||
return ThreadMetaRepository(session_factory)
|
||||
if store is None:
|
||||
raise ValueError("make_thread_store requires either a session_factory (SQL) or a store (memory)")
|
||||
return MemoryThreadMetaStore(store)
|
||||
@ -1,76 +0,0 @@
|
||||
"""Abstract interface for thread metadata storage.
|
||||
|
||||
Implementations:
|
||||
- ThreadMetaRepository: SQL-backed (sqlite / postgres via SQLAlchemy)
|
||||
- MemoryThreadMetaStore: wraps LangGraph BaseStore (memory mode)
|
||||
|
||||
All mutating and querying methods accept a ``user_id`` parameter with
|
||||
three-state semantics (see :mod:`deerflow.runtime.user_context`):
|
||||
|
||||
- ``AUTO`` (default): resolve from the request-scoped contextvar.
|
||||
- Explicit ``str``: use the provided value verbatim.
|
||||
- Explicit ``None``: bypass owner filtering (migration/CLI only).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel
|
||||
|
||||
|
||||
class ThreadMetaStore(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
async def create(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
display_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def search(
|
||||
self,
|
||||
*,
|
||||
metadata: dict | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
"""Merge ``metadata`` into the thread's metadata field.
|
||||
|
||||
Existing keys are overwritten by the new values; keys absent from
|
||||
``metadata`` are preserved. No-op if the thread does not exist
|
||||
or the owner check fails.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||
"""Check if ``user_id`` has access to ``thread_id``."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
pass
|
||||
@ -1,149 +0,0 @@
|
||||
"""In-memory ThreadMetaStore backed by LangGraph BaseStore.
|
||||
|
||||
Used when database.backend=memory. Delegates to the LangGraph Store's
|
||||
``("threads",)`` namespace — the same namespace used by the Gateway
|
||||
router for thread records.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
|
||||
THREADS_NS: tuple[str, ...] = ("threads",)
|
||||
|
||||
|
||||
class MemoryThreadMetaStore(ThreadMetaStore):
|
||||
def __init__(self, store: BaseStore) -> None:
|
||||
self._store = store
|
||||
|
||||
async def _get_owned_record(
|
||||
self,
|
||||
thread_id: str,
|
||||
user_id: str | None | _AutoSentinel,
|
||||
method_name: str,
|
||||
) -> dict | None:
|
||||
"""Fetch a record and verify ownership. Returns a mutable copy, or None."""
|
||||
resolved = resolve_user_id(user_id, method_name=method_name)
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
if item is None:
|
||||
return None
|
||||
record = dict(item.value)
|
||||
if resolved is not None and record.get("user_id") != resolved:
|
||||
return None
|
||||
return record
|
||||
|
||||
async def create(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
display_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.create")
|
||||
now = time.time()
|
||||
record: dict[str, Any] = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"user_id": resolved_user_id,
|
||||
"display_name": display_name,
|
||||
"status": "idle",
|
||||
"metadata": metadata or {},
|
||||
"values": {},
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
return record
|
||||
|
||||
async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None:
|
||||
return await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.get")
|
||||
|
||||
async def search(
|
||||
self,
|
||||
*,
|
||||
metadata: dict | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.search")
|
||||
filter_dict: dict[str, Any] = {}
|
||||
if metadata:
|
||||
filter_dict.update(metadata)
|
||||
if status:
|
||||
filter_dict["status"] = status
|
||||
if resolved_user_id is not None:
|
||||
filter_dict["user_id"] = resolved_user_id
|
||||
|
||||
items = await self._store.asearch(
|
||||
THREADS_NS,
|
||||
filter=filter_dict or None,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
return [self._item_to_dict(item) for item in items]
|
||||
|
||||
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||
item = await self._store.aget(THREADS_NS, thread_id)
|
||||
if item is None:
|
||||
return not require_existing
|
||||
record_user_id = item.value.get("user_id")
|
||||
if record_user_id is None:
|
||||
return True
|
||||
return record_user_id == user_id
|
||||
|
||||
async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_display_name")
|
||||
if record is None:
|
||||
return
|
||||
record["display_name"] = display_name
|
||||
record["updated_at"] = time.time()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_status")
|
||||
if record is None:
|
||||
return
|
||||
record["status"] = status
|
||||
record["updated_at"] = time.time()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_metadata")
|
||||
if record is None:
|
||||
return
|
||||
merged = dict(record.get("metadata") or {})
|
||||
merged.update(metadata)
|
||||
record["metadata"] = merged
|
||||
record["updated_at"] = time.time()
|
||||
await self._store.aput(THREADS_NS, thread_id, record)
|
||||
|
||||
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete")
|
||||
if record is None:
|
||||
return
|
||||
await self._store.adelete(THREADS_NS, thread_id)
|
||||
|
||||
@staticmethod
|
||||
def _item_to_dict(item) -> dict[str, Any]:
|
||||
"""Convert a Store SearchItem to the dict format expected by callers."""
|
||||
val = item.value
|
||||
return {
|
||||
"thread_id": item.key,
|
||||
"assistant_id": val.get("assistant_id"),
|
||||
"user_id": val.get("user_id"),
|
||||
"display_name": val.get("display_name"),
|
||||
"status": val.get("status", "idle"),
|
||||
"metadata": val.get("metadata", {}),
|
||||
"created_at": str(val.get("created_at", "")),
|
||||
"updated_at": str(val.get("updated_at", "")),
|
||||
}
|
||||
@ -1,23 +0,0 @@
|
||||
"""ORM model for thread metadata."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import JSON, DateTime, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
class ThreadMetaRow(Base):
|
||||
__tablename__ = "threads_meta"
|
||||
|
||||
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128), index=True)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
display_name: Mapped[str | None] = mapped_column(String(256))
|
||||
status: Mapped[str] = mapped_column(String(20), default="idle")
|
||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
|
||||
@ -1,217 +0,0 @@
|
||||
"""SQLAlchemy-backed thread metadata repository."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||
|
||||
|
||||
class ThreadMetaRepository(ThreadMetaStore):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._sf = session_factory
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: ThreadMetaRow) -> dict[str, Any]:
|
||||
d = row.to_dict()
|
||||
d["metadata"] = d.pop("metadata_json", {})
|
||||
for key in ("created_at", "updated_at"):
|
||||
val = d.get(key)
|
||||
if isinstance(val, datetime):
|
||||
d[key] = val.isoformat()
|
||||
return d
|
||||
|
||||
async def create(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
assistant_id: str | None = None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
display_name: str | None = None,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
# Auto-resolve user_id from contextvar when AUTO; explicit None
|
||||
# creates an orphan row (used by migration scripts).
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.create")
|
||||
now = datetime.now(UTC)
|
||||
row = ThreadMetaRow(
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
user_id=resolved_user_id,
|
||||
display_name=display_name,
|
||||
metadata_json=metadata or {},
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
async with self._sf() as session:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> dict | None:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.get")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return None
|
||||
# Enforce owner filter unless explicitly bypassed (user_id=None).
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return None
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||
"""Check if ``user_id`` has access to ``thread_id``.
|
||||
|
||||
Two modes — one row, two distinct semantics depending on what
|
||||
the caller is about to do:
|
||||
|
||||
- ``require_existing=False`` (default, permissive):
|
||||
Returns True for: row missing (untracked legacy thread),
|
||||
``row.user_id`` is None (shared / pre-auth data),
|
||||
or ``row.user_id == user_id``. Use for **read-style**
|
||||
decorators where treating an untracked thread as accessible
|
||||
preserves backward-compat.
|
||||
|
||||
- ``require_existing=True`` (strict):
|
||||
Returns True **only** when the row exists AND
|
||||
(``row.user_id == user_id`` OR ``row.user_id is None``).
|
||||
Use for **destructive / mutating** decorators (DELETE, PATCH,
|
||||
state-update) so a thread that has *already been deleted*
|
||||
cannot be re-targeted by any caller — closing the
|
||||
delete-idempotence cross-user gap where the row vanishing
|
||||
made every other user appear to "own" it.
|
||||
"""
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return not require_existing
|
||||
if row.user_id is None:
|
||||
return True
|
||||
return row.user_id == user_id
|
||||
|
||||
async def search(
|
||||
self,
|
||||
*,
|
||||
metadata: dict | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> list[dict]:
|
||||
"""Search threads with optional metadata and status filters.
|
||||
|
||||
Owner filter is enforced by default: caller must be in a user
|
||||
context. Pass ``user_id=None`` to bypass (migration/CLI).
|
||||
"""
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.search")
|
||||
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(ThreadMetaRow.user_id == resolved_user_id)
|
||||
if status:
|
||||
stmt = stmt.where(ThreadMetaRow.status == status)
|
||||
|
||||
if metadata:
|
||||
# When metadata filter is active, fetch a larger window and filter
|
||||
# in Python. TODO(Phase 2): use JSON DB operators (Postgres @>,
|
||||
# SQLite json_extract) for server-side filtering.
|
||||
stmt = stmt.limit(limit * 5 + offset)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
rows = [self._row_to_dict(r) for r in result.scalars()]
|
||||
rows = [r for r in rows if all(r.get("metadata", {}).get(k) == v for k, v in metadata.items())]
|
||||
return rows[offset : offset + limit]
|
||||
else:
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_user_id: str | None) -> bool:
|
||||
"""Return True if the row exists and is owned (or filter bypassed)."""
|
||||
if resolved_user_id is None:
|
||||
return True # explicit bypass
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
return row is not None and row.user_id == resolved_user_id
|
||||
|
||||
async def update_display_name(
|
||||
self,
|
||||
thread_id: str,
|
||||
display_name: str,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
"""Update the display_name (title) for a thread."""
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_display_name")
|
||||
async with self._sf() as session:
|
||||
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
||||
return
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
|
||||
async def update_status(
|
||||
self,
|
||||
thread_id: str,
|
||||
status: str,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_status")
|
||||
async with self._sf() as session:
|
||||
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
||||
return
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
|
||||
async def update_metadata(
|
||||
self,
|
||||
thread_id: str,
|
||||
metadata: dict,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
"""Merge ``metadata`` into ``metadata_json``.
|
||||
|
||||
Read-modify-write inside a single session/transaction so concurrent
|
||||
callers see consistent state. No-op if the row does not exist or
|
||||
the user_id check fails.
|
||||
"""
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_metadata")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return
|
||||
merged = dict(row.metadata_json or {})
|
||||
merged.update(metadata)
|
||||
row.metadata_json = merged
|
||||
row.updated_at = datetime.now(UTC)
|
||||
await session.commit()
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
) -> None:
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.delete")
|
||||
async with self._sf() as session:
|
||||
row = await session.get(ThreadMetaRow, thread_id)
|
||||
if row is None:
|
||||
return
|
||||
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||
return
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
@ -1,12 +0,0 @@
|
||||
"""User storage subpackage.
|
||||
|
||||
Holds the ORM model for the ``users`` table. The concrete repository
|
||||
implementation (``SQLiteUserRepository``) lives in the app layer
|
||||
(``app.gateway.auth.repositories.sqlite``) because it converts
|
||||
between the ORM row and the auth module's pydantic ``User`` class.
|
||||
This keeps the harness package free of any dependency on app code.
|
||||
"""
|
||||
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
__all__ = ["UserRow"]
|
||||
@ -1,59 +0,0 @@
|
||||
"""ORM model for the users table.
|
||||
|
||||
Lives in the harness persistence package so it is picked up by
|
||||
``Base.metadata.create_all()`` alongside ``threads_meta``, ``runs``,
|
||||
``run_events``, and ``feedback``. Using the shared engine means:
|
||||
|
||||
- One SQLite/Postgres database, one connection pool
|
||||
- One schema initialisation codepath
|
||||
- Consistent async sessions across auth and persistence reads
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Index, String, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
class UserRow(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
# UUIDs are stored as 36-char strings for cross-backend portability.
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||
|
||||
email: Mapped[str] = mapped_column(String(320), unique=True, nullable=False, index=True)
|
||||
password_hash: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
|
||||
# "admin" | "user" — kept as plain string to avoid ALTER TABLE pain
|
||||
# when new roles are introduced.
|
||||
system_role: Mapped[str] = mapped_column(String(16), nullable=False, default="user")
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=False,
|
||||
default=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
# OAuth linkage (optional). A partial unique index enforces one
|
||||
# account per (provider, oauth_id) pair, leaving NULL/NULL rows
|
||||
# unconstrained so plain password accounts can coexist.
|
||||
oauth_provider: Mapped[str | None] = mapped_column(String(32), nullable=True)
|
||||
oauth_id: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
|
||||
# Auth lifecycle flags
|
||||
needs_setup: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
token_version: Mapped[int] = mapped_column(nullable=False, default=0)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"idx_users_oauth_identity",
|
||||
"oauth_provider",
|
||||
"oauth_id",
|
||||
unique=True,
|
||||
sqlite_where=text("oauth_provider IS NOT NULL AND oauth_id IS NOT NULL"),
|
||||
),
|
||||
)
|
||||
@ -1,9 +0,0 @@
|
||||
from .async_provider import make_checkpointer
|
||||
from .provider import checkpointer_context, get_checkpointer, reset_checkpointer
|
||||
|
||||
__all__ = [
|
||||
"get_checkpointer",
|
||||
"reset_checkpointer",
|
||||
"checkpointer_context",
|
||||
"make_checkpointer",
|
||||
]
|
||||
@ -1,159 +0,0 @@
|
||||
"""Async checkpointer factory.
|
||||
|
||||
Provides an **async context manager** for long-running async servers that need
|
||||
proper resource cleanup.
|
||||
|
||||
Supported backends: memory, sqlite, postgres.
|
||||
|
||||
Usage (e.g. FastAPI lifespan)::
|
||||
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
|
||||
async with make_checkpointer() as checkpointer:
|
||||
app.state.checkpointer = checkpointer # InMemorySaver if not configured
|
||||
|
||||
For sync usage see :mod:`deerflow.runtime.checkpointer.provider`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.runtime.checkpointer.provider import (
|
||||
POSTGRES_CONN_REQUIRED,
|
||||
POSTGRES_INSTALL,
|
||||
SQLITE_INSTALL,
|
||||
)
|
||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Async factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
|
||||
"""Async context manager that constructs and tears down a checkpointer."""
|
||||
if config.type == "memory":
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
if config.type == "sqlite":
|
||||
try:
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(SQLITE_INSTALL) from exc
|
||||
|
||||
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||
await asyncio.to_thread(ensure_sqlite_parent_dir, conn_str)
|
||||
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
|
||||
await saver.setup()
|
||||
yield saver
|
||||
return
|
||||
|
||||
if config.type == "postgres":
|
||||
try:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_INSTALL) from exc
|
||||
|
||||
if not config.connection_string:
|
||||
raise ValueError(POSTGRES_CONN_REQUIRED)
|
||||
|
||||
async with AsyncPostgresSaver.from_conn_string(config.connection_string) as saver:
|
||||
await saver.setup()
|
||||
yield saver
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown checkpointer type: {config.type!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public async context manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpointer]:
|
||||
"""Async context manager that constructs a checkpointer from unified DatabaseConfig."""
|
||||
if db_config.backend == "memory":
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
if db_config.backend == "sqlite":
|
||||
try:
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(SQLITE_INSTALL) from exc
|
||||
|
||||
conn_str = db_config.checkpointer_sqlite_path
|
||||
ensure_sqlite_parent_dir(conn_str)
|
||||
async with AsyncSqliteSaver.from_conn_string(conn_str) as saver:
|
||||
await saver.setup()
|
||||
yield saver
|
||||
return
|
||||
|
||||
if db_config.backend == "postgres":
|
||||
try:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_INSTALL) from exc
|
||||
|
||||
if not db_config.postgres_url:
|
||||
raise ValueError("database.postgres_url is required for the postgres backend")
|
||||
|
||||
async with AsyncPostgresSaver.from_conn_string(db_config.postgres_url) as saver:
|
||||
await saver.setup()
|
||||
yield saver
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown database backend: {db_config.backend!r}")
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_checkpointer() -> AsyncIterator[Checkpointer]:
|
||||
"""Async context manager that yields a checkpointer for the caller's lifetime.
|
||||
Resources are opened on enter and closed on exit -- no global state::
|
||||
|
||||
async with make_checkpointer() as checkpointer:
|
||||
app.state.checkpointer = checkpointer
|
||||
|
||||
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
|
||||
Priority:
|
||||
1. Legacy ``checkpointer:`` config section (backward compatible)
|
||||
2. Unified ``database:`` config section
|
||||
3. Default InMemorySaver
|
||||
"""
|
||||
|
||||
config = get_app_config()
|
||||
|
||||
# Legacy: standalone checkpointer config takes precedence
|
||||
if config.checkpointer is not None:
|
||||
async with _async_checkpointer(config.checkpointer) as saver:
|
||||
yield saver
|
||||
return
|
||||
|
||||
# Unified database config
|
||||
db_config = getattr(config, "database", None)
|
||||
if db_config is not None and db_config.backend != "memory":
|
||||
async with _async_checkpointer_from_database(db_config) as saver:
|
||||
yield saver
|
||||
return
|
||||
|
||||
# Default: in-memory
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
@ -1,191 +0,0 @@
|
||||
"""Sync checkpointer factory.
|
||||
|
||||
Provides a **sync singleton** and a **sync context manager** for LangGraph
|
||||
graph compilation and CLI tools.
|
||||
|
||||
Supported backends: memory, sqlite, postgres.
|
||||
|
||||
Usage::
|
||||
|
||||
from deerflow.runtime.checkpointer.provider import get_checkpointer, checkpointer_context
|
||||
|
||||
# Singleton — reused across calls, closed on process exit
|
||||
cp = get_checkpointer()
|
||||
|
||||
# One-shot — fresh connection, closed on block exit
|
||||
with checkpointer_context() as cp:
|
||||
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig
|
||||
from deerflow.runtime.store._sqlite_utils import resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error message constants — imported by aio.provider too
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SQLITE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite checkpointer. Install it with: uv add langgraph-checkpoint-sqlite"
|
||||
POSTGRES_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
|
||||
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
|
||||
"""Context manager that creates and tears down a sync checkpointer.
|
||||
|
||||
Returns a configured ``Checkpointer`` instance. Resource cleanup for any
|
||||
underlying connections or pools is handled by higher-level helpers in
|
||||
this module (such as the singleton factory or context manager); this
|
||||
function does not return a separate cleanup callback.
|
||||
"""
|
||||
if config.type == "memory":
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
if config.type == "sqlite":
|
||||
try:
|
||||
from langgraph.checkpoint.sqlite import SqliteSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(SQLITE_INSTALL) from exc
|
||||
|
||||
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||
with SqliteSaver.from_conn_string(conn_str) as saver:
|
||||
saver.setup()
|
||||
logger.info("Checkpointer: using SqliteSaver (%s)", conn_str)
|
||||
yield saver
|
||||
return
|
||||
|
||||
if config.type == "postgres":
|
||||
try:
|
||||
from langgraph.checkpoint.postgres import PostgresSaver
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_INSTALL) from exc
|
||||
|
||||
if not config.connection_string:
|
||||
raise ValueError(POSTGRES_CONN_REQUIRED)
|
||||
|
||||
with PostgresSaver.from_conn_string(config.connection_string) as saver:
|
||||
saver.setup()
|
||||
logger.info("Checkpointer: using PostgresSaver")
|
||||
yield saver
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown checkpointer type: {config.type!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_checkpointer: Checkpointer | None = None
|
||||
_checkpointer_ctx = None # open context manager keeping the connection alive
|
||||
|
||||
|
||||
def get_checkpointer() -> Checkpointer:
|
||||
"""Return the global sync checkpointer singleton, creating it on first call.
|
||||
|
||||
Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
|
||||
Raises:
|
||||
ImportError: If the required package for the configured backend is not installed.
|
||||
ValueError: If ``connection_string`` is missing for a backend that requires it.
|
||||
"""
|
||||
global _checkpointer, _checkpointer_ctx
|
||||
|
||||
if _checkpointer is not None:
|
||||
return _checkpointer
|
||||
|
||||
# Ensure app config is loaded before checking checkpointer config
|
||||
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section
|
||||
# but hasn't been loaded yet
|
||||
from deerflow.config.app_config import _app_config
|
||||
from deerflow.config.checkpointer_config import get_checkpointer_config
|
||||
|
||||
config = get_checkpointer_config()
|
||||
|
||||
if config is None and _app_config is None:
|
||||
# Only load app config lazily when neither the app config nor an explicit
|
||||
# checkpointer config has been initialized yet. This keeps tests that
|
||||
# intentionally set the global checkpointer config isolated from any
|
||||
# ambient config.yaml on disk.
|
||||
try:
|
||||
get_app_config()
|
||||
except FileNotFoundError:
|
||||
# In test environments without config.yaml, this is expected.
|
||||
pass
|
||||
config = get_checkpointer_config()
|
||||
if config is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
|
||||
_checkpointer = InMemorySaver()
|
||||
return _checkpointer
|
||||
|
||||
_checkpointer_ctx = _sync_checkpointer_cm(config)
|
||||
_checkpointer = _checkpointer_ctx.__enter__()
|
||||
|
||||
return _checkpointer
|
||||
|
||||
|
||||
def reset_checkpointer() -> None:
|
||||
"""Reset the sync singleton, forcing recreation on the next call.
|
||||
|
||||
Closes any open backend connections and clears the cached instance.
|
||||
Useful in tests or after a configuration change.
|
||||
"""
|
||||
global _checkpointer, _checkpointer_ctx
|
||||
if _checkpointer_ctx is not None:
|
||||
try:
|
||||
_checkpointer_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error during checkpointer cleanup", exc_info=True)
|
||||
_checkpointer_ctx = None
|
||||
_checkpointer = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync context manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def checkpointer_context() -> Iterator[Checkpointer]:
|
||||
"""Sync context manager that yields a checkpointer and cleans up on exit.
|
||||
|
||||
Unlike :func:`get_checkpointer`, this does **not** cache the instance —
|
||||
each ``with`` block creates and destroys its own connection. Use it in
|
||||
CLI scripts or tests where you want deterministic cleanup::
|
||||
|
||||
with checkpointer_context() as cp:
|
||||
graph.invoke(input, config={"configurable": {"thread_id": "1"}})
|
||||
|
||||
Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*.
|
||||
"""
|
||||
|
||||
config = get_app_config()
|
||||
if config.checkpointer is None:
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
yield InMemorySaver()
|
||||
return
|
||||
|
||||
with _sync_checkpointer_cm(config.checkpointer) as saver:
|
||||
yield saver
|
||||
@ -1,4 +0,0 @@
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
__all__ = ["MemoryRunEventStore", "RunEventStore"]
|
||||
@ -1,26 +0,0 @@
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
|
||||
|
||||
def make_run_event_store(config=None) -> RunEventStore:
|
||||
"""Create a RunEventStore based on run_events.backend configuration."""
|
||||
if config is None or config.backend == "memory":
|
||||
return MemoryRunEventStore()
|
||||
if config.backend == "db":
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
# database.backend=memory but run_events.backend=db -> fallback
|
||||
return MemoryRunEventStore()
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
return DbRunEventStore(sf, max_trace_content=config.max_trace_content)
|
||||
if config.backend == "jsonl":
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
return JsonlRunEventStore()
|
||||
raise ValueError(f"Unknown run_events backend: {config.backend!r}")
|
||||
|
||||
|
||||
__all__ = ["MemoryRunEventStore", "RunEventStore", "make_run_event_store"]
|
||||
@ -1,109 +0,0 @@
|
||||
"""Abstract interface for run event storage.
|
||||
|
||||
RunEventStore is the unified storage interface for run event streams.
|
||||
Messages (frontend display) and execution traces (debugging/audit) go
|
||||
through the same interface, distinguished by the ``category`` field.
|
||||
|
||||
Implementations:
|
||||
- MemoryRunEventStore: in-memory dict (development, tests)
|
||||
- Future: DB-backed store (SQLAlchemy ORM), JSONL file store
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
|
||||
class RunEventStore(abc.ABC):
|
||||
"""Run event stream storage interface.
|
||||
|
||||
All implementations must guarantee:
|
||||
1. put() events are retrievable in subsequent queries
|
||||
2. seq is strictly increasing within the same thread
|
||||
3. list_messages() only returns category="message" events
|
||||
4. list_events() returns all events for the specified run
|
||||
5. Returned dicts match the RunEvent field structure
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def put(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
event_type: str,
|
||||
category: str,
|
||||
content: str | dict = "",
|
||||
metadata: dict | None = None,
|
||||
created_at: str | None = None,
|
||||
) -> dict:
|
||||
"""Write an event, auto-assign seq, return the complete record."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def put_batch(self, events: list[dict]) -> list[dict]:
|
||||
"""Batch-write events. Used by RunJournal flush buffer.
|
||||
|
||||
Each dict's keys match put()'s keyword arguments.
|
||||
Returns complete records with seq assigned.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[dict]:
|
||||
"""Return displayable messages (category=message) for a thread, ordered by seq ascending.
|
||||
|
||||
Supports bidirectional cursor pagination:
|
||||
- before_seq: return the last ``limit`` records with seq < before_seq (ascending)
|
||||
- after_seq: return the first ``limit`` records with seq > after_seq (ascending)
|
||||
- neither: return the latest ``limit`` records (ascending)
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
event_types: list[str] | None = None,
|
||||
limit: int = 500,
|
||||
) -> list[dict]:
|
||||
"""Return the full event stream for a run, ordered by seq ascending.
|
||||
|
||||
Optionally filter by event_types.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[dict]:
|
||||
"""Return displayable messages (category=message) for a specific run, ordered by seq ascending.
|
||||
|
||||
Supports bidirectional cursor pagination:
|
||||
- after_seq: return the first ``limit`` records with seq > after_seq (ascending)
|
||||
- before_seq: return the last ``limit`` records with seq < before_seq (ascending)
|
||||
- neither: return the latest ``limit`` records (ascending)
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def count_messages(self, thread_id: str) -> int:
|
||||
"""Count displayable messages (category=message) in a thread."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_by_thread(self, thread_id: str) -> int:
|
||||
"""Delete all events for a thread. Return the number of deleted events."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_by_run(self, thread_id: str, run_id: str) -> int:
|
||||
"""Delete all events for a specific run. Return the number of deleted events."""
|
||||
@ -1,286 +0,0 @@
|
||||
"""SQLAlchemy-backed RunEventStore implementation.
|
||||
|
||||
Persists events to the ``run_events`` table. Trace content is truncated
|
||||
at ``max_trace_content`` bytes to avoid bloating the database.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DbRunEventStore(RunEventStore):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, max_trace_content: int = 10240):
|
||||
self._sf = session_factory
|
||||
self._max_trace_content = max_trace_content
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: RunEventRow) -> dict:
|
||||
d = row.to_dict()
|
||||
d["metadata"] = d.pop("event_metadata", {})
|
||||
val = d.get("created_at")
|
||||
if isinstance(val, datetime):
|
||||
d["created_at"] = val.isoformat()
|
||||
d.pop("id", None)
|
||||
# Restore dict content that was JSON-serialized on write
|
||||
raw = d.get("content", "")
|
||||
if isinstance(raw, str) and d.get("metadata", {}).get("content_is_dict"):
|
||||
try:
|
||||
d["content"] = json.loads(raw)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
# Content looked like JSON (content_is_dict flag) but failed to parse;
|
||||
# keep the raw string as-is.
|
||||
logger.debug("Failed to deserialize content as JSON for event seq=%s", d.get("seq"))
|
||||
return d
|
||||
|
||||
def _truncate_trace(self, category: str, content: str | dict, metadata: dict | None) -> tuple[str | dict, dict]:
|
||||
if category == "trace":
|
||||
text = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content
|
||||
encoded = text.encode("utf-8")
|
||||
if len(encoded) > self._max_trace_content:
|
||||
# Truncate by bytes, then decode back (may cut a multi-byte char, so use errors="ignore")
|
||||
content = encoded[: self._max_trace_content].decode("utf-8", errors="ignore")
|
||||
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
|
||||
return content, metadata or {}
|
||||
|
||||
@staticmethod
|
||||
def _user_id_from_context() -> str | None:
|
||||
"""Soft read of user_id from contextvar for write paths.
|
||||
|
||||
Returns ``None`` (no filter / no stamp) if contextvar is unset,
|
||||
which is the expected case for background worker writes. HTTP
|
||||
request writes will have the contextvar set by auth middleware
|
||||
and get their user_id stamped automatically.
|
||||
|
||||
Coerces ``user.id`` to ``str`` at the boundary: ``User.id`` is
|
||||
typed as ``UUID`` by the auth layer, but ``run_events.user_id``
|
||||
is ``VARCHAR(64)`` and aiosqlite cannot bind a raw UUID object
|
||||
to a VARCHAR column ("type 'UUID' is not supported") — the
|
||||
INSERT would silently roll back and the worker would hang.
|
||||
"""
|
||||
user = get_current_user()
|
||||
return str(user.id) if user is not None else None
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
||||
"""Write a single event — low-frequency path only.
|
||||
|
||||
This opens a dedicated transaction with a FOR UPDATE lock to
|
||||
assign a monotonic *seq*. For high-throughput writes use
|
||||
:meth:`put_batch`, which acquires the lock once for the whole
|
||||
batch. Currently the only caller is ``worker.run_agent`` for
|
||||
the initial ``human_message`` event (once per run).
|
||||
"""
|
||||
content, metadata = self._truncate_trace(category, content, metadata)
|
||||
if isinstance(content, dict):
|
||||
db_content = json.dumps(content, default=str, ensure_ascii=False)
|
||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||
else:
|
||||
db_content = content
|
||||
user_id = self._user_id_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||
seq = (max_seq or 0) + 1
|
||||
row = RunEventRow(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
user_id=user_id,
|
||||
event_type=event_type,
|
||||
category=category,
|
||||
content=db_content,
|
||||
event_metadata=metadata,
|
||||
seq=seq,
|
||||
created_at=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC),
|
||||
)
|
||||
session.add(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
return []
|
||||
user_id = self._user_id_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||
thread_id = events[0]["thread_id"]
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||
seq = max_seq or 0
|
||||
rows = []
|
||||
for e in events:
|
||||
seq += 1
|
||||
content = e.get("content", "")
|
||||
category = e.get("category", "trace")
|
||||
metadata = e.get("metadata")
|
||||
content, metadata = self._truncate_trace(category, content, metadata)
|
||||
if isinstance(content, dict):
|
||||
db_content = json.dumps(content, default=str, ensure_ascii=False)
|
||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||
else:
|
||||
db_content = content
|
||||
row = RunEventRow(
|
||||
thread_id=e["thread_id"],
|
||||
run_id=e["run_id"],
|
||||
user_id=e.get("user_id", user_id),
|
||||
event_type=e["event_type"],
|
||||
category=category,
|
||||
content=db_content,
|
||||
event_metadata=metadata,
|
||||
seq=seq,
|
||||
created_at=datetime.fromisoformat(e["created_at"]) if e.get("created_at") else datetime.now(UTC),
|
||||
)
|
||||
session.add(row)
|
||||
rows.append(row)
|
||||
return [self._row_to_dict(r) for r in rows]
|
||||
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
limit=50,
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq > after_seq)
|
||||
|
||||
if after_seq is not None:
|
||||
# Forward pagination: first `limit` records after cursor
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
else:
|
||||
# before_seq or default (latest): take last `limit` records, return ascending
|
||||
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
rows = list(result.scalars())
|
||||
return [self._row_to_dict(r) for r in reversed(rows)]
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
event_types=None,
|
||||
limit=500,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_events")
|
||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
if event_types:
|
||||
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
limit=50,
|
||||
before_seq=None,
|
||||
after_seq=None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages_by_run")
|
||||
stmt = select(RunEventRow).where(
|
||||
RunEventRow.thread_id == thread_id,
|
||||
RunEventRow.run_id == run_id,
|
||||
RunEventRow.category == "message",
|
||||
)
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventRow.seq > after_seq)
|
||||
|
||||
if after_seq is not None:
|
||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
else:
|
||||
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
rows = list(result.scalars())
|
||||
return [self._row_to_dict(r) for r in reversed(rows)]
|
||||
|
||||
async def count_messages(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.count_messages")
|
||||
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||
if resolved_user_id is not None:
|
||||
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||
async with self._sf() as session:
|
||||
return await session.scalar(stmt) or 0
|
||||
|
||||
async def delete_by_thread(
|
||||
self,
|
||||
thread_id,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_thread")
|
||||
async with self._sf() as session:
|
||||
count_conditions = [RunEventRow.thread_id == thread_id]
|
||||
if resolved_user_id is not None:
|
||||
count_conditions.append(RunEventRow.user_id == resolved_user_id)
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
await session.execute(delete(RunEventRow).where(*count_conditions))
|
||||
await session.commit()
|
||||
return count
|
||||
|
||||
async def delete_by_run(
|
||||
self,
|
||||
thread_id,
|
||||
run_id,
|
||||
*,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
):
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_run")
|
||||
async with self._sf() as session:
|
||||
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
|
||||
if resolved_user_id is not None:
|
||||
count_conditions.append(RunEventRow.user_id == resolved_user_id)
|
||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||
count = await session.scalar(count_stmt) or 0
|
||||
if count > 0:
|
||||
await session.execute(delete(RunEventRow).where(*count_conditions))
|
||||
await session.commit()
|
||||
return count
|
||||
@ -1,187 +0,0 @@
|
||||
"""JSONL file-backed RunEventStore implementation.
|
||||
|
||||
Each run's events are stored in a single file:
|
||||
``.deer-flow/threads/{thread_id}/runs/{run_id}.jsonl``
|
||||
|
||||
All categories (message, trace, lifecycle) are in the same file.
|
||||
This backend is suitable for lightweight single-node deployments.
|
||||
|
||||
Known trade-off: ``list_messages()`` must scan all run files for a
|
||||
thread since messages from multiple runs need unified seq ordering.
|
||||
``list_events()`` reads only one file -- the fast path.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SAFE_ID_PATTERN = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
|
||||
|
||||
class JsonlRunEventStore(RunEventStore):
|
||||
def __init__(self, base_dir: str | Path | None = None):
|
||||
self._base_dir = Path(base_dir) if base_dir else Path(".deer-flow")
|
||||
self._seq_counters: dict[str, int] = {} # thread_id -> current max seq
|
||||
|
||||
@staticmethod
|
||||
def _validate_id(value: str, label: str) -> str:
|
||||
"""Validate that an ID is safe for use in filesystem paths."""
|
||||
if not value or not _SAFE_ID_PATTERN.match(value):
|
||||
raise ValueError(f"Invalid {label}: must be alphanumeric/dash/underscore, got {value!r}")
|
||||
return value
|
||||
|
||||
def _thread_dir(self, thread_id: str) -> Path:
|
||||
self._validate_id(thread_id, "thread_id")
|
||||
return self._base_dir / "threads" / thread_id / "runs"
|
||||
|
||||
def _run_file(self, thread_id: str, run_id: str) -> Path:
|
||||
self._validate_id(run_id, "run_id")
|
||||
return self._thread_dir(thread_id) / f"{run_id}.jsonl"
|
||||
|
||||
def _next_seq(self, thread_id: str) -> int:
|
||||
self._seq_counters[thread_id] = self._seq_counters.get(thread_id, 0) + 1
|
||||
return self._seq_counters[thread_id]
|
||||
|
||||
def _ensure_seq_loaded(self, thread_id: str) -> None:
|
||||
"""Load max seq from existing files if not yet cached."""
|
||||
if thread_id in self._seq_counters:
|
||||
return
|
||||
max_seq = 0
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if thread_dir.exists():
|
||||
for f in thread_dir.glob("*.jsonl"):
|
||||
for line in f.read_text(encoding="utf-8").strip().splitlines():
|
||||
try:
|
||||
record = json.loads(line)
|
||||
max_seq = max(max_seq, record.get("seq", 0))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSONL line in %s", f)
|
||||
continue
|
||||
self._seq_counters[thread_id] = max_seq
|
||||
|
||||
def _write_record(self, record: dict) -> None:
|
||||
path = self._run_file(record["thread_id"], record["run_id"])
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(record, default=str, ensure_ascii=False) + "\n")
|
||||
|
||||
def _read_thread_events(self, thread_id: str) -> list[dict]:
|
||||
"""Read all events for a thread, sorted by seq."""
|
||||
events = []
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if not thread_dir.exists():
|
||||
return events
|
||||
for f in sorted(thread_dir.glob("*.jsonl")):
|
||||
for line in f.read_text(encoding="utf-8").strip().splitlines():
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
events.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSONL line in %s", f)
|
||||
continue
|
||||
events.sort(key=lambda e: e.get("seq", 0))
|
||||
return events
|
||||
|
||||
def _read_run_events(self, thread_id: str, run_id: str) -> list[dict]:
|
||||
"""Read events for a specific run file."""
|
||||
path = self._run_file(thread_id, run_id)
|
||||
if not path.exists():
|
||||
return []
|
||||
events = []
|
||||
for line in path.read_text(encoding="utf-8").strip().splitlines():
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
events.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSONL line in %s", path)
|
||||
continue
|
||||
events.sort(key=lambda e: e.get("seq", 0))
|
||||
return events
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None):
|
||||
self._ensure_seq_loaded(thread_id)
|
||||
seq = self._next_seq(thread_id)
|
||||
record = {
|
||||
"thread_id": thread_id,
|
||||
"run_id": run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"seq": seq,
|
||||
"created_at": created_at or datetime.now(UTC).isoformat(),
|
||||
}
|
||||
self._write_record(record)
|
||||
return record
|
||||
|
||||
async def put_batch(self, events):
|
||||
if not events:
|
||||
return []
|
||||
results = []
|
||||
for ev in events:
|
||||
record = await self.put(**ev)
|
||||
results.append(record)
|
||||
return results
|
||||
|
||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
messages = [e for e in all_events if e.get("category") == "message"]
|
||||
|
||||
if before_seq is not None:
|
||||
messages = [e for e in messages if e["seq"] < before_seq]
|
||||
return messages[-limit:]
|
||||
elif after_seq is not None:
|
||||
messages = [e for e in messages if e["seq"] > after_seq]
|
||||
return messages[:limit]
|
||||
else:
|
||||
return messages[-limit:]
|
||||
|
||||
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
if event_types is not None:
|
||||
events = [e for e in events if e.get("event_type") in event_types]
|
||||
return events[:limit]
|
||||
|
||||
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
filtered = [e for e in events if e.get("category") == "message"]
|
||||
if before_seq is not None:
|
||||
filtered = [e for e in filtered if e.get("seq", 0) < before_seq]
|
||||
if after_seq is not None:
|
||||
filtered = [e for e in filtered if e.get("seq", 0) > after_seq]
|
||||
if after_seq is not None:
|
||||
return filtered[:limit]
|
||||
else:
|
||||
return filtered[-limit:] if len(filtered) > limit else filtered
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
return sum(1 for e in all_events if e.get("category") == "message")
|
||||
|
||||
async def delete_by_thread(self, thread_id):
|
||||
all_events = self._read_thread_events(thread_id)
|
||||
count = len(all_events)
|
||||
thread_dir = self._thread_dir(thread_id)
|
||||
if thread_dir.exists():
|
||||
for f in thread_dir.glob("*.jsonl"):
|
||||
f.unlink()
|
||||
self._seq_counters.pop(thread_id, None)
|
||||
return count
|
||||
|
||||
async def delete_by_run(self, thread_id, run_id):
|
||||
events = self._read_run_events(thread_id, run_id)
|
||||
count = len(events)
|
||||
path = self._run_file(thread_id, run_id)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
return count
|
||||
@ -1,128 +0,0 @@
|
||||
"""In-memory RunEventStore. Used when run_events.backend=memory (default) and in tests.
|
||||
|
||||
Thread-safe for single-process async usage (no threading locks needed
|
||||
since all mutations happen within the same event loop).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
|
||||
|
||||
class MemoryRunEventStore(RunEventStore):
|
||||
def __init__(self) -> None:
|
||||
self._events: dict[str, list[dict]] = {} # thread_id -> sorted event list
|
||||
self._seq_counters: dict[str, int] = {} # thread_id -> last assigned seq
|
||||
|
||||
def _next_seq(self, thread_id: str) -> int:
|
||||
current = self._seq_counters.get(thread_id, 0)
|
||||
next_val = current + 1
|
||||
self._seq_counters[thread_id] = next_val
|
||||
return next_val
|
||||
|
||||
def _put_one(
|
||||
self,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
event_type: str,
|
||||
category: str,
|
||||
content: str | dict = "",
|
||||
metadata: dict | None = None,
|
||||
created_at: str | None = None,
|
||||
) -> dict:
|
||||
seq = self._next_seq(thread_id)
|
||||
record = {
|
||||
"thread_id": thread_id,
|
||||
"run_id": run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"seq": seq,
|
||||
"created_at": created_at or datetime.now(UTC).isoformat(),
|
||||
}
|
||||
self._events.setdefault(thread_id, []).append(record)
|
||||
return record
|
||||
|
||||
async def put(
|
||||
self,
|
||||
*,
|
||||
thread_id,
|
||||
run_id,
|
||||
event_type,
|
||||
category,
|
||||
content="",
|
||||
metadata=None,
|
||||
created_at=None,
|
||||
):
|
||||
return self._put_one(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
event_type=event_type,
|
||||
category=category,
|
||||
content=content,
|
||||
metadata=metadata,
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
async def put_batch(self, events):
|
||||
results = []
|
||||
for ev in events:
|
||||
record = self._put_one(**ev)
|
||||
results.append(record)
|
||||
return results
|
||||
|
||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
messages = [e for e in all_events if e["category"] == "message"]
|
||||
|
||||
if before_seq is not None:
|
||||
messages = [e for e in messages if e["seq"] < before_seq]
|
||||
# Take the last `limit` records
|
||||
return messages[-limit:]
|
||||
elif after_seq is not None:
|
||||
messages = [e for e in messages if e["seq"] > after_seq]
|
||||
return messages[:limit]
|
||||
else:
|
||||
# Return the latest `limit` records, ascending
|
||||
return messages[-limit:]
|
||||
|
||||
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
filtered = [e for e in all_events if e["run_id"] == run_id]
|
||||
if event_types is not None:
|
||||
filtered = [e for e in filtered if e["event_type"] in event_types]
|
||||
return filtered[:limit]
|
||||
|
||||
async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
filtered = [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"]
|
||||
if before_seq is not None:
|
||||
filtered = [e for e in filtered if e["seq"] < before_seq]
|
||||
if after_seq is not None:
|
||||
filtered = [e for e in filtered if e["seq"] > after_seq]
|
||||
if after_seq is not None:
|
||||
return filtered[:limit]
|
||||
else:
|
||||
return filtered[-limit:] if len(filtered) > limit else filtered
|
||||
|
||||
async def count_messages(self, thread_id):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
return sum(1 for e in all_events if e["category"] == "message")
|
||||
|
||||
async def delete_by_thread(self, thread_id):
|
||||
events = self._events.pop(thread_id, [])
|
||||
self._seq_counters.pop(thread_id, None)
|
||||
return len(events)
|
||||
|
||||
async def delete_by_run(self, thread_id, run_id):
|
||||
all_events = self._events.get(thread_id, [])
|
||||
if not all_events:
|
||||
return 0
|
||||
remaining = [e for e in all_events if e["run_id"] != run_id]
|
||||
removed = len(all_events) - len(remaining)
|
||||
self._events[thread_id] = remaining
|
||||
return removed
|
||||
@ -1,374 +0,0 @@
|
||||
"""Run event capture via LangChain callbacks.
|
||||
|
||||
RunJournal sits between LangChain's callback mechanism and the pluggable
|
||||
RunEventStore. It standardizes callback data into RunEvent records and
|
||||
handles token usage accumulation.
|
||||
|
||||
Key design decisions:
|
||||
- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end
|
||||
- on_chat_model_start captures structured prompts as llm_request (OpenAI format) and
|
||||
extracts the first human message for run.input, because it is more reliable than
|
||||
on_chain_start (fires on every node) — messages here are fully structured.
|
||||
- on_chain_start with parent_run_id=None emits a run.start trace marking root invocation.
|
||||
- on_llm_end emits llm_response in OpenAI Chat Completions format
|
||||
- Token usage accumulated in memory, written to RunRow on run completion
|
||||
- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name})
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunJournal(BaseCallbackHandler):
|
||||
"""LangChain callback handler that captures events to RunEventStore."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
event_store: RunEventStore,
|
||||
*,
|
||||
track_token_usage: bool = True,
|
||||
flush_threshold: int = 20,
|
||||
):
|
||||
super().__init__()
|
||||
self.run_id = run_id
|
||||
self.thread_id = thread_id
|
||||
self._store = event_store
|
||||
self._track_tokens = track_token_usage
|
||||
self._flush_threshold = flush_threshold
|
||||
|
||||
# Write buffer
|
||||
self._buffer: list[dict] = []
|
||||
self._pending_flush_tasks: set[asyncio.Task[None]] = set()
|
||||
|
||||
# Token accumulators
|
||||
self._total_input_tokens = 0
|
||||
self._total_output_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._llm_call_count = 0
|
||||
self._lead_agent_tokens = 0
|
||||
self._subagent_tokens = 0
|
||||
self._middleware_tokens = 0
|
||||
|
||||
# Convenience fields
|
||||
self._last_ai_msg: str | None = None
|
||||
self._first_human_msg: str | None = None
|
||||
self._msg_count = 0
|
||||
|
||||
# Latency tracking
|
||||
self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time
|
||||
|
||||
# LLM request/response tracking
|
||||
self._llm_call_index = 0
|
||||
self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages
|
||||
|
||||
# -- Lifecycle callbacks --
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
inputs: dict[str, Any],
|
||||
*,
|
||||
run_id: UUID,
|
||||
parent_run_id: UUID | None = None,
|
||||
tags: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
caller = self._identify_caller(tags)
|
||||
if parent_run_id is None:
|
||||
# Root graph invocation — emit a single trace event for the run start.
|
||||
chain_name = (serialized or {}).get("name", "unknown")
|
||||
self._put(
|
||||
event_type="run.start",
|
||||
category="trace",
|
||||
content={"chain": chain_name},
|
||||
metadata={"caller": caller, **(metadata or {})},
|
||||
)
|
||||
|
||||
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
|
||||
self._flush_sync()
|
||||
|
||||
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._put(
|
||||
event_type="run.error",
|
||||
category="error",
|
||||
content=str(error),
|
||||
metadata={"error_type": type(error).__name__},
|
||||
)
|
||||
self._flush_sync()
|
||||
|
||||
# -- LLM callbacks --
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict,
|
||||
messages: list[list[BaseMessage]],
|
||||
*,
|
||||
run_id: UUID,
|
||||
tags: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Capture structured prompt messages for llm_request event.
|
||||
|
||||
This is also the canonical place to extract the first human message:
|
||||
messages are fully structured here, it fires only on real LLM calls,
|
||||
and the content is never compressed by checkpoint trimming.
|
||||
"""
|
||||
rid = str(run_id)
|
||||
self._llm_start_times[rid] = time.monotonic()
|
||||
self._llm_call_index += 1
|
||||
# Mark this run_id as seen so on_llm_end knows not to increment again.
|
||||
self._cached_prompts[rid] = []
|
||||
|
||||
logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}")
|
||||
|
||||
# Capture the first human message sent to any LLM in this run.
|
||||
if not self._first_human_msg:
|
||||
for batch in messages.reversed():
|
||||
for m in batch.reversed():
|
||||
if isinstance(m, HumanMessage) and m.name != "summary":
|
||||
caller = self._identify_caller(tags)
|
||||
self.set_first_human_message(m.text)
|
||||
self._put(
|
||||
event_type="llm.human.input",
|
||||
category="message",
|
||||
content=m.model_dump(),
|
||||
metadata={"caller": caller},
|
||||
)
|
||||
break
|
||||
if self._first_human_msg:
|
||||
break
|
||||
|
||||
def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, parent_run_id: UUID | None = None, tags: list[str] | None = None, metadata: dict[str, Any] | None = None, **kwargs: Any) -> None:
|
||||
# Fallback: on_chat_model_start is preferred. This just tracks latency.
|
||||
self._llm_start_times[str(run_id)] = time.monotonic()
|
||||
|
||||
def on_llm_end(self, response, *, run_id, parent_run_id, tags, **kwargs) -> None:
|
||||
messages: list[AnyMessage] = []
|
||||
logger.info(f"on_llm_end {run_id}: response: {tags} {kwargs}")
|
||||
for generation in response.generations:
|
||||
for gen in generation:
|
||||
if hasattr(gen, "message"):
|
||||
messages.append(gen.message)
|
||||
else:
|
||||
logger.warning(f"on_llm_end {run_id}: generation has no message attribute: {gen}")
|
||||
|
||||
for message in messages:
|
||||
caller = self._identify_caller(tags)
|
||||
|
||||
# Latency
|
||||
rid = str(run_id)
|
||||
start = self._llm_start_times.pop(rid, None)
|
||||
latency_ms = int((time.monotonic() - start) * 1000) if start else None
|
||||
|
||||
# Token usage from message
|
||||
usage = getattr(message, "usage_metadata", None)
|
||||
usage_dict = dict(usage) if usage else {}
|
||||
|
||||
# Resolve call index
|
||||
call_index = self._llm_call_index
|
||||
if rid not in self._cached_prompts:
|
||||
# Fallback: on_chat_model_start was not called
|
||||
self._llm_call_index += 1
|
||||
call_index = self._llm_call_index
|
||||
|
||||
# Trace event: llm_response (OpenAI completion format)
|
||||
self._put(
|
||||
event_type="llm.ai.response",
|
||||
category="message",
|
||||
content=message.model_dump(),
|
||||
metadata={
|
||||
"caller": caller,
|
||||
"usage": usage_dict,
|
||||
"latency_ms": latency_ms,
|
||||
"llm_call_index": call_index,
|
||||
},
|
||||
)
|
||||
|
||||
# Token accumulation
|
||||
if self._track_tokens:
|
||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||
if total_tk == 0:
|
||||
total_tk = input_tk + output_tk
|
||||
if total_tk > 0:
|
||||
self._total_input_tokens += input_tk
|
||||
self._total_output_tokens += output_tk
|
||||
self._total_tokens += total_tk
|
||||
self._llm_call_count += 1
|
||||
|
||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._llm_start_times.pop(str(run_id), None)
|
||||
self._put(event_type="llm.error", category="trace", content=str(error))
|
||||
|
||||
def on_tool_start(self, serialized, input_str, *, run_id, parent_run_id=None, tags=None, metadata=None, inputs=None, **kwargs):
|
||||
"""Handle tool start event, cache tool call ID for later correlation"""
|
||||
tool_call_id = str(run_id)
|
||||
logger.info(f"Tool start for node {run_id}, tool_call_id={tool_call_id}, tags={tags}, metadata={metadata}")
|
||||
|
||||
def on_tool_end(self, output, *, run_id, parent_run_id=None, **kwargs):
|
||||
"""Handle tool end event, append message and clear node data"""
|
||||
try:
|
||||
if isinstance(output, ToolMessage):
|
||||
msg = cast(ToolMessage, output)
|
||||
self._put(event_type="llm.tool.result", category="message", content=msg.model_dump())
|
||||
elif isinstance(output, Command):
|
||||
cmd = cast(Command, output)
|
||||
messages = cmd.update.get("messages", [])
|
||||
for message in messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
self._put(event_type="llm.tool.result", category="message", content=message.model_dump())
|
||||
else:
|
||||
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
|
||||
else:
|
||||
logger.warning(f"on_tool_end {run_id}: output is not ToolMessage: {type(output)}")
|
||||
finally:
|
||||
logger.info(f"Tool end for node {run_id}")
|
||||
|
||||
# -- Internal methods --
|
||||
|
||||
def _put(self, *, event_type: str, category: str, content: str | dict = "", metadata: dict | None = None) -> None:
|
||||
self._buffer.append(
|
||||
{
|
||||
"thread_id": self.thread_id,
|
||||
"run_id": self.run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
)
|
||||
if len(self._buffer) >= self._flush_threshold:
|
||||
self._flush_sync()
|
||||
|
||||
def _flush_sync(self) -> None:
|
||||
"""Best-effort flush of buffer to RunEventStore.
|
||||
|
||||
BaseCallbackHandler methods are synchronous. If an event loop is
|
||||
running we schedule an async ``put_batch``; otherwise the events
|
||||
stay in the buffer and are flushed later by the async ``flush()``
|
||||
call in the worker's ``finally`` block.
|
||||
"""
|
||||
if not self._buffer:
|
||||
return
|
||||
# Skip if a flush is already in flight — avoids concurrent writes
|
||||
# to the same SQLite file from multiple fire-and-forget tasks.
|
||||
if self._pending_flush_tasks:
|
||||
return
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# No event loop — keep events in buffer for later async flush.
|
||||
return
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
task = loop.create_task(self._flush_async(batch))
|
||||
self._pending_flush_tasks.add(task)
|
||||
task.add_done_callback(self._on_flush_done)
|
||||
|
||||
async def _flush_async(self, batch: list[dict]) -> None:
|
||||
try:
|
||||
await self._store.put_batch(batch)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to flush %d events for run %s — returning to buffer",
|
||||
len(batch),
|
||||
self.run_id,
|
||||
exc_info=True,
|
||||
)
|
||||
# Return failed events to buffer for retry on next flush
|
||||
self._buffer = batch + self._buffer
|
||||
|
||||
def _on_flush_done(self, task: asyncio.Task) -> None:
|
||||
self._pending_flush_tasks.discard(task)
|
||||
if task.cancelled():
|
||||
return
|
||||
exc = task.exception()
|
||||
if exc:
|
||||
logger.warning("Journal flush task failed: %s", exc)
|
||||
|
||||
def _identify_caller(self, tags: list[str] | None, **kwargs) -> str:
|
||||
_tags = tags or kwargs.get("tags", [])
|
||||
for tag in _tags:
|
||||
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
|
||||
return tag
|
||||
# Default to lead_agent: the main agent graph does not inject
|
||||
# callback tags, while subagents and middleware explicitly tag
|
||||
# themselves.
|
||||
return "lead_agent"
|
||||
|
||||
# -- Public methods (called by worker) --
|
||||
|
||||
def set_first_human_message(self, content: str) -> None:
|
||||
"""Record the first human message for convenience fields."""
|
||||
self._first_human_msg = content[:2000] if content else None
|
||||
|
||||
def record_middleware(self, tag: str, *, name: str, hook: str, action: str, changes: dict) -> None:
|
||||
"""Record a middleware state-change event.
|
||||
|
||||
Called by middleware implementations when they perform a meaningful
|
||||
state change (e.g., title generation, summarization, HITL approval).
|
||||
Pure-observation middleware should not call this.
|
||||
|
||||
Args:
|
||||
tag: Short identifier for the middleware (e.g., "title", "summarize",
|
||||
"guardrail"). Used to form event_type="middleware:{tag}".
|
||||
name: Full middleware class name.
|
||||
hook: Lifecycle hook that triggered the action (e.g., "after_model").
|
||||
action: Specific action performed (e.g., "generate_title").
|
||||
changes: Dict describing the state changes made.
|
||||
"""
|
||||
self._put(
|
||||
event_type=f"middleware:{tag}",
|
||||
category="middleware",
|
||||
content={"name": name, "hook": hook, "action": action, "changes": changes},
|
||||
)
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Force flush remaining buffer. Called in worker's finally block."""
|
||||
if self._pending_flush_tasks:
|
||||
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
|
||||
|
||||
while self._buffer:
|
||||
batch = self._buffer[: self._flush_threshold]
|
||||
del self._buffer[: self._flush_threshold]
|
||||
try:
|
||||
await self._store.put_batch(batch)
|
||||
except Exception:
|
||||
self._buffer = batch + self._buffer
|
||||
raise
|
||||
|
||||
def get_completion_data(self) -> dict:
|
||||
"""Return accumulated token and message data for run completion."""
|
||||
return {
|
||||
"total_input_tokens": self._total_input_tokens,
|
||||
"total_output_tokens": self._total_output_tokens,
|
||||
"total_tokens": self._total_tokens,
|
||||
"llm_call_count": self._llm_call_count,
|
||||
"lead_agent_tokens": self._lead_agent_tokens,
|
||||
"subagent_tokens": self._subagent_tokens,
|
||||
"middleware_tokens": self._middleware_tokens,
|
||||
"message_count": self._msg_count,
|
||||
"last_ai_message": self._last_ai_msg,
|
||||
"first_human_message": self._first_human_msg,
|
||||
}
|
||||
@ -1,31 +0,0 @@
|
||||
"""Store provider for the DeerFlow runtime.
|
||||
|
||||
Re-exports the public API of both the async provider (for long-running
|
||||
servers) and the sync provider (for CLI tools and the embedded client).
|
||||
|
||||
Async usage (FastAPI lifespan)::
|
||||
|
||||
from deerflow.runtime.store import make_store
|
||||
|
||||
async with make_store() as store:
|
||||
app.state.store = store
|
||||
|
||||
Sync usage (CLI / DeerFlowClient)::
|
||||
|
||||
from deerflow.runtime.store import get_store, store_context
|
||||
|
||||
store = get_store() # singleton
|
||||
with store_context() as store: ... # one-shot
|
||||
"""
|
||||
|
||||
from .async_provider import make_store
|
||||
from .provider import get_store, reset_store, store_context
|
||||
|
||||
__all__ = [
|
||||
# async
|
||||
"make_store",
|
||||
# sync
|
||||
"get_store",
|
||||
"reset_store",
|
||||
"store_context",
|
||||
]
|
||||
@ -1,28 +0,0 @@
|
||||
"""Shared SQLite connection utilities for store and checkpointer providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pathlib
|
||||
|
||||
from deerflow.config.paths import resolve_path
|
||||
|
||||
|
||||
def resolve_sqlite_conn_str(raw: str) -> str:
|
||||
"""Return a SQLite connection string ready for use with store/checkpointer backends.
|
||||
|
||||
SQLite special strings (``":memory:"`` and ``file:`` URIs) are returned
|
||||
unchanged. Plain filesystem paths — relative or absolute — are resolved
|
||||
to an absolute string via :func:`resolve_path`.
|
||||
"""
|
||||
if raw == ":memory:" or raw.startswith("file:"):
|
||||
return raw
|
||||
return str(resolve_path(raw))
|
||||
|
||||
|
||||
def ensure_sqlite_parent_dir(conn_str: str) -> None:
|
||||
"""Create parent directory for a SQLite filesystem path.
|
||||
|
||||
No-op for in-memory databases (``":memory:"``) and ``file:`` URIs.
|
||||
"""
|
||||
if conn_str != ":memory:" and not conn_str.startswith("file:"):
|
||||
pathlib.Path(conn_str).parent.mkdir(parents=True, exist_ok=True)
|
||||
@ -1,113 +0,0 @@
|
||||
"""Async Store factory — backend mirrors the configured checkpointer.
|
||||
|
||||
The store and checkpointer share the same ``checkpointer`` section in
|
||||
*config.yaml* so they always use the same persistence backend:
|
||||
|
||||
- ``type: memory`` → :class:`langgraph.store.memory.InMemoryStore`
|
||||
- ``type: sqlite`` → :class:`langgraph.store.sqlite.aio.AsyncSqliteStore`
|
||||
- ``type: postgres`` → :class:`langgraph.store.postgres.aio.AsyncPostgresStore`
|
||||
|
||||
Usage (e.g. FastAPI lifespan)::
|
||||
|
||||
from deerflow.runtime.store import make_store
|
||||
|
||||
async with make_store() as store:
|
||||
app.state.store = store
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.runtime.store.provider import POSTGRES_CONN_REQUIRED, POSTGRES_STORE_INSTALL, SQLITE_STORE_INSTALL, ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal backend factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def _async_store(config) -> AsyncIterator[BaseStore]:
|
||||
"""Async context manager that constructs and tears down a Store.
|
||||
|
||||
The ``config`` argument is a :class:`deerflow.config.checkpointer_config.CheckpointerConfig`
|
||||
instance — the same object used by the checkpointer factory.
|
||||
"""
|
||||
if config.type == "memory":
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
logger.info("Store: using InMemoryStore (in-process, not persistent)")
|
||||
yield InMemoryStore()
|
||||
return
|
||||
|
||||
if config.type == "sqlite":
|
||||
try:
|
||||
from langgraph.store.sqlite.aio import AsyncSqliteStore
|
||||
except ImportError as exc:
|
||||
raise ImportError(SQLITE_STORE_INSTALL) from exc
|
||||
|
||||
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||
ensure_sqlite_parent_dir(conn_str)
|
||||
|
||||
async with AsyncSqliteStore.from_conn_string(conn_str) as store:
|
||||
await store.setup()
|
||||
logger.info("Store: using AsyncSqliteStore (%s)", conn_str)
|
||||
yield store
|
||||
return
|
||||
|
||||
if config.type == "postgres":
|
||||
try:
|
||||
from langgraph.store.postgres.aio import AsyncPostgresStore # type: ignore[import]
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_STORE_INSTALL) from exc
|
||||
|
||||
if not config.connection_string:
|
||||
raise ValueError(POSTGRES_CONN_REQUIRED)
|
||||
|
||||
async with AsyncPostgresStore.from_conn_string(config.connection_string) as store:
|
||||
await store.setup()
|
||||
logger.info("Store: using AsyncPostgresStore")
|
||||
yield store
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown store backend type: {config.type!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public async context manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def make_store() -> AsyncIterator[BaseStore]:
|
||||
"""Async context manager that yields a Store whose backend matches the
|
||||
configured checkpointer.
|
||||
|
||||
Reads from the same ``checkpointer`` section of *config.yaml* used by
|
||||
:func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer` so
|
||||
that both singletons always use the same persistence technology::
|
||||
|
||||
async with make_store() as store:
|
||||
app.state.store = store
|
||||
|
||||
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
|
||||
``checkpointer`` section is configured (emits a WARNING in that case).
|
||||
"""
|
||||
config = get_app_config()
|
||||
|
||||
if config.checkpointer is None:
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
|
||||
yield InMemoryStore()
|
||||
return
|
||||
|
||||
async with _async_store(config.checkpointer) as store:
|
||||
yield store
|
||||
@ -1,188 +0,0 @@
|
||||
"""Sync Store factory.
|
||||
|
||||
Provides a **sync singleton** and a **sync context manager** for CLI tools
|
||||
and the embedded :class:`~deerflow.client.DeerFlowClient`.
|
||||
|
||||
The backend mirrors the configured checkpointer so that both always use the
|
||||
same persistence technology. Supported backends: memory, sqlite, postgres.
|
||||
|
||||
Usage::
|
||||
|
||||
from deerflow.runtime.store.provider import get_store, store_context
|
||||
|
||||
# Singleton — reused across calls, closed on process exit
|
||||
store = get_store()
|
||||
|
||||
# One-shot — fresh connection, closed on block exit
|
||||
with store_context() as store:
|
||||
store.put(("ns",), "key", {"value": 1})
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langgraph.store.base import BaseStore
|
||||
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error message constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SQLITE_STORE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite store. Install it with: uv add langgraph-checkpoint-sqlite"
|
||||
POSTGRES_STORE_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL store. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool"
|
||||
POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _sync_store_cm(config) -> Iterator[BaseStore]:
|
||||
"""Context manager that creates and tears down a sync Store.
|
||||
|
||||
The ``config`` argument is a
|
||||
:class:`~deerflow.config.checkpointer_config.CheckpointerConfig` instance —
|
||||
the same object used by the checkpointer factory.
|
||||
"""
|
||||
if config.type == "memory":
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
logger.info("Store: using InMemoryStore (in-process, not persistent)")
|
||||
yield InMemoryStore()
|
||||
return
|
||||
|
||||
if config.type == "sqlite":
|
||||
try:
|
||||
from langgraph.store.sqlite import SqliteStore
|
||||
except ImportError as exc:
|
||||
raise ImportError(SQLITE_STORE_INSTALL) from exc
|
||||
|
||||
conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db")
|
||||
ensure_sqlite_parent_dir(conn_str)
|
||||
|
||||
with SqliteStore.from_conn_string(conn_str) as store:
|
||||
store.setup()
|
||||
logger.info("Store: using SqliteStore (%s)", conn_str)
|
||||
yield store
|
||||
return
|
||||
|
||||
if config.type == "postgres":
|
||||
try:
|
||||
from langgraph.store.postgres import PostgresStore # type: ignore[import]
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_STORE_INSTALL) from exc
|
||||
|
||||
if not config.connection_string:
|
||||
raise ValueError(POSTGRES_CONN_REQUIRED)
|
||||
|
||||
with PostgresStore.from_conn_string(config.connection_string) as store:
|
||||
store.setup()
|
||||
logger.info("Store: using PostgresStore")
|
||||
yield store
|
||||
return
|
||||
|
||||
raise ValueError(f"Unknown store backend type: {config.type!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_store: BaseStore | None = None
|
||||
_store_ctx = None # open context manager keeping the connection alive
|
||||
|
||||
|
||||
def get_store() -> BaseStore:
|
||||
"""Return the global sync Store singleton, creating it on first call.
|
||||
|
||||
Returns an :class:`~langgraph.store.memory.InMemoryStore` when no
|
||||
checkpointer is configured in *config.yaml* (emits a WARNING in that case).
|
||||
|
||||
Raises:
|
||||
ImportError: If the required package for the configured backend is not installed.
|
||||
ValueError: If ``connection_string`` is missing for a backend that requires it.
|
||||
"""
|
||||
global _store, _store_ctx
|
||||
|
||||
if _store is not None:
|
||||
return _store
|
||||
|
||||
# Lazily load app config, mirroring the checkpointer singleton pattern so
|
||||
# that tests that set the global checkpointer config explicitly remain isolated.
|
||||
from deerflow.config.app_config import _app_config
|
||||
from deerflow.config.checkpointer_config import get_checkpointer_config
|
||||
|
||||
config = get_checkpointer_config()
|
||||
|
||||
if config is None and _app_config is None:
|
||||
try:
|
||||
get_app_config()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
config = get_checkpointer_config()
|
||||
|
||||
if config is None:
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
|
||||
_store = InMemoryStore()
|
||||
return _store
|
||||
|
||||
_store_ctx = _sync_store_cm(config)
|
||||
_store = _store_ctx.__enter__()
|
||||
return _store
|
||||
|
||||
|
||||
def reset_store() -> None:
|
||||
"""Reset the sync singleton, forcing recreation on the next call.
|
||||
|
||||
Closes any open backend connections and clears the cached instance.
|
||||
Useful in tests or after a configuration change.
|
||||
"""
|
||||
global _store, _store_ctx
|
||||
if _store_ctx is not None:
|
||||
try:
|
||||
_store_ctx.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error during store cleanup", exc_info=True)
|
||||
_store_ctx = None
|
||||
_store = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sync context manager
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def store_context() -> Iterator[BaseStore]:
|
||||
"""Sync context manager that yields a Store and cleans up on exit.
|
||||
|
||||
Unlike :func:`get_store`, this does **not** cache the instance — each
|
||||
``with`` block creates and destroys its own connection. Use it in CLI
|
||||
scripts or tests where you want deterministic cleanup::
|
||||
|
||||
with store_context() as store:
|
||||
store.put(("threads",), thread_id, {...})
|
||||
|
||||
Yields an :class:`~langgraph.store.memory.InMemoryStore` when no
|
||||
checkpointer is configured in *config.yaml*.
|
||||
"""
|
||||
config = get_app_config()
|
||||
if config.checkpointer is None:
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
|
||||
yield InMemoryStore()
|
||||
return
|
||||
|
||||
with _sync_store_cm(config.checkpointer) as store:
|
||||
yield store
|
||||
Loading…
x
Reference in New Issue
Block a user