diff --git a/backend/packages/harness/deerflow/config/checkpointer_config.py b/backend/packages/harness/deerflow/config/checkpointer_config.py deleted file mode 100644 index 6947cefb7..000000000 --- a/backend/packages/harness/deerflow/config/checkpointer_config.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Configuration for LangGraph checkpointer.""" - -from typing import Literal - -from pydantic import BaseModel, Field - -CheckpointerType = Literal["memory", "sqlite", "postgres"] - - -class CheckpointerConfig(BaseModel): - """Configuration for LangGraph state persistence checkpointer.""" - - type: CheckpointerType = Field( - description="Checkpointer backend type. " - "'memory' is in-process only (lost on restart). " - "'sqlite' persists to a local file (requires langgraph-checkpoint-sqlite). " - "'postgres' persists to PostgreSQL (requires langgraph-checkpoint-postgres)." - ) - connection_string: str | None = Field( - default=None, - description="Connection string for sqlite (file path) or postgres (DSN). " - "Required for sqlite and postgres types. " - "For sqlite, use a file path like '.deer-flow/checkpoints.db' or ':memory:' for in-memory. " - "For postgres, use a DSN like 'postgresql://user:pass@localhost:5432/db'.", - ) - - -# Global configuration instance — None means no checkpointer is configured. -_checkpointer_config: CheckpointerConfig | None = None - - -def get_checkpointer_config() -> CheckpointerConfig | None: - """Get the current checkpointer configuration, or None if not configured.""" - return _checkpointer_config - - -def set_checkpointer_config(config: CheckpointerConfig | None) -> None: - """Set the checkpointer configuration.""" - global _checkpointer_config - _checkpointer_config = config - - -def load_checkpointer_config_from_dict(config_dict: dict) -> None: - """Load checkpointer configuration from a dictionary.""" - global _checkpointer_config - _checkpointer_config = CheckpointerConfig(**config_dict) diff --git a/backend/packages/harness/deerflow/config/database_config.py b/backend/packages/harness/deerflow/config/database_config.py deleted file mode 100644 index 37cfd579d..000000000 --- a/backend/packages/harness/deerflow/config/database_config.py +++ /dev/null @@ -1,102 +0,0 @@ -"""Unified database backend configuration. - -Controls BOTH the LangGraph checkpointer and the DeerFlow application -persistence layer (runs, threads metadata, users, etc.). The user -configures one backend; the system handles physical separation details. - -SQLite mode: checkpointer and app share a single .db file -({sqlite_dir}/deerflow.db) with WAL journal mode enabled on every -connection. WAL allows concurrent readers and a single writer without -blocking, making a unified file safe for both workloads. Writers -that contend for the lock wait via the default 5-second sqlite3 -busy timeout rather than failing immediately. - -Postgres mode: both use the same database URL but maintain independent -connection pools with different lifecycles. - -Memory mode: checkpointer uses MemorySaver, app uses in-memory stores. -No database is initialized. - -Sensitive values (postgres_url) should use $VAR syntax in config.yaml -to reference environment variables from .env: - - database: - backend: postgres - postgres_url: $DATABASE_URL - -The $VAR resolution is handled by AppConfig.resolve_env_variables() -before this config is instantiated -- DatabaseConfig itself does not -need to do any environment variable processing. -""" - -from __future__ import annotations - -import os -from typing import Literal - -from pydantic import BaseModel, Field - - -class DatabaseConfig(BaseModel): - backend: Literal["memory", "sqlite", "postgres"] = Field( - default="memory", - description=("Storage backend for both checkpointer and application data. 'memory' for development (no persistence across restarts), 'sqlite' for single-node deployment, 'postgres' for production multi-node deployment."), - ) - sqlite_dir: str = Field( - default=".deer-flow/data", - description=("Directory for the SQLite database file. Both checkpointer and application data share {sqlite_dir}/deerflow.db."), - ) - postgres_url: str = Field( - default="", - description=( - "PostgreSQL connection URL, shared by checkpointer and app. " - "Use $DATABASE_URL in config.yaml to reference .env. " - "Example: postgresql://user:pass@host:5432/deerflow " - "(the +asyncpg driver suffix is added automatically where needed)." - ), - ) - echo_sql: bool = Field( - default=False, - description="Echo all SQL statements to log (debug only).", - ) - pool_size: int = Field( - default=5, - description="Connection pool size for the app ORM engine (postgres only).", - ) - - # -- Derived helpers (not user-configured) -- - - @property - def _resolved_sqlite_dir(self) -> str: - """Resolve sqlite_dir to an absolute path (relative to CWD).""" - from pathlib import Path - - return str(Path(self.sqlite_dir).resolve()) - - @property - def sqlite_path(self) -> str: - """Unified SQLite file path shared by checkpointer and app.""" - return os.path.join(self._resolved_sqlite_dir, "deerflow.db") - - # Backward-compatible aliases - @property - def checkpointer_sqlite_path(self) -> str: - """SQLite file path for the LangGraph checkpointer (alias for sqlite_path).""" - return self.sqlite_path - - @property - def app_sqlite_path(self) -> str: - """SQLite file path for application ORM data (alias for sqlite_path).""" - return self.sqlite_path - - @property - def app_sqlalchemy_url(self) -> str: - """SQLAlchemy async URL for the application ORM engine.""" - if self.backend == "sqlite": - return f"sqlite+aiosqlite:///{self.sqlite_path}" - if self.backend == "postgres": - url = self.postgres_url - if url.startswith("postgresql://"): - url = url.replace("postgresql://", "postgresql+asyncpg://", 1) - return url - raise ValueError(f"No SQLAlchemy URL for backend={self.backend!r}") diff --git a/backend/packages/harness/deerflow/persistence/__init__.py b/backend/packages/harness/deerflow/persistence/__init__.py deleted file mode 100644 index dfd64be95..000000000 --- a/backend/packages/harness/deerflow/persistence/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""DeerFlow application persistence layer (SQLAlchemy 2.0 async ORM). - -This module manages DeerFlow's own application data -- runs metadata, -thread ownership, cron jobs, users. It is completely separate from -LangGraph's checkpointer, which manages graph execution state. - -Usage: - from deerflow.persistence import init_engine, close_engine, get_session_factory -""" - -from deerflow.persistence.engine import close_engine, get_engine, get_session_factory, init_engine - -__all__ = ["close_engine", "get_engine", "get_session_factory", "init_engine"] diff --git a/backend/packages/harness/deerflow/persistence/base.py b/backend/packages/harness/deerflow/persistence/base.py deleted file mode 100644 index fd99d5f74..000000000 --- a/backend/packages/harness/deerflow/persistence/base.py +++ /dev/null @@ -1,40 +0,0 @@ -"""SQLAlchemy declarative base with automatic to_dict support. - -All DeerFlow ORM models inherit from this Base. It provides a generic -to_dict() method via SQLAlchemy's inspect() so individual models don't -need to write their own serialization logic. - -LangGraph's checkpointer tables are NOT managed by this Base. -""" - -from __future__ import annotations - -from sqlalchemy import inspect as sa_inspect -from sqlalchemy.orm import DeclarativeBase - - -class Base(DeclarativeBase): - """Base class for all DeerFlow ORM models. - - Provides: - - Automatic to_dict() via SQLAlchemy column inspection. - - Standard __repr__() showing all column values. - """ - - def to_dict(self, *, exclude: set[str] | None = None) -> dict: - """Convert ORM instance to plain dict. - - Uses SQLAlchemy's inspect() to iterate mapped column attributes. - - Args: - exclude: Optional set of column keys to omit. - - Returns: - Dict of {column_key: value} for all mapped columns. - """ - exclude = exclude or set() - return {c.key: getattr(self, c.key) for c in sa_inspect(type(self)).mapper.column_attrs if c.key not in exclude} - - def __repr__(self) -> str: - cols = ", ".join(f"{c.key}={getattr(self, c.key)!r}" for c in sa_inspect(type(self)).mapper.column_attrs) - return f"{type(self).__name__}({cols})" diff --git a/backend/packages/harness/deerflow/persistence/engine.py b/backend/packages/harness/deerflow/persistence/engine.py deleted file mode 100644 index 2777c2450..000000000 --- a/backend/packages/harness/deerflow/persistence/engine.py +++ /dev/null @@ -1,190 +0,0 @@ -"""Async SQLAlchemy engine lifecycle management. - -Initializes at Gateway startup, provides session factory for -repositories, disposes at shutdown. - -When database.backend="memory", init_engine is a no-op and -get_session_factory() returns None. Repositories must check for -None and fall back to in-memory implementations. -""" - -from __future__ import annotations - -import json -import logging - -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine - - -def _json_serializer(obj: object) -> str: - """JSON serializer with ensure_ascii=False for Chinese character support.""" - return json.dumps(obj, ensure_ascii=False) - - -logger = logging.getLogger(__name__) - -_engine: AsyncEngine | None = None -_session_factory: async_sessionmaker[AsyncSession] | None = None - - -async def _auto_create_postgres_db(url: str) -> None: - """Connect to the ``postgres`` maintenance DB and CREATE DATABASE. - - The target database name is extracted from *url*. The connection is - made to the default ``postgres`` database on the same server using - ``AUTOCOMMIT`` isolation (CREATE DATABASE cannot run inside a - transaction). - """ - from sqlalchemy import text - from sqlalchemy.engine.url import make_url - - parsed = make_url(url) - db_name = parsed.database - if not db_name: - raise ValueError("Cannot auto-create database: no database name in URL") - - # Connect to the default 'postgres' database to issue CREATE DATABASE - maint_url = parsed.set(database="postgres") - maint_engine = create_async_engine(maint_url, isolation_level="AUTOCOMMIT") - try: - async with maint_engine.connect() as conn: - await conn.execute(text(f'CREATE DATABASE "{db_name}"')) - logger.info("Auto-created PostgreSQL database: %s", db_name) - finally: - await maint_engine.dispose() - - -async def init_engine( - backend: str, - *, - url: str = "", - echo: bool = False, - pool_size: int = 5, - sqlite_dir: str = "", -) -> None: - """Create the async engine and session factory, then auto-create tables. - - Args: - backend: "memory", "sqlite", or "postgres". - url: SQLAlchemy async URL (for sqlite/postgres). - echo: Echo SQL to log. - pool_size: Postgres connection pool size. - sqlite_dir: Directory to create for SQLite (ensured to exist). - """ - global _engine, _session_factory - - if backend == "memory": - logger.info("Persistence backend=memory -- ORM engine not initialized") - return - - if backend == "postgres": - try: - import asyncpg # noqa: F401 - except ImportError: - raise ImportError("database.backend is set to 'postgres' but asyncpg is not installed.\nInstall it with:\n uv sync --extra postgres\nOr switch to backend: sqlite in config.yaml for single-node deployment.") from None - - if backend == "sqlite": - import os - - from sqlalchemy import event - - os.makedirs(sqlite_dir or ".", exist_ok=True) - _engine = create_async_engine(url, echo=echo, json_serializer=_json_serializer) - - # Enable WAL on every new connection. SQLite PRAGMA settings are - # per-connection, so we wire the listener instead of running PRAGMA - # once at startup. WAL gives concurrent reads + writers without - # blocking and is the standard recommendation for any production - # SQLite deployment (TC-UPG-06 in AUTH_TEST_PLAN.md). The companion - # ``synchronous=NORMAL`` is the safe-and-fast pairing — fsync only - # at WAL checkpoint boundaries instead of every commit. - # Note: we do not set PRAGMA busy_timeout here — Python's sqlite3 - # driver already defaults to a 5-second busy timeout (see the - # ``timeout`` kwarg of ``sqlite3.connect``), and aiosqlite / - # SQLAlchemy's aiosqlite dialect inherit that default. Setting - # it again would be a no-op. - @event.listens_for(_engine.sync_engine, "connect") - def _enable_sqlite_wal(dbapi_conn, _record): # noqa: ARG001 — SQLAlchemy contract - cursor = dbapi_conn.cursor() - try: - cursor.execute("PRAGMA journal_mode=WAL;") - cursor.execute("PRAGMA synchronous=NORMAL;") - cursor.execute("PRAGMA foreign_keys=ON;") - finally: - cursor.close() - elif backend == "postgres": - _engine = create_async_engine( - url, - echo=echo, - pool_size=pool_size, - pool_pre_ping=True, - json_serializer=_json_serializer, - ) - else: - raise ValueError(f"Unknown persistence backend: {backend!r}") - - _session_factory = async_sessionmaker(_engine, expire_on_commit=False) - - # Auto-create tables (dev convenience). Production should use Alembic. - from deerflow.persistence.base import Base - - # Import all models so Base.metadata discovers them. - # When no models exist yet (scaffolding phase), this is a no-op. - try: - import deerflow.persistence.models # noqa: F401 - except ImportError: - # Models package not yet available — tables won't be auto-created. - # This is expected during initial scaffolding or minimal installs. - logger.debug("deerflow.persistence.models not found; skipping auto-create tables") - - try: - async with _engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - except Exception as exc: - if backend == "postgres" and "does not exist" in str(exc): - # Database not yet created — attempt to auto-create it, then retry. - await _auto_create_postgres_db(url) - # Rebuild engine against the now-existing database - await _engine.dispose() - _engine = create_async_engine(url, echo=echo, pool_size=pool_size, pool_pre_ping=True, json_serializer=_json_serializer) - _session_factory = async_sessionmaker(_engine, expire_on_commit=False) - async with _engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - else: - raise - - logger.info("Persistence engine initialized: backend=%s", backend) - - -async def init_engine_from_config(config) -> None: - """Convenience: init engine from a DatabaseConfig object.""" - if config.backend == "memory": - await init_engine("memory") - return - await init_engine( - backend=config.backend, - url=config.app_sqlalchemy_url, - echo=config.echo_sql, - pool_size=config.pool_size, - sqlite_dir=config.sqlite_dir if config.backend == "sqlite" else "", - ) - - -def get_session_factory() -> async_sessionmaker[AsyncSession] | None: - """Return the async session factory, or None if backend=memory.""" - return _session_factory - - -def get_engine() -> AsyncEngine | None: - """Return the async engine, or None if not initialized.""" - return _engine - - -async def close_engine() -> None: - """Dispose the engine, release all connections.""" - global _engine, _session_factory - if _engine is not None: - await _engine.dispose() - logger.info("Persistence engine closed") - _engine = None - _session_factory = None diff --git a/backend/packages/harness/deerflow/persistence/feedback/__init__.py b/backend/packages/harness/deerflow/persistence/feedback/__init__.py deleted file mode 100644 index ee958b027..000000000 --- a/backend/packages/harness/deerflow/persistence/feedback/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Feedback persistence — ORM and SQL repository.""" - -from deerflow.persistence.feedback.model import FeedbackRow -from deerflow.persistence.feedback.sql import FeedbackRepository - -__all__ = ["FeedbackRepository", "FeedbackRow"] diff --git a/backend/packages/harness/deerflow/persistence/feedback/model.py b/backend/packages/harness/deerflow/persistence/feedback/model.py deleted file mode 100644 index f06bc84e7..000000000 --- a/backend/packages/harness/deerflow/persistence/feedback/model.py +++ /dev/null @@ -1,34 +0,0 @@ -"""ORM model for user feedback on runs.""" - -from __future__ import annotations - -from datetime import UTC, datetime - -from sqlalchemy import DateTime, String, Text, UniqueConstraint -from sqlalchemy.orm import Mapped, mapped_column - -from deerflow.persistence.base import Base - - -class FeedbackRow(Base): - __tablename__ = "feedback" - - __table_args__ = ( - UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"), - ) - - feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True) - run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True) - thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True) - user_id: Mapped[str | None] = mapped_column(String(64), index=True) - message_id: Mapped[str | None] = mapped_column(String(64)) - # message_id is an optional RunEventStore event identifier — - # allows feedback to target a specific message or the entire run - - rating: Mapped[int] = mapped_column(nullable=False) - # +1 (thumbs-up) or -1 (thumbs-down) - - comment: Mapped[str | None] = mapped_column(Text) - # Optional text feedback from the user - - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) diff --git a/backend/packages/harness/deerflow/persistence/feedback/sql.py b/backend/packages/harness/deerflow/persistence/feedback/sql.py deleted file mode 100644 index 1db74ce84..000000000 --- a/backend/packages/harness/deerflow/persistence/feedback/sql.py +++ /dev/null @@ -1,217 +0,0 @@ -"""SQLAlchemy-backed feedback storage. - -Each method acquires its own short-lived session. -""" - -from __future__ import annotations - -import uuid -from datetime import UTC, datetime - -from sqlalchemy import case, func, select -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker - -from deerflow.persistence.feedback.model import FeedbackRow -from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id - - -class FeedbackRepository: - def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: - self._sf = session_factory - - @staticmethod - def _row_to_dict(row: FeedbackRow) -> dict: - d = row.to_dict() - val = d.get("created_at") - if isinstance(val, datetime): - d["created_at"] = val.isoformat() - return d - - async def create( - self, - *, - run_id: str, - thread_id: str, - rating: int, - user_id: str | None | _AutoSentinel = AUTO, - message_id: str | None = None, - comment: str | None = None, - ) -> dict: - """Create a feedback record. rating must be +1 or -1.""" - if rating not in (1, -1): - raise ValueError(f"rating must be +1 or -1, got {rating}") - resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.create") - row = FeedbackRow( - feedback_id=str(uuid.uuid4()), - run_id=run_id, - thread_id=thread_id, - user_id=resolved_user_id, - message_id=message_id, - rating=rating, - comment=comment, - created_at=datetime.now(UTC), - ) - async with self._sf() as session: - session.add(row) - await session.commit() - await session.refresh(row) - return self._row_to_dict(row) - - async def get( - self, - feedback_id: str, - *, - user_id: str | None | _AutoSentinel = AUTO, - ) -> dict | None: - resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.get") - async with self._sf() as session: - row = await session.get(FeedbackRow, feedback_id) - if row is None: - return None - if resolved_user_id is not None and row.user_id != resolved_user_id: - return None - return self._row_to_dict(row) - - async def list_by_run( - self, - thread_id: str, - run_id: str, - *, - limit: int = 100, - user_id: str | None | _AutoSentinel = AUTO, - ) -> list[dict]: - resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_run") - stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id) - if resolved_user_id is not None: - stmt = stmt.where(FeedbackRow.user_id == resolved_user_id) - stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit) - async with self._sf() as session: - result = await session.execute(stmt) - return [self._row_to_dict(r) for r in result.scalars()] - - async def list_by_thread( - self, - thread_id: str, - *, - limit: int = 100, - user_id: str | None | _AutoSentinel = AUTO, - ) -> list[dict]: - resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread") - stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id) - if resolved_user_id is not None: - stmt = stmt.where(FeedbackRow.user_id == resolved_user_id) - stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit) - async with self._sf() as session: - result = await session.execute(stmt) - return [self._row_to_dict(r) for r in result.scalars()] - - async def delete( - self, - feedback_id: str, - *, - user_id: str | None | _AutoSentinel = AUTO, - ) -> bool: - resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete") - async with self._sf() as session: - row = await session.get(FeedbackRow, feedback_id) - if row is None: - return False - if resolved_user_id is not None and row.user_id != resolved_user_id: - return False - await session.delete(row) - await session.commit() - return True - - async def upsert( - self, - *, - run_id: str, - thread_id: str, - rating: int, - user_id: str | None | _AutoSentinel = AUTO, - comment: str | None = None, - ) -> dict: - """Create or update feedback for (thread_id, run_id, user_id). rating must be +1 or -1.""" - if rating not in (1, -1): - raise ValueError(f"rating must be +1 or -1, got {rating}") - resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.upsert") - async with self._sf() as session: - stmt = select(FeedbackRow).where( - FeedbackRow.thread_id == thread_id, - FeedbackRow.run_id == run_id, - FeedbackRow.user_id == resolved_user_id, - ) - result = await session.execute(stmt) - row = result.scalar_one_or_none() - if row is not None: - row.rating = rating - row.comment = comment - row.created_at = datetime.now(UTC) - else: - row = FeedbackRow( - feedback_id=str(uuid.uuid4()), - run_id=run_id, - thread_id=thread_id, - user_id=resolved_user_id, - rating=rating, - comment=comment, - created_at=datetime.now(UTC), - ) - session.add(row) - await session.commit() - await session.refresh(row) - return self._row_to_dict(row) - - async def delete_by_run( - self, - *, - thread_id: str, - run_id: str, - user_id: str | None | _AutoSentinel = AUTO, - ) -> bool: - """Delete the current user's feedback for a run. Returns True if a record was deleted.""" - resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete_by_run") - async with self._sf() as session: - stmt = select(FeedbackRow).where( - FeedbackRow.thread_id == thread_id, - FeedbackRow.run_id == run_id, - FeedbackRow.user_id == resolved_user_id, - ) - result = await session.execute(stmt) - row = result.scalar_one_or_none() - if row is None: - return False - await session.delete(row) - await session.commit() - return True - - async def list_by_thread_grouped( - self, - thread_id: str, - *, - user_id: str | None | _AutoSentinel = AUTO, - ) -> dict[str, dict]: - """Return feedback grouped by run_id for a thread: {run_id: feedback_dict}.""" - resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread_grouped") - stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id) - if resolved_user_id is not None: - stmt = stmt.where(FeedbackRow.user_id == resolved_user_id) - async with self._sf() as session: - result = await session.execute(stmt) - return {row.run_id: self._row_to_dict(row) for row in result.scalars()} - - async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict: - """Aggregate feedback stats for a run using database-side counting.""" - stmt = select( - func.count().label("total"), - func.coalesce(func.sum(case((FeedbackRow.rating == 1, 1), else_=0)), 0).label("positive"), - func.coalesce(func.sum(case((FeedbackRow.rating == -1, 1), else_=0)), 0).label("negative"), - ).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id) - async with self._sf() as session: - row = (await session.execute(stmt)).one() - return { - "run_id": run_id, - "total": row.total, - "positive": row.positive, - "negative": row.negative, - } diff --git a/backend/packages/harness/deerflow/persistence/migrations/alembic.ini b/backend/packages/harness/deerflow/persistence/migrations/alembic.ini deleted file mode 100644 index 71b4b1dc0..000000000 --- a/backend/packages/harness/deerflow/persistence/migrations/alembic.ini +++ /dev/null @@ -1,38 +0,0 @@ -[alembic] -script_location = %(here)s -# Default URL for offline mode / autogenerate. -# Runtime uses engine from DeerFlow config. -sqlalchemy.url = sqlite+aiosqlite:///./data/deerflow.db - -[loggers] -keys = root,sqlalchemy,alembic - -[handlers] -keys = console - -[formatters] -keys = generic - -[logger_root] -level = WARN -handlers = console - -[logger_sqlalchemy] -level = WARN -handlers = -qualname = sqlalchemy.engine - -[logger_alembic] -level = INFO -handlers = -qualname = alembic - -[handler_console] -class = StreamHandler -args = (sys.stderr,) -level = NOTSET -formatter = generic - -[formatter_generic] -format = %(levelname)-5.5s [%(name)s] %(message)s -datefmt = %H:%M:%S diff --git a/backend/packages/harness/deerflow/persistence/migrations/env.py b/backend/packages/harness/deerflow/persistence/migrations/env.py deleted file mode 100644 index 04c186fa0..000000000 --- a/backend/packages/harness/deerflow/persistence/migrations/env.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Alembic environment for DeerFlow application tables. - -ONLY manages DeerFlow's tables (runs, threads_meta, cron_jobs, users). -LangGraph's checkpointer tables are managed by LangGraph itself -- they -have their own schema lifecycle and must not be touched by Alembic. -""" - -from __future__ import annotations - -import asyncio -import logging -from logging.config import fileConfig - -from alembic import context -from sqlalchemy.ext.asyncio import create_async_engine - -from deerflow.persistence.base import Base - -# Import all models so metadata is populated. -try: - import deerflow.persistence.models # noqa: F401 — register ORM models with Base.metadata -except ImportError: - # Models not available — migration will work with existing metadata only. - logging.getLogger(__name__).warning("Could not import deerflow.persistence.models; Alembic may not detect all tables") - -config = context.config -if config.config_file_name is not None: - fileConfig(config.config_file_name) - -target_metadata = Base.metadata - - -def run_migrations_offline() -> None: - url = config.get_main_option("sqlalchemy.url") - context.configure( - url=url, - target_metadata=target_metadata, - literal_binds=True, - render_as_batch=True, - ) - with context.begin_transaction(): - context.run_migrations() - - -def do_run_migrations(connection): - context.configure( - connection=connection, - target_metadata=target_metadata, - render_as_batch=True, # Required for SQLite ALTER TABLE support - ) - with context.begin_transaction(): - context.run_migrations() - - -async def run_migrations_online() -> None: - connectable = create_async_engine(config.get_main_option("sqlalchemy.url")) - async with connectable.connect() as connection: - await connection.run_sync(do_run_migrations) - await connectable.dispose() - - -if context.is_offline_mode(): - run_migrations_offline() -else: - asyncio.run(run_migrations_online()) diff --git a/backend/packages/harness/deerflow/persistence/migrations/versions/.gitkeep b/backend/packages/harness/deerflow/persistence/migrations/versions/.gitkeep deleted file mode 100644 index e69de29bb..000000000 diff --git a/backend/packages/harness/deerflow/persistence/models/__init__.py b/backend/packages/harness/deerflow/persistence/models/__init__.py deleted file mode 100644 index ab29a3536..000000000 --- a/backend/packages/harness/deerflow/persistence/models/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -"""ORM model registration entry point. - -Importing this module ensures all ORM models are registered with -``Base.metadata`` so Alembic autogenerate detects every table. - -The actual ORM classes have moved to entity-specific subpackages: -- ``deerflow.persistence.thread_meta`` -- ``deerflow.persistence.run`` -- ``deerflow.persistence.feedback`` -- ``deerflow.persistence.user`` - -``RunEventRow`` remains in ``deerflow.persistence.models.run_event`` because -its storage implementation lives in ``deerflow.runtime.events.store.db`` and -there is no matching entity directory. -""" - -from deerflow.persistence.feedback.model import FeedbackRow -from deerflow.persistence.models.run_event import RunEventRow -from deerflow.persistence.run.model import RunRow -from deerflow.persistence.thread_meta.model import ThreadMetaRow -from deerflow.persistence.user.model import UserRow - -__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow", "UserRow"] diff --git a/backend/packages/harness/deerflow/persistence/models/run_event.py b/backend/packages/harness/deerflow/persistence/models/run_event.py deleted file mode 100644 index 4f22b4616..000000000 --- a/backend/packages/harness/deerflow/persistence/models/run_event.py +++ /dev/null @@ -1,35 +0,0 @@ -"""ORM model for run events.""" - -from __future__ import annotations - -from datetime import UTC, datetime - -from sqlalchemy import JSON, DateTime, Index, String, Text, UniqueConstraint -from sqlalchemy.orm import Mapped, mapped_column - -from deerflow.persistence.base import Base - - -class RunEventRow(Base): - __tablename__ = "run_events" - - id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) - thread_id: Mapped[str] = mapped_column(String(64), nullable=False) - run_id: Mapped[str] = mapped_column(String(64), nullable=False) - # Owner of the conversation this event belongs to. Nullable for data - # created before auth was introduced; populated by auth middleware on - # new writes and by the boot-time orphan migration on existing rows. - user_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True) - event_type: Mapped[str] = mapped_column(String(32), nullable=False) - category: Mapped[str] = mapped_column(String(16), nullable=False) - # "message" | "trace" | "lifecycle" - content: Mapped[str] = mapped_column(Text, default="") - event_metadata: Mapped[dict] = mapped_column(JSON, default=dict) - seq: Mapped[int] = mapped_column(nullable=False) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) - - __table_args__ = ( - UniqueConstraint("thread_id", "seq", name="uq_events_thread_seq"), - Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"), - Index("ix_events_run", "thread_id", "run_id", "seq"), - ) diff --git a/backend/packages/harness/deerflow/persistence/run/__init__.py b/backend/packages/harness/deerflow/persistence/run/__init__.py deleted file mode 100644 index 0aa01e7ea..000000000 --- a/backend/packages/harness/deerflow/persistence/run/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Run metadata persistence — ORM and SQL repository.""" - -from deerflow.persistence.run.model import RunRow -from deerflow.persistence.run.sql import RunRepository - -__all__ = ["RunRepository", "RunRow"] diff --git a/backend/packages/harness/deerflow/persistence/run/model.py b/backend/packages/harness/deerflow/persistence/run/model.py deleted file mode 100644 index d0dfe4085..000000000 --- a/backend/packages/harness/deerflow/persistence/run/model.py +++ /dev/null @@ -1,49 +0,0 @@ -"""ORM model for run metadata.""" - -from __future__ import annotations - -from datetime import UTC, datetime - -from sqlalchemy import JSON, DateTime, Index, String, Text -from sqlalchemy.orm import Mapped, mapped_column - -from deerflow.persistence.base import Base - - -class RunRow(Base): - __tablename__ = "runs" - - run_id: Mapped[str] = mapped_column(String(64), primary_key=True) - thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True) - assistant_id: Mapped[str | None] = mapped_column(String(128)) - user_id: Mapped[str | None] = mapped_column(String(64), index=True) - status: Mapped[str] = mapped_column(String(20), default="pending") - # "pending" | "running" | "success" | "error" | "timeout" | "interrupted" - - model_name: Mapped[str | None] = mapped_column(String(128)) - multitask_strategy: Mapped[str] = mapped_column(String(20), default="reject") - metadata_json: Mapped[dict] = mapped_column(JSON, default=dict) - kwargs_json: Mapped[dict] = mapped_column(JSON, default=dict) - error: Mapped[str | None] = mapped_column(Text) - - # Convenience fields (for listing pages without querying RunEventStore) - message_count: Mapped[int] = mapped_column(default=0) - first_human_message: Mapped[str | None] = mapped_column(Text) - last_ai_message: Mapped[str | None] = mapped_column(Text) - - # Token usage (accumulated in-memory by RunJournal, written on run completion) - total_input_tokens: Mapped[int] = mapped_column(default=0) - total_output_tokens: Mapped[int] = mapped_column(default=0) - total_tokens: Mapped[int] = mapped_column(default=0) - llm_call_count: Mapped[int] = mapped_column(default=0) - lead_agent_tokens: Mapped[int] = mapped_column(default=0) - subagent_tokens: Mapped[int] = mapped_column(default=0) - middleware_tokens: Mapped[int] = mapped_column(default=0) - - # Follow-up association - follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64)) - - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) - updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)) - - __table_args__ = (Index("ix_runs_thread_status", "thread_id", "status"),) diff --git a/backend/packages/harness/deerflow/persistence/run/sql.py b/backend/packages/harness/deerflow/persistence/run/sql.py deleted file mode 100644 index fcd1a3411..000000000 --- a/backend/packages/harness/deerflow/persistence/run/sql.py +++ /dev/null @@ -1,255 +0,0 @@ -"""SQLAlchemy-backed RunStore implementation. - -Each method acquires and releases its own short-lived session. -Run status updates happen from background workers that may live -minutes -- we don't hold connections across long execution. -""" - -from __future__ import annotations - -import json -from datetime import UTC, datetime -from typing import Any - -from sqlalchemy import func, select, update -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker - -from deerflow.persistence.run.model import RunRow -from deerflow.runtime.runs.store.base import RunStore -from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id - - -class RunRepository(RunStore): - def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: - self._sf = session_factory - - @staticmethod - def _safe_json(obj: Any) -> Any: - """Ensure obj is JSON-serializable. Falls back to model_dump() or str().""" - if obj is None: - return None - if isinstance(obj, (str, int, float, bool)): - return obj - if isinstance(obj, dict): - return {k: RunRepository._safe_json(v) for k, v in obj.items()} - if isinstance(obj, (list, tuple)): - return [RunRepository._safe_json(v) for v in obj] - if hasattr(obj, "model_dump"): - try: - return obj.model_dump() - except Exception: - pass - if hasattr(obj, "dict"): - try: - return obj.dict() - except Exception: - pass - try: - json.dumps(obj) - return obj - except (TypeError, ValueError): - return str(obj) - - @staticmethod - def _row_to_dict(row: RunRow) -> dict[str, Any]: - d = row.to_dict() - # Remap JSON columns to match RunStore interface - d["metadata"] = d.pop("metadata_json", {}) - d["kwargs"] = d.pop("kwargs_json", {}) - # Convert datetime to ISO string for consistency with MemoryRunStore - for key in ("created_at", "updated_at"): - val = d.get(key) - if isinstance(val, datetime): - d[key] = val.isoformat() - return d - - async def put( - self, - run_id, - *, - thread_id, - assistant_id=None, - user_id: str | None | _AutoSentinel = AUTO, - status="pending", - multitask_strategy="reject", - metadata=None, - kwargs=None, - error=None, - created_at=None, - follow_up_to_run_id=None, - ): - resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put") - now = datetime.now(UTC) - row = RunRow( - run_id=run_id, - thread_id=thread_id, - assistant_id=assistant_id, - user_id=resolved_user_id, - status=status, - multitask_strategy=multitask_strategy, - metadata_json=self._safe_json(metadata) or {}, - kwargs_json=self._safe_json(kwargs) or {}, - error=error, - follow_up_to_run_id=follow_up_to_run_id, - created_at=datetime.fromisoformat(created_at) if created_at else now, - updated_at=now, - ) - async with self._sf() as session: - session.add(row) - await session.commit() - - async def get( - self, - run_id, - *, - user_id: str | None | _AutoSentinel = AUTO, - ): - resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.get") - async with self._sf() as session: - row = await session.get(RunRow, run_id) - if row is None: - return None - if resolved_user_id is not None and row.user_id != resolved_user_id: - return None - return self._row_to_dict(row) - - async def list_by_thread( - self, - thread_id, - *, - user_id: str | None | _AutoSentinel = AUTO, - limit=100, - ): - resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.list_by_thread") - stmt = select(RunRow).where(RunRow.thread_id == thread_id) - if resolved_user_id is not None: - stmt = stmt.where(RunRow.user_id == resolved_user_id) - stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit) - async with self._sf() as session: - result = await session.execute(stmt) - return [self._row_to_dict(r) for r in result.scalars()] - - async def update_status(self, run_id, status, *, error=None): - values: dict[str, Any] = {"status": status, "updated_at": datetime.now(UTC)} - if error is not None: - values["error"] = error - async with self._sf() as session: - await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) - await session.commit() - - async def delete( - self, - run_id, - *, - user_id: str | None | _AutoSentinel = AUTO, - ): - resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.delete") - async with self._sf() as session: - row = await session.get(RunRow, run_id) - if row is None: - return - if resolved_user_id is not None and row.user_id != resolved_user_id: - return - await session.delete(row) - await session.commit() - - async def list_pending(self, *, before=None): - if before is None: - before_dt = datetime.now(UTC) - elif isinstance(before, datetime): - before_dt = before - else: - before_dt = datetime.fromisoformat(before) - stmt = select(RunRow).where(RunRow.status == "pending", RunRow.created_at <= before_dt).order_by(RunRow.created_at.asc()) - async with self._sf() as session: - result = await session.execute(stmt) - return [self._row_to_dict(r) for r in result.scalars()] - - async def update_run_completion( - self, - run_id: str, - *, - status: str, - total_input_tokens: int = 0, - total_output_tokens: int = 0, - total_tokens: int = 0, - llm_call_count: int = 0, - lead_agent_tokens: int = 0, - subagent_tokens: int = 0, - middleware_tokens: int = 0, - message_count: int = 0, - last_ai_message: str | None = None, - first_human_message: str | None = None, - error: str | None = None, - ) -> None: - """Update status + token usage + convenience fields on run completion.""" - values: dict[str, Any] = { - "status": status, - "total_input_tokens": total_input_tokens, - "total_output_tokens": total_output_tokens, - "total_tokens": total_tokens, - "llm_call_count": llm_call_count, - "lead_agent_tokens": lead_agent_tokens, - "subagent_tokens": subagent_tokens, - "middleware_tokens": middleware_tokens, - "message_count": message_count, - "updated_at": datetime.now(UTC), - } - if last_ai_message is not None: - values["last_ai_message"] = last_ai_message[:2000] - if first_human_message is not None: - values["first_human_message"] = first_human_message[:2000] - if error is not None: - values["error"] = error - async with self._sf() as session: - await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) - await session.commit() - - async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: - """Aggregate token usage via a single SQL GROUP BY query.""" - _completed = RunRow.status.in_(("success", "error")) - _thread = RunRow.thread_id == thread_id - - stmt = ( - select( - func.coalesce(RunRow.model_name, "unknown").label("model"), - func.count().label("runs"), - func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"), - func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"), - func.coalesce(func.sum(RunRow.total_output_tokens), 0).label("total_output_tokens"), - func.coalesce(func.sum(RunRow.lead_agent_tokens), 0).label("lead_agent"), - func.coalesce(func.sum(RunRow.subagent_tokens), 0).label("subagent"), - func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"), - ) - .where(_thread, _completed) - .group_by(func.coalesce(RunRow.model_name, "unknown")) - ) - - async with self._sf() as session: - rows = (await session.execute(stmt)).all() - - total_tokens = total_input = total_output = total_runs = 0 - lead_agent = subagent = middleware = 0 - by_model: dict[str, dict] = {} - for r in rows: - by_model[r.model] = {"tokens": r.total_tokens, "runs": r.runs} - total_tokens += r.total_tokens - total_input += r.total_input_tokens - total_output += r.total_output_tokens - total_runs += r.runs - lead_agent += r.lead_agent - subagent += r.subagent - middleware += r.middleware - - return { - "total_tokens": total_tokens, - "total_input_tokens": total_input, - "total_output_tokens": total_output, - "total_runs": total_runs, - "by_model": by_model, - "by_caller": { - "lead_agent": lead_agent, - "subagent": subagent, - "middleware": middleware, - }, - } diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py b/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py deleted file mode 100644 index 080ce8093..000000000 --- a/backend/packages/harness/deerflow/persistence/thread_meta/__init__.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Thread metadata persistence — ORM, abstract store, and concrete implementations.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from deerflow.persistence.thread_meta.base import ThreadMetaStore -from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore -from deerflow.persistence.thread_meta.model import ThreadMetaRow -from deerflow.persistence.thread_meta.sql import ThreadMetaRepository - -if TYPE_CHECKING: - from langgraph.store.base import BaseStore - from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker - -__all__ = [ - "MemoryThreadMetaStore", - "ThreadMetaRepository", - "ThreadMetaRow", - "ThreadMetaStore", - "make_thread_store", -] - - -def make_thread_store( - session_factory: async_sessionmaker[AsyncSession] | None, - store: BaseStore | None = None, -) -> ThreadMetaStore: - """Create the appropriate ThreadMetaStore based on available backends. - - Returns a SQL-backed repository when a session factory is available, - otherwise falls back to the in-memory LangGraph Store implementation. - """ - if session_factory is not None: - return ThreadMetaRepository(session_factory) - if store is None: - raise ValueError("make_thread_store requires either a session_factory (SQL) or a store (memory)") - return MemoryThreadMetaStore(store) diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/base.py b/backend/packages/harness/deerflow/persistence/thread_meta/base.py deleted file mode 100644 index c87c10a16..000000000 --- a/backend/packages/harness/deerflow/persistence/thread_meta/base.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Abstract interface for thread metadata storage. - -Implementations: -- ThreadMetaRepository: SQL-backed (sqlite / postgres via SQLAlchemy) -- MemoryThreadMetaStore: wraps LangGraph BaseStore (memory mode) - -All mutating and querying methods accept a ``user_id`` parameter with -three-state semantics (see :mod:`deerflow.runtime.user_context`): - -- ``AUTO`` (default): resolve from the request-scoped contextvar. -- Explicit ``str``: use the provided value verbatim. -- Explicit ``None``: bypass owner filtering (migration/CLI only). -""" - -from __future__ import annotations - -import abc - -from deerflow.runtime.user_context import AUTO, _AutoSentinel - - -class ThreadMetaStore(abc.ABC): - @abc.abstractmethod - async def create( - self, - thread_id: str, - *, - assistant_id: str | None = None, - user_id: str | None | _AutoSentinel = AUTO, - display_name: str | None = None, - metadata: dict | None = None, - ) -> dict: - pass - - @abc.abstractmethod - async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None: - pass - - @abc.abstractmethod - async def search( - self, - *, - metadata: dict | None = None, - status: str | None = None, - limit: int = 100, - offset: int = 0, - user_id: str | None | _AutoSentinel = AUTO, - ) -> list[dict]: - pass - - @abc.abstractmethod - async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: - pass - - @abc.abstractmethod - async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: - pass - - @abc.abstractmethod - async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None: - """Merge ``metadata`` into the thread's metadata field. - - Existing keys are overwritten by the new values; keys absent from - ``metadata`` are preserved. No-op if the thread does not exist - or the owner check fails. - """ - pass - - @abc.abstractmethod - async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool: - """Check if ``user_id`` has access to ``thread_id``.""" - pass - - @abc.abstractmethod - async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: - pass diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py deleted file mode 100644 index ccf59ad42..000000000 --- a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py +++ /dev/null @@ -1,149 +0,0 @@ -"""In-memory ThreadMetaStore backed by LangGraph BaseStore. - -Used when database.backend=memory. Delegates to the LangGraph Store's -``("threads",)`` namespace — the same namespace used by the Gateway -router for thread records. -""" - -from __future__ import annotations - -import time -from typing import Any - -from langgraph.store.base import BaseStore - -from deerflow.persistence.thread_meta.base import ThreadMetaStore -from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id - -THREADS_NS: tuple[str, ...] = ("threads",) - - -class MemoryThreadMetaStore(ThreadMetaStore): - def __init__(self, store: BaseStore) -> None: - self._store = store - - async def _get_owned_record( - self, - thread_id: str, - user_id: str | None | _AutoSentinel, - method_name: str, - ) -> dict | None: - """Fetch a record and verify ownership. Returns a mutable copy, or None.""" - resolved = resolve_user_id(user_id, method_name=method_name) - item = await self._store.aget(THREADS_NS, thread_id) - if item is None: - return None - record = dict(item.value) - if resolved is not None and record.get("user_id") != resolved: - return None - return record - - async def create( - self, - thread_id: str, - *, - assistant_id: str | None = None, - user_id: str | None | _AutoSentinel = AUTO, - display_name: str | None = None, - metadata: dict | None = None, - ) -> dict: - resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.create") - now = time.time() - record: dict[str, Any] = { - "thread_id": thread_id, - "assistant_id": assistant_id, - "user_id": resolved_user_id, - "display_name": display_name, - "status": "idle", - "metadata": metadata or {}, - "values": {}, - "created_at": now, - "updated_at": now, - } - await self._store.aput(THREADS_NS, thread_id, record) - return record - - async def get(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> dict | None: - return await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.get") - - async def search( - self, - *, - metadata: dict | None = None, - status: str | None = None, - limit: int = 100, - offset: int = 0, - user_id: str | None | _AutoSentinel = AUTO, - ) -> list[dict]: - resolved_user_id = resolve_user_id(user_id, method_name="MemoryThreadMetaStore.search") - filter_dict: dict[str, Any] = {} - if metadata: - filter_dict.update(metadata) - if status: - filter_dict["status"] = status - if resolved_user_id is not None: - filter_dict["user_id"] = resolved_user_id - - items = await self._store.asearch( - THREADS_NS, - filter=filter_dict or None, - limit=limit, - offset=offset, - ) - return [self._item_to_dict(item) for item in items] - - async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool: - item = await self._store.aget(THREADS_NS, thread_id) - if item is None: - return not require_existing - record_user_id = item.value.get("user_id") - if record_user_id is None: - return True - return record_user_id == user_id - - async def update_display_name(self, thread_id: str, display_name: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: - record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_display_name") - if record is None: - return - record["display_name"] = display_name - record["updated_at"] = time.time() - await self._store.aput(THREADS_NS, thread_id, record) - - async def update_status(self, thread_id: str, status: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: - record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_status") - if record is None: - return - record["status"] = status - record["updated_at"] = time.time() - await self._store.aput(THREADS_NS, thread_id, record) - - async def update_metadata(self, thread_id: str, metadata: dict, *, user_id: str | None | _AutoSentinel = AUTO) -> None: - record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_metadata") - if record is None: - return - merged = dict(record.get("metadata") or {}) - merged.update(metadata) - record["metadata"] = merged - record["updated_at"] = time.time() - await self._store.aput(THREADS_NS, thread_id, record) - - async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: - record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete") - if record is None: - return - await self._store.adelete(THREADS_NS, thread_id) - - @staticmethod - def _item_to_dict(item) -> dict[str, Any]: - """Convert a Store SearchItem to the dict format expected by callers.""" - val = item.value - return { - "thread_id": item.key, - "assistant_id": val.get("assistant_id"), - "user_id": val.get("user_id"), - "display_name": val.get("display_name"), - "status": val.get("status", "idle"), - "metadata": val.get("metadata", {}), - "created_at": str(val.get("created_at", "")), - "updated_at": str(val.get("updated_at", "")), - } diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/model.py b/backend/packages/harness/deerflow/persistence/thread_meta/model.py deleted file mode 100644 index fe15315e1..000000000 --- a/backend/packages/harness/deerflow/persistence/thread_meta/model.py +++ /dev/null @@ -1,23 +0,0 @@ -"""ORM model for thread metadata.""" - -from __future__ import annotations - -from datetime import UTC, datetime - -from sqlalchemy import JSON, DateTime, String -from sqlalchemy.orm import Mapped, mapped_column - -from deerflow.persistence.base import Base - - -class ThreadMetaRow(Base): - __tablename__ = "threads_meta" - - thread_id: Mapped[str] = mapped_column(String(64), primary_key=True) - assistant_id: Mapped[str | None] = mapped_column(String(128), index=True) - user_id: Mapped[str | None] = mapped_column(String(64), index=True) - display_name: Mapped[str | None] = mapped_column(String(256)) - status: Mapped[str] = mapped_column(String(20), default="idle") - metadata_json: Mapped[dict] = mapped_column(JSON, default=dict) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC)) - updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC)) diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py deleted file mode 100644 index 688fbb247..000000000 --- a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py +++ /dev/null @@ -1,217 +0,0 @@ -"""SQLAlchemy-backed thread metadata repository.""" - -from __future__ import annotations - -from datetime import UTC, datetime -from typing import Any - -from sqlalchemy import select, update -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker - -from deerflow.persistence.thread_meta.base import ThreadMetaStore -from deerflow.persistence.thread_meta.model import ThreadMetaRow -from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id - - -class ThreadMetaRepository(ThreadMetaStore): - def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: - self._sf = session_factory - - @staticmethod - def _row_to_dict(row: ThreadMetaRow) -> dict[str, Any]: - d = row.to_dict() - d["metadata"] = d.pop("metadata_json", {}) - for key in ("created_at", "updated_at"): - val = d.get(key) - if isinstance(val, datetime): - d[key] = val.isoformat() - return d - - async def create( - self, - thread_id: str, - *, - assistant_id: str | None = None, - user_id: str | None | _AutoSentinel = AUTO, - display_name: str | None = None, - metadata: dict | None = None, - ) -> dict: - # Auto-resolve user_id from contextvar when AUTO; explicit None - # creates an orphan row (used by migration scripts). - resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.create") - now = datetime.now(UTC) - row = ThreadMetaRow( - thread_id=thread_id, - assistant_id=assistant_id, - user_id=resolved_user_id, - display_name=display_name, - metadata_json=metadata or {}, - created_at=now, - updated_at=now, - ) - async with self._sf() as session: - session.add(row) - await session.commit() - await session.refresh(row) - return self._row_to_dict(row) - - async def get( - self, - thread_id: str, - *, - user_id: str | None | _AutoSentinel = AUTO, - ) -> dict | None: - resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.get") - async with self._sf() as session: - row = await session.get(ThreadMetaRow, thread_id) - if row is None: - return None - # Enforce owner filter unless explicitly bypassed (user_id=None). - if resolved_user_id is not None and row.user_id != resolved_user_id: - return None - return self._row_to_dict(row) - - async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool: - """Check if ``user_id`` has access to ``thread_id``. - - Two modes — one row, two distinct semantics depending on what - the caller is about to do: - - - ``require_existing=False`` (default, permissive): - Returns True for: row missing (untracked legacy thread), - ``row.user_id`` is None (shared / pre-auth data), - or ``row.user_id == user_id``. Use for **read-style** - decorators where treating an untracked thread as accessible - preserves backward-compat. - - - ``require_existing=True`` (strict): - Returns True **only** when the row exists AND - (``row.user_id == user_id`` OR ``row.user_id is None``). - Use for **destructive / mutating** decorators (DELETE, PATCH, - state-update) so a thread that has *already been deleted* - cannot be re-targeted by any caller — closing the - delete-idempotence cross-user gap where the row vanishing - made every other user appear to "own" it. - """ - async with self._sf() as session: - row = await session.get(ThreadMetaRow, thread_id) - if row is None: - return not require_existing - if row.user_id is None: - return True - return row.user_id == user_id - - async def search( - self, - *, - metadata: dict | None = None, - status: str | None = None, - limit: int = 100, - offset: int = 0, - user_id: str | None | _AutoSentinel = AUTO, - ) -> list[dict]: - """Search threads with optional metadata and status filters. - - Owner filter is enforced by default: caller must be in a user - context. Pass ``user_id=None`` to bypass (migration/CLI). - """ - resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.search") - stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc()) - if resolved_user_id is not None: - stmt = stmt.where(ThreadMetaRow.user_id == resolved_user_id) - if status: - stmt = stmt.where(ThreadMetaRow.status == status) - - if metadata: - # When metadata filter is active, fetch a larger window and filter - # in Python. TODO(Phase 2): use JSON DB operators (Postgres @>, - # SQLite json_extract) for server-side filtering. - stmt = stmt.limit(limit * 5 + offset) - async with self._sf() as session: - result = await session.execute(stmt) - rows = [self._row_to_dict(r) for r in result.scalars()] - rows = [r for r in rows if all(r.get("metadata", {}).get(k) == v for k, v in metadata.items())] - return rows[offset : offset + limit] - else: - stmt = stmt.limit(limit).offset(offset) - async with self._sf() as session: - result = await session.execute(stmt) - return [self._row_to_dict(r) for r in result.scalars()] - - async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_user_id: str | None) -> bool: - """Return True if the row exists and is owned (or filter bypassed).""" - if resolved_user_id is None: - return True # explicit bypass - row = await session.get(ThreadMetaRow, thread_id) - return row is not None and row.user_id == resolved_user_id - - async def update_display_name( - self, - thread_id: str, - display_name: str, - *, - user_id: str | None | _AutoSentinel = AUTO, - ) -> None: - """Update the display_name (title) for a thread.""" - resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_display_name") - async with self._sf() as session: - if not await self._check_ownership(session, thread_id, resolved_user_id): - return - await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC))) - await session.commit() - - async def update_status( - self, - thread_id: str, - status: str, - *, - user_id: str | None | _AutoSentinel = AUTO, - ) -> None: - resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_status") - async with self._sf() as session: - if not await self._check_ownership(session, thread_id, resolved_user_id): - return - await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC))) - await session.commit() - - async def update_metadata( - self, - thread_id: str, - metadata: dict, - *, - user_id: str | None | _AutoSentinel = AUTO, - ) -> None: - """Merge ``metadata`` into ``metadata_json``. - - Read-modify-write inside a single session/transaction so concurrent - callers see consistent state. No-op if the row does not exist or - the user_id check fails. - """ - resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_metadata") - async with self._sf() as session: - row = await session.get(ThreadMetaRow, thread_id) - if row is None: - return - if resolved_user_id is not None and row.user_id != resolved_user_id: - return - merged = dict(row.metadata_json or {}) - merged.update(metadata) - row.metadata_json = merged - row.updated_at = datetime.now(UTC) - await session.commit() - - async def delete( - self, - thread_id: str, - *, - user_id: str | None | _AutoSentinel = AUTO, - ) -> None: - resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.delete") - async with self._sf() as session: - row = await session.get(ThreadMetaRow, thread_id) - if row is None: - return - if resolved_user_id is not None and row.user_id != resolved_user_id: - return - await session.delete(row) - await session.commit() diff --git a/backend/packages/harness/deerflow/persistence/user/__init__.py b/backend/packages/harness/deerflow/persistence/user/__init__.py deleted file mode 100644 index a60eeef2c..000000000 --- a/backend/packages/harness/deerflow/persistence/user/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""User storage subpackage. - -Holds the ORM model for the ``users`` table. The concrete repository -implementation (``SQLiteUserRepository``) lives in the app layer -(``app.gateway.auth.repositories.sqlite``) because it converts -between the ORM row and the auth module's pydantic ``User`` class. -This keeps the harness package free of any dependency on app code. -""" - -from deerflow.persistence.user.model import UserRow - -__all__ = ["UserRow"] diff --git a/backend/packages/harness/deerflow/persistence/user/model.py b/backend/packages/harness/deerflow/persistence/user/model.py deleted file mode 100644 index 130d4bfcb..000000000 --- a/backend/packages/harness/deerflow/persistence/user/model.py +++ /dev/null @@ -1,59 +0,0 @@ -"""ORM model for the users table. - -Lives in the harness persistence package so it is picked up by -``Base.metadata.create_all()`` alongside ``threads_meta``, ``runs``, -``run_events``, and ``feedback``. Using the shared engine means: - -- One SQLite/Postgres database, one connection pool -- One schema initialisation codepath -- Consistent async sessions across auth and persistence reads -""" - -from __future__ import annotations - -from datetime import UTC, datetime - -from sqlalchemy import Boolean, DateTime, Index, String, text -from sqlalchemy.orm import Mapped, mapped_column - -from deerflow.persistence.base import Base - - -class UserRow(Base): - __tablename__ = "users" - - # UUIDs are stored as 36-char strings for cross-backend portability. - id: Mapped[str] = mapped_column(String(36), primary_key=True) - - email: Mapped[str] = mapped_column(String(320), unique=True, nullable=False, index=True) - password_hash: Mapped[str | None] = mapped_column(String(128), nullable=True) - - # "admin" | "user" — kept as plain string to avoid ALTER TABLE pain - # when new roles are introduced. - system_role: Mapped[str] = mapped_column(String(16), nullable=False, default="user") - - created_at: Mapped[datetime] = mapped_column( - DateTime(timezone=True), - nullable=False, - default=lambda: datetime.now(UTC), - ) - - # OAuth linkage (optional). A partial unique index enforces one - # account per (provider, oauth_id) pair, leaving NULL/NULL rows - # unconstrained so plain password accounts can coexist. - oauth_provider: Mapped[str | None] = mapped_column(String(32), nullable=True) - oauth_id: Mapped[str | None] = mapped_column(String(128), nullable=True) - - # Auth lifecycle flags - needs_setup: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - token_version: Mapped[int] = mapped_column(nullable=False, default=0) - - __table_args__ = ( - Index( - "idx_users_oauth_identity", - "oauth_provider", - "oauth_id", - unique=True, - sqlite_where=text("oauth_provider IS NOT NULL AND oauth_id IS NOT NULL"), - ), - ) diff --git a/backend/packages/harness/deerflow/runtime/checkpointer/__init__.py b/backend/packages/harness/deerflow/runtime/checkpointer/__init__.py deleted file mode 100644 index 7bb0019a2..000000000 --- a/backend/packages/harness/deerflow/runtime/checkpointer/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .async_provider import make_checkpointer -from .provider import checkpointer_context, get_checkpointer, reset_checkpointer - -__all__ = [ - "get_checkpointer", - "reset_checkpointer", - "checkpointer_context", - "make_checkpointer", -] diff --git a/backend/packages/harness/deerflow/runtime/checkpointer/async_provider.py b/backend/packages/harness/deerflow/runtime/checkpointer/async_provider.py deleted file mode 100644 index 21c747b45..000000000 --- a/backend/packages/harness/deerflow/runtime/checkpointer/async_provider.py +++ /dev/null @@ -1,159 +0,0 @@ -"""Async checkpointer factory. - -Provides an **async context manager** for long-running async servers that need -proper resource cleanup. - -Supported backends: memory, sqlite, postgres. - -Usage (e.g. FastAPI lifespan):: - - from deerflow.runtime.checkpointer.async_provider import make_checkpointer - - async with make_checkpointer() as checkpointer: - app.state.checkpointer = checkpointer # InMemorySaver if not configured - -For sync usage see :mod:`deerflow.runtime.checkpointer.provider`. -""" - -from __future__ import annotations - -import asyncio -import contextlib -import logging -from collections.abc import AsyncIterator - -from langgraph.types import Checkpointer - -from deerflow.config.app_config import get_app_config -from deerflow.runtime.checkpointer.provider import ( - POSTGRES_CONN_REQUIRED, - POSTGRES_INSTALL, - SQLITE_INSTALL, -) -from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str - -logger = logging.getLogger(__name__) - -# --------------------------------------------------------------------------- -# Async factory -# --------------------------------------------------------------------------- - - -@contextlib.asynccontextmanager -async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]: - """Async context manager that constructs and tears down a checkpointer.""" - if config.type == "memory": - from langgraph.checkpoint.memory import InMemorySaver - - yield InMemorySaver() - return - - if config.type == "sqlite": - try: - from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver - except ImportError as exc: - raise ImportError(SQLITE_INSTALL) from exc - - conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db") - await asyncio.to_thread(ensure_sqlite_parent_dir, conn_str) - async with AsyncSqliteSaver.from_conn_string(conn_str) as saver: - await saver.setup() - yield saver - return - - if config.type == "postgres": - try: - from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver - except ImportError as exc: - raise ImportError(POSTGRES_INSTALL) from exc - - if not config.connection_string: - raise ValueError(POSTGRES_CONN_REQUIRED) - - async with AsyncPostgresSaver.from_conn_string(config.connection_string) as saver: - await saver.setup() - yield saver - return - - raise ValueError(f"Unknown checkpointer type: {config.type!r}") - - -# --------------------------------------------------------------------------- -# Public async context manager -# --------------------------------------------------------------------------- - - -@contextlib.asynccontextmanager -async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpointer]: - """Async context manager that constructs a checkpointer from unified DatabaseConfig.""" - if db_config.backend == "memory": - from langgraph.checkpoint.memory import InMemorySaver - - yield InMemorySaver() - return - - if db_config.backend == "sqlite": - try: - from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver - except ImportError as exc: - raise ImportError(SQLITE_INSTALL) from exc - - conn_str = db_config.checkpointer_sqlite_path - ensure_sqlite_parent_dir(conn_str) - async with AsyncSqliteSaver.from_conn_string(conn_str) as saver: - await saver.setup() - yield saver - return - - if db_config.backend == "postgres": - try: - from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver - except ImportError as exc: - raise ImportError(POSTGRES_INSTALL) from exc - - if not db_config.postgres_url: - raise ValueError("database.postgres_url is required for the postgres backend") - - async with AsyncPostgresSaver.from_conn_string(db_config.postgres_url) as saver: - await saver.setup() - yield saver - return - - raise ValueError(f"Unknown database backend: {db_config.backend!r}") - - -@contextlib.asynccontextmanager -async def make_checkpointer() -> AsyncIterator[Checkpointer]: - """Async context manager that yields a checkpointer for the caller's lifetime. - Resources are opened on enter and closed on exit -- no global state:: - - async with make_checkpointer() as checkpointer: - app.state.checkpointer = checkpointer - - Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*. - - Priority: - 1. Legacy ``checkpointer:`` config section (backward compatible) - 2. Unified ``database:`` config section - 3. Default InMemorySaver - """ - - config = get_app_config() - - # Legacy: standalone checkpointer config takes precedence - if config.checkpointer is not None: - async with _async_checkpointer(config.checkpointer) as saver: - yield saver - return - - # Unified database config - db_config = getattr(config, "database", None) - if db_config is not None and db_config.backend != "memory": - async with _async_checkpointer_from_database(db_config) as saver: - yield saver - return - - # Default: in-memory - from langgraph.checkpoint.memory import InMemorySaver - - yield InMemorySaver() diff --git a/backend/packages/harness/deerflow/runtime/checkpointer/provider.py b/backend/packages/harness/deerflow/runtime/checkpointer/provider.py deleted file mode 100644 index 59f8b1ab2..000000000 --- a/backend/packages/harness/deerflow/runtime/checkpointer/provider.py +++ /dev/null @@ -1,191 +0,0 @@ -"""Sync checkpointer factory. - -Provides a **sync singleton** and a **sync context manager** for LangGraph -graph compilation and CLI tools. - -Supported backends: memory, sqlite, postgres. - -Usage:: - - from deerflow.runtime.checkpointer.provider import get_checkpointer, checkpointer_context - - # Singleton — reused across calls, closed on process exit - cp = get_checkpointer() - - # One-shot — fresh connection, closed on block exit - with checkpointer_context() as cp: - graph.invoke(input, config={"configurable": {"thread_id": "1"}}) -""" - -from __future__ import annotations - -import contextlib -import logging -from collections.abc import Iterator - -from langgraph.types import Checkpointer - -from deerflow.config.app_config import get_app_config -from deerflow.config.checkpointer_config import CheckpointerConfig -from deerflow.runtime.store._sqlite_utils import resolve_sqlite_conn_str - -logger = logging.getLogger(__name__) - -# --------------------------------------------------------------------------- -# Error message constants — imported by aio.provider too -# --------------------------------------------------------------------------- - -SQLITE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite checkpointer. Install it with: uv add langgraph-checkpoint-sqlite" -POSTGRES_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL checkpointer. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool" -POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend" - -# --------------------------------------------------------------------------- -# Sync factory -# --------------------------------------------------------------------------- - - -@contextlib.contextmanager -def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]: - """Context manager that creates and tears down a sync checkpointer. - - Returns a configured ``Checkpointer`` instance. Resource cleanup for any - underlying connections or pools is handled by higher-level helpers in - this module (such as the singleton factory or context manager); this - function does not return a separate cleanup callback. - """ - if config.type == "memory": - from langgraph.checkpoint.memory import InMemorySaver - - logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)") - yield InMemorySaver() - return - - if config.type == "sqlite": - try: - from langgraph.checkpoint.sqlite import SqliteSaver - except ImportError as exc: - raise ImportError(SQLITE_INSTALL) from exc - - conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db") - with SqliteSaver.from_conn_string(conn_str) as saver: - saver.setup() - logger.info("Checkpointer: using SqliteSaver (%s)", conn_str) - yield saver - return - - if config.type == "postgres": - try: - from langgraph.checkpoint.postgres import PostgresSaver - except ImportError as exc: - raise ImportError(POSTGRES_INSTALL) from exc - - if not config.connection_string: - raise ValueError(POSTGRES_CONN_REQUIRED) - - with PostgresSaver.from_conn_string(config.connection_string) as saver: - saver.setup() - logger.info("Checkpointer: using PostgresSaver") - yield saver - return - - raise ValueError(f"Unknown checkpointer type: {config.type!r}") - - -# --------------------------------------------------------------------------- -# Sync singleton -# --------------------------------------------------------------------------- - -_checkpointer: Checkpointer | None = None -_checkpointer_ctx = None # open context manager keeping the connection alive - - -def get_checkpointer() -> Checkpointer: - """Return the global sync checkpointer singleton, creating it on first call. - - Returns an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*. - - Raises: - ImportError: If the required package for the configured backend is not installed. - ValueError: If ``connection_string`` is missing for a backend that requires it. - """ - global _checkpointer, _checkpointer_ctx - - if _checkpointer is not None: - return _checkpointer - - # Ensure app config is loaded before checking checkpointer config - # This prevents returning InMemorySaver when config.yaml actually has a checkpointer section - # but hasn't been loaded yet - from deerflow.config.app_config import _app_config - from deerflow.config.checkpointer_config import get_checkpointer_config - - config = get_checkpointer_config() - - if config is None and _app_config is None: - # Only load app config lazily when neither the app config nor an explicit - # checkpointer config has been initialized yet. This keeps tests that - # intentionally set the global checkpointer config isolated from any - # ambient config.yaml on disk. - try: - get_app_config() - except FileNotFoundError: - # In test environments without config.yaml, this is expected. - pass - config = get_checkpointer_config() - if config is None: - from langgraph.checkpoint.memory import InMemorySaver - - logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)") - _checkpointer = InMemorySaver() - return _checkpointer - - _checkpointer_ctx = _sync_checkpointer_cm(config) - _checkpointer = _checkpointer_ctx.__enter__() - - return _checkpointer - - -def reset_checkpointer() -> None: - """Reset the sync singleton, forcing recreation on the next call. - - Closes any open backend connections and clears the cached instance. - Useful in tests or after a configuration change. - """ - global _checkpointer, _checkpointer_ctx - if _checkpointer_ctx is not None: - try: - _checkpointer_ctx.__exit__(None, None, None) - except Exception: - logger.warning("Error during checkpointer cleanup", exc_info=True) - _checkpointer_ctx = None - _checkpointer = None - - -# --------------------------------------------------------------------------- -# Sync context manager -# --------------------------------------------------------------------------- - - -@contextlib.contextmanager -def checkpointer_context() -> Iterator[Checkpointer]: - """Sync context manager that yields a checkpointer and cleans up on exit. - - Unlike :func:`get_checkpointer`, this does **not** cache the instance — - each ``with`` block creates and destroys its own connection. Use it in - CLI scripts or tests where you want deterministic cleanup:: - - with checkpointer_context() as cp: - graph.invoke(input, config={"configurable": {"thread_id": "1"}}) - - Yields an ``InMemorySaver`` when no checkpointer is configured in *config.yaml*. - """ - - config = get_app_config() - if config.checkpointer is None: - from langgraph.checkpoint.memory import InMemorySaver - - yield InMemorySaver() - return - - with _sync_checkpointer_cm(config.checkpointer) as saver: - yield saver diff --git a/backend/packages/harness/deerflow/runtime/events/__init__.py b/backend/packages/harness/deerflow/runtime/events/__init__.py deleted file mode 100644 index 0da8fabe5..000000000 --- a/backend/packages/harness/deerflow/runtime/events/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from deerflow.runtime.events.store.base import RunEventStore -from deerflow.runtime.events.store.memory import MemoryRunEventStore - -__all__ = ["MemoryRunEventStore", "RunEventStore"] diff --git a/backend/packages/harness/deerflow/runtime/events/store/__init__.py b/backend/packages/harness/deerflow/runtime/events/store/__init__.py deleted file mode 100644 index 55f0dd33f..000000000 --- a/backend/packages/harness/deerflow/runtime/events/store/__init__.py +++ /dev/null @@ -1,26 +0,0 @@ -from deerflow.runtime.events.store.base import RunEventStore -from deerflow.runtime.events.store.memory import MemoryRunEventStore - - -def make_run_event_store(config=None) -> RunEventStore: - """Create a RunEventStore based on run_events.backend configuration.""" - if config is None or config.backend == "memory": - return MemoryRunEventStore() - if config.backend == "db": - from deerflow.persistence.engine import get_session_factory - - sf = get_session_factory() - if sf is None: - # database.backend=memory but run_events.backend=db -> fallback - return MemoryRunEventStore() - from deerflow.runtime.events.store.db import DbRunEventStore - - return DbRunEventStore(sf, max_trace_content=config.max_trace_content) - if config.backend == "jsonl": - from deerflow.runtime.events.store.jsonl import JsonlRunEventStore - - return JsonlRunEventStore() - raise ValueError(f"Unknown run_events backend: {config.backend!r}") - - -__all__ = ["MemoryRunEventStore", "RunEventStore", "make_run_event_store"] diff --git a/backend/packages/harness/deerflow/runtime/events/store/base.py b/backend/packages/harness/deerflow/runtime/events/store/base.py deleted file mode 100644 index df5136ba5..000000000 --- a/backend/packages/harness/deerflow/runtime/events/store/base.py +++ /dev/null @@ -1,109 +0,0 @@ -"""Abstract interface for run event storage. - -RunEventStore is the unified storage interface for run event streams. -Messages (frontend display) and execution traces (debugging/audit) go -through the same interface, distinguished by the ``category`` field. - -Implementations: -- MemoryRunEventStore: in-memory dict (development, tests) -- Future: DB-backed store (SQLAlchemy ORM), JSONL file store -""" - -from __future__ import annotations - -import abc - - -class RunEventStore(abc.ABC): - """Run event stream storage interface. - - All implementations must guarantee: - 1. put() events are retrievable in subsequent queries - 2. seq is strictly increasing within the same thread - 3. list_messages() only returns category="message" events - 4. list_events() returns all events for the specified run - 5. Returned dicts match the RunEvent field structure - """ - - @abc.abstractmethod - async def put( - self, - *, - thread_id: str, - run_id: str, - event_type: str, - category: str, - content: str | dict = "", - metadata: dict | None = None, - created_at: str | None = None, - ) -> dict: - """Write an event, auto-assign seq, return the complete record.""" - - @abc.abstractmethod - async def put_batch(self, events: list[dict]) -> list[dict]: - """Batch-write events. Used by RunJournal flush buffer. - - Each dict's keys match put()'s keyword arguments. - Returns complete records with seq assigned. - """ - - @abc.abstractmethod - async def list_messages( - self, - thread_id: str, - *, - limit: int = 50, - before_seq: int | None = None, - after_seq: int | None = None, - ) -> list[dict]: - """Return displayable messages (category=message) for a thread, ordered by seq ascending. - - Supports bidirectional cursor pagination: - - before_seq: return the last ``limit`` records with seq < before_seq (ascending) - - after_seq: return the first ``limit`` records with seq > after_seq (ascending) - - neither: return the latest ``limit`` records (ascending) - """ - - @abc.abstractmethod - async def list_events( - self, - thread_id: str, - run_id: str, - *, - event_types: list[str] | None = None, - limit: int = 500, - ) -> list[dict]: - """Return the full event stream for a run, ordered by seq ascending. - - Optionally filter by event_types. - """ - - @abc.abstractmethod - async def list_messages_by_run( - self, - thread_id: str, - run_id: str, - *, - limit: int = 50, - before_seq: int | None = None, - after_seq: int | None = None, - ) -> list[dict]: - """Return displayable messages (category=message) for a specific run, ordered by seq ascending. - - Supports bidirectional cursor pagination: - - after_seq: return the first ``limit`` records with seq > after_seq (ascending) - - before_seq: return the last ``limit`` records with seq < before_seq (ascending) - - neither: return the latest ``limit`` records (ascending) - """ - - @abc.abstractmethod - async def count_messages(self, thread_id: str) -> int: - """Count displayable messages (category=message) in a thread.""" - - @abc.abstractmethod - async def delete_by_thread(self, thread_id: str) -> int: - """Delete all events for a thread. Return the number of deleted events.""" - - @abc.abstractmethod - async def delete_by_run(self, thread_id: str, run_id: str) -> int: - """Delete all events for a specific run. Return the number of deleted events.""" diff --git a/backend/packages/harness/deerflow/runtime/events/store/db.py b/backend/packages/harness/deerflow/runtime/events/store/db.py deleted file mode 100644 index e4a21d006..000000000 --- a/backend/packages/harness/deerflow/runtime/events/store/db.py +++ /dev/null @@ -1,286 +0,0 @@ -"""SQLAlchemy-backed RunEventStore implementation. - -Persists events to the ``run_events`` table. Trace content is truncated -at ``max_trace_content`` bytes to avoid bloating the database. -""" - -from __future__ import annotations - -import json -import logging -from datetime import UTC, datetime - -from sqlalchemy import delete, func, select -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker - -from deerflow.persistence.models.run_event import RunEventRow -from deerflow.runtime.events.store.base import RunEventStore -from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_user_id - -logger = logging.getLogger(__name__) - - -class DbRunEventStore(RunEventStore): - def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, max_trace_content: int = 10240): - self._sf = session_factory - self._max_trace_content = max_trace_content - - @staticmethod - def _row_to_dict(row: RunEventRow) -> dict: - d = row.to_dict() - d["metadata"] = d.pop("event_metadata", {}) - val = d.get("created_at") - if isinstance(val, datetime): - d["created_at"] = val.isoformat() - d.pop("id", None) - # Restore dict content that was JSON-serialized on write - raw = d.get("content", "") - if isinstance(raw, str) and d.get("metadata", {}).get("content_is_dict"): - try: - d["content"] = json.loads(raw) - except (json.JSONDecodeError, ValueError): - # Content looked like JSON (content_is_dict flag) but failed to parse; - # keep the raw string as-is. - logger.debug("Failed to deserialize content as JSON for event seq=%s", d.get("seq")) - return d - - def _truncate_trace(self, category: str, content: str | dict, metadata: dict | None) -> tuple[str | dict, dict]: - if category == "trace": - text = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content - encoded = text.encode("utf-8") - if len(encoded) > self._max_trace_content: - # Truncate by bytes, then decode back (may cut a multi-byte char, so use errors="ignore") - content = encoded[: self._max_trace_content].decode("utf-8", errors="ignore") - metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)} - return content, metadata or {} - - @staticmethod - def _user_id_from_context() -> str | None: - """Soft read of user_id from contextvar for write paths. - - Returns ``None`` (no filter / no stamp) if contextvar is unset, - which is the expected case for background worker writes. HTTP - request writes will have the contextvar set by auth middleware - and get their user_id stamped automatically. - - Coerces ``user.id`` to ``str`` at the boundary: ``User.id`` is - typed as ``UUID`` by the auth layer, but ``run_events.user_id`` - is ``VARCHAR(64)`` and aiosqlite cannot bind a raw UUID object - to a VARCHAR column ("type 'UUID' is not supported") — the - INSERT would silently roll back and the worker would hang. - """ - user = get_current_user() - return str(user.id) if user is not None else None - - async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401 - """Write a single event — low-frequency path only. - - This opens a dedicated transaction with a FOR UPDATE lock to - assign a monotonic *seq*. For high-throughput writes use - :meth:`put_batch`, which acquires the lock once for the whole - batch. Currently the only caller is ``worker.run_agent`` for - the initial ``human_message`` event (once per run). - """ - content, metadata = self._truncate_trace(category, content, metadata) - if isinstance(content, dict): - db_content = json.dumps(content, default=str, ensure_ascii=False) - metadata = {**(metadata or {}), "content_is_dict": True} - else: - db_content = content - user_id = self._user_id_from_context() - async with self._sf() as session: - async with session.begin(): - # Use FOR UPDATE to serialize seq assignment within a thread. - # NOTE: with_for_update() on aggregates is a no-op on SQLite; - # the UNIQUE(thread_id, seq) constraint catches races there. - max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update()) - seq = (max_seq or 0) + 1 - row = RunEventRow( - thread_id=thread_id, - run_id=run_id, - user_id=user_id, - event_type=event_type, - category=category, - content=db_content, - event_metadata=metadata, - seq=seq, - created_at=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC), - ) - session.add(row) - return self._row_to_dict(row) - - async def put_batch(self, events): - if not events: - return [] - user_id = self._user_id_from_context() - async with self._sf() as session: - async with session.begin(): - # Get max seq for the thread (assume all events in batch belong to same thread). - # NOTE: with_for_update() on aggregates is a no-op on SQLite; - # the UNIQUE(thread_id, seq) constraint catches races there. - thread_id = events[0]["thread_id"] - max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update()) - seq = max_seq or 0 - rows = [] - for e in events: - seq += 1 - content = e.get("content", "") - category = e.get("category", "trace") - metadata = e.get("metadata") - content, metadata = self._truncate_trace(category, content, metadata) - if isinstance(content, dict): - db_content = json.dumps(content, default=str, ensure_ascii=False) - metadata = {**(metadata or {}), "content_is_dict": True} - else: - db_content = content - row = RunEventRow( - thread_id=e["thread_id"], - run_id=e["run_id"], - user_id=e.get("user_id", user_id), - event_type=e["event_type"], - category=category, - content=db_content, - event_metadata=metadata, - seq=seq, - created_at=datetime.fromisoformat(e["created_at"]) if e.get("created_at") else datetime.now(UTC), - ) - session.add(row) - rows.append(row) - return [self._row_to_dict(r) for r in rows] - - async def list_messages( - self, - thread_id, - *, - limit=50, - before_seq=None, - after_seq=None, - user_id: str | None | _AutoSentinel = AUTO, - ): - resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages") - stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message") - if resolved_user_id is not None: - stmt = stmt.where(RunEventRow.user_id == resolved_user_id) - if before_seq is not None: - stmt = stmt.where(RunEventRow.seq < before_seq) - if after_seq is not None: - stmt = stmt.where(RunEventRow.seq > after_seq) - - if after_seq is not None: - # Forward pagination: first `limit` records after cursor - stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit) - async with self._sf() as session: - result = await session.execute(stmt) - return [self._row_to_dict(r) for r in result.scalars()] - else: - # before_seq or default (latest): take last `limit` records, return ascending - stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit) - async with self._sf() as session: - result = await session.execute(stmt) - rows = list(result.scalars()) - return [self._row_to_dict(r) for r in reversed(rows)] - - async def list_events( - self, - thread_id, - run_id, - *, - event_types=None, - limit=500, - user_id: str | None | _AutoSentinel = AUTO, - ): - resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_events") - stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id) - if resolved_user_id is not None: - stmt = stmt.where(RunEventRow.user_id == resolved_user_id) - if event_types: - stmt = stmt.where(RunEventRow.event_type.in_(event_types)) - stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit) - async with self._sf() as session: - result = await session.execute(stmt) - return [self._row_to_dict(r) for r in result.scalars()] - - async def list_messages_by_run( - self, - thread_id, - run_id, - *, - limit=50, - before_seq=None, - after_seq=None, - user_id: str | None | _AutoSentinel = AUTO, - ): - resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages_by_run") - stmt = select(RunEventRow).where( - RunEventRow.thread_id == thread_id, - RunEventRow.run_id == run_id, - RunEventRow.category == "message", - ) - if resolved_user_id is not None: - stmt = stmt.where(RunEventRow.user_id == resolved_user_id) - if before_seq is not None: - stmt = stmt.where(RunEventRow.seq < before_seq) - if after_seq is not None: - stmt = stmt.where(RunEventRow.seq > after_seq) - - if after_seq is not None: - stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit) - async with self._sf() as session: - result = await session.execute(stmt) - return [self._row_to_dict(r) for r in result.scalars()] - else: - stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit) - async with self._sf() as session: - result = await session.execute(stmt) - rows = list(result.scalars()) - return [self._row_to_dict(r) for r in reversed(rows)] - - async def count_messages( - self, - thread_id, - *, - user_id: str | None | _AutoSentinel = AUTO, - ): - resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.count_messages") - stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message") - if resolved_user_id is not None: - stmt = stmt.where(RunEventRow.user_id == resolved_user_id) - async with self._sf() as session: - return await session.scalar(stmt) or 0 - - async def delete_by_thread( - self, - thread_id, - *, - user_id: str | None | _AutoSentinel = AUTO, - ): - resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_thread") - async with self._sf() as session: - count_conditions = [RunEventRow.thread_id == thread_id] - if resolved_user_id is not None: - count_conditions.append(RunEventRow.user_id == resolved_user_id) - count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions) - count = await session.scalar(count_stmt) or 0 - if count > 0: - await session.execute(delete(RunEventRow).where(*count_conditions)) - await session.commit() - return count - - async def delete_by_run( - self, - thread_id, - run_id, - *, - user_id: str | None | _AutoSentinel = AUTO, - ): - resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_run") - async with self._sf() as session: - count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id] - if resolved_user_id is not None: - count_conditions.append(RunEventRow.user_id == resolved_user_id) - count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions) - count = await session.scalar(count_stmt) or 0 - if count > 0: - await session.execute(delete(RunEventRow).where(*count_conditions)) - await session.commit() - return count diff --git a/backend/packages/harness/deerflow/runtime/events/store/jsonl.py b/backend/packages/harness/deerflow/runtime/events/store/jsonl.py deleted file mode 100644 index 378713afc..000000000 --- a/backend/packages/harness/deerflow/runtime/events/store/jsonl.py +++ /dev/null @@ -1,187 +0,0 @@ -"""JSONL file-backed RunEventStore implementation. - -Each run's events are stored in a single file: -``.deer-flow/threads/{thread_id}/runs/{run_id}.jsonl`` - -All categories (message, trace, lifecycle) are in the same file. -This backend is suitable for lightweight single-node deployments. - -Known trade-off: ``list_messages()`` must scan all run files for a -thread since messages from multiple runs need unified seq ordering. -``list_events()`` reads only one file -- the fast path. -""" - -from __future__ import annotations - -import json -import logging -import re -from datetime import UTC, datetime -from pathlib import Path - -from deerflow.runtime.events.store.base import RunEventStore - -logger = logging.getLogger(__name__) - -_SAFE_ID_PATTERN = re.compile(r"^[A-Za-z0-9_\-]+$") - - -class JsonlRunEventStore(RunEventStore): - def __init__(self, base_dir: str | Path | None = None): - self._base_dir = Path(base_dir) if base_dir else Path(".deer-flow") - self._seq_counters: dict[str, int] = {} # thread_id -> current max seq - - @staticmethod - def _validate_id(value: str, label: str) -> str: - """Validate that an ID is safe for use in filesystem paths.""" - if not value or not _SAFE_ID_PATTERN.match(value): - raise ValueError(f"Invalid {label}: must be alphanumeric/dash/underscore, got {value!r}") - return value - - def _thread_dir(self, thread_id: str) -> Path: - self._validate_id(thread_id, "thread_id") - return self._base_dir / "threads" / thread_id / "runs" - - def _run_file(self, thread_id: str, run_id: str) -> Path: - self._validate_id(run_id, "run_id") - return self._thread_dir(thread_id) / f"{run_id}.jsonl" - - def _next_seq(self, thread_id: str) -> int: - self._seq_counters[thread_id] = self._seq_counters.get(thread_id, 0) + 1 - return self._seq_counters[thread_id] - - def _ensure_seq_loaded(self, thread_id: str) -> None: - """Load max seq from existing files if not yet cached.""" - if thread_id in self._seq_counters: - return - max_seq = 0 - thread_dir = self._thread_dir(thread_id) - if thread_dir.exists(): - for f in thread_dir.glob("*.jsonl"): - for line in f.read_text(encoding="utf-8").strip().splitlines(): - try: - record = json.loads(line) - max_seq = max(max_seq, record.get("seq", 0)) - except json.JSONDecodeError: - logger.debug("Skipping malformed JSONL line in %s", f) - continue - self._seq_counters[thread_id] = max_seq - - def _write_record(self, record: dict) -> None: - path = self._run_file(record["thread_id"], record["run_id"]) - path.parent.mkdir(parents=True, exist_ok=True) - with open(path, "a", encoding="utf-8") as f: - f.write(json.dumps(record, default=str, ensure_ascii=False) + "\n") - - def _read_thread_events(self, thread_id: str) -> list[dict]: - """Read all events for a thread, sorted by seq.""" - events = [] - thread_dir = self._thread_dir(thread_id) - if not thread_dir.exists(): - return events - for f in sorted(thread_dir.glob("*.jsonl")): - for line in f.read_text(encoding="utf-8").strip().splitlines(): - if not line: - continue - try: - events.append(json.loads(line)) - except json.JSONDecodeError: - logger.debug("Skipping malformed JSONL line in %s", f) - continue - events.sort(key=lambda e: e.get("seq", 0)) - return events - - def _read_run_events(self, thread_id: str, run_id: str) -> list[dict]: - """Read events for a specific run file.""" - path = self._run_file(thread_id, run_id) - if not path.exists(): - return [] - events = [] - for line in path.read_text(encoding="utf-8").strip().splitlines(): - if not line: - continue - try: - events.append(json.loads(line)) - except json.JSONDecodeError: - logger.debug("Skipping malformed JSONL line in %s", path) - continue - events.sort(key=lambda e: e.get("seq", 0)) - return events - - async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): - self._ensure_seq_loaded(thread_id) - seq = self._next_seq(thread_id) - record = { - "thread_id": thread_id, - "run_id": run_id, - "event_type": event_type, - "category": category, - "content": content, - "metadata": metadata or {}, - "seq": seq, - "created_at": created_at or datetime.now(UTC).isoformat(), - } - self._write_record(record) - return record - - async def put_batch(self, events): - if not events: - return [] - results = [] - for ev in events: - record = await self.put(**ev) - results.append(record) - return results - - async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None): - all_events = self._read_thread_events(thread_id) - messages = [e for e in all_events if e.get("category") == "message"] - - if before_seq is not None: - messages = [e for e in messages if e["seq"] < before_seq] - return messages[-limit:] - elif after_seq is not None: - messages = [e for e in messages if e["seq"] > after_seq] - return messages[:limit] - else: - return messages[-limit:] - - async def list_events(self, thread_id, run_id, *, event_types=None, limit=500): - events = self._read_run_events(thread_id, run_id) - if event_types is not None: - events = [e for e in events if e.get("event_type") in event_types] - return events[:limit] - - async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None): - events = self._read_run_events(thread_id, run_id) - filtered = [e for e in events if e.get("category") == "message"] - if before_seq is not None: - filtered = [e for e in filtered if e.get("seq", 0) < before_seq] - if after_seq is not None: - filtered = [e for e in filtered if e.get("seq", 0) > after_seq] - if after_seq is not None: - return filtered[:limit] - else: - return filtered[-limit:] if len(filtered) > limit else filtered - - async def count_messages(self, thread_id): - all_events = self._read_thread_events(thread_id) - return sum(1 for e in all_events if e.get("category") == "message") - - async def delete_by_thread(self, thread_id): - all_events = self._read_thread_events(thread_id) - count = len(all_events) - thread_dir = self._thread_dir(thread_id) - if thread_dir.exists(): - for f in thread_dir.glob("*.jsonl"): - f.unlink() - self._seq_counters.pop(thread_id, None) - return count - - async def delete_by_run(self, thread_id, run_id): - events = self._read_run_events(thread_id, run_id) - count = len(events) - path = self._run_file(thread_id, run_id) - if path.exists(): - path.unlink() - return count diff --git a/backend/packages/harness/deerflow/runtime/events/store/memory.py b/backend/packages/harness/deerflow/runtime/events/store/memory.py deleted file mode 100644 index cf70e1cdf..000000000 --- a/backend/packages/harness/deerflow/runtime/events/store/memory.py +++ /dev/null @@ -1,128 +0,0 @@ -"""In-memory RunEventStore. Used when run_events.backend=memory (default) and in tests. - -Thread-safe for single-process async usage (no threading locks needed -since all mutations happen within the same event loop). -""" - -from __future__ import annotations - -from datetime import UTC, datetime - -from deerflow.runtime.events.store.base import RunEventStore - - -class MemoryRunEventStore(RunEventStore): - def __init__(self) -> None: - self._events: dict[str, list[dict]] = {} # thread_id -> sorted event list - self._seq_counters: dict[str, int] = {} # thread_id -> last assigned seq - - def _next_seq(self, thread_id: str) -> int: - current = self._seq_counters.get(thread_id, 0) - next_val = current + 1 - self._seq_counters[thread_id] = next_val - return next_val - - def _put_one( - self, - *, - thread_id: str, - run_id: str, - event_type: str, - category: str, - content: str | dict = "", - metadata: dict | None = None, - created_at: str | None = None, - ) -> dict: - seq = self._next_seq(thread_id) - record = { - "thread_id": thread_id, - "run_id": run_id, - "event_type": event_type, - "category": category, - "content": content, - "metadata": metadata or {}, - "seq": seq, - "created_at": created_at or datetime.now(UTC).isoformat(), - } - self._events.setdefault(thread_id, []).append(record) - return record - - async def put( - self, - *, - thread_id, - run_id, - event_type, - category, - content="", - metadata=None, - created_at=None, - ): - return self._put_one( - thread_id=thread_id, - run_id=run_id, - event_type=event_type, - category=category, - content=content, - metadata=metadata, - created_at=created_at, - ) - - async def put_batch(self, events): - results = [] - for ev in events: - record = self._put_one(**ev) - results.append(record) - return results - - async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None): - all_events = self._events.get(thread_id, []) - messages = [e for e in all_events if e["category"] == "message"] - - if before_seq is not None: - messages = [e for e in messages if e["seq"] < before_seq] - # Take the last `limit` records - return messages[-limit:] - elif after_seq is not None: - messages = [e for e in messages if e["seq"] > after_seq] - return messages[:limit] - else: - # Return the latest `limit` records, ascending - return messages[-limit:] - - async def list_events(self, thread_id, run_id, *, event_types=None, limit=500): - all_events = self._events.get(thread_id, []) - filtered = [e for e in all_events if e["run_id"] == run_id] - if event_types is not None: - filtered = [e for e in filtered if e["event_type"] in event_types] - return filtered[:limit] - - async def list_messages_by_run(self, thread_id, run_id, *, limit=50, before_seq=None, after_seq=None): - all_events = self._events.get(thread_id, []) - filtered = [e for e in all_events if e["run_id"] == run_id and e["category"] == "message"] - if before_seq is not None: - filtered = [e for e in filtered if e["seq"] < before_seq] - if after_seq is not None: - filtered = [e for e in filtered if e["seq"] > after_seq] - if after_seq is not None: - return filtered[:limit] - else: - return filtered[-limit:] if len(filtered) > limit else filtered - - async def count_messages(self, thread_id): - all_events = self._events.get(thread_id, []) - return sum(1 for e in all_events if e["category"] == "message") - - async def delete_by_thread(self, thread_id): - events = self._events.pop(thread_id, []) - self._seq_counters.pop(thread_id, None) - return len(events) - - async def delete_by_run(self, thread_id, run_id): - all_events = self._events.get(thread_id, []) - if not all_events: - return 0 - remaining = [e for e in all_events if e["run_id"] != run_id] - removed = len(all_events) - len(remaining) - self._events[thread_id] = remaining - return removed diff --git a/backend/packages/harness/deerflow/runtime/journal.py b/backend/packages/harness/deerflow/runtime/journal.py deleted file mode 100644 index 5f1838888..000000000 --- a/backend/packages/harness/deerflow/runtime/journal.py +++ /dev/null @@ -1,374 +0,0 @@ -"""Run event capture via LangChain callbacks. - -RunJournal sits between LangChain's callback mechanism and the pluggable -RunEventStore. It standardizes callback data into RunEvent records and -handles token usage accumulation. - -Key design decisions: -- on_llm_new_token is NOT implemented -- only complete messages via on_llm_end -- on_chat_model_start captures structured prompts as llm_request (OpenAI format) and - extracts the first human message for run.input, because it is more reliable than - on_chain_start (fires on every node) — messages here are fully structured. -- on_chain_start with parent_run_id=None emits a run.start trace marking root invocation. -- on_llm_end emits llm_response in OpenAI Chat Completions format -- Token usage accumulated in memory, written to RunRow on run completion -- Caller identification via tags injection (lead_agent / subagent:{name} / middleware:{name}) -""" - -from __future__ import annotations - -import asyncio -import logging -import time -from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, cast -from uuid import UUID - -from langchain_core.callbacks import BaseCallbackHandler -from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage -from langgraph.types import Command - -if TYPE_CHECKING: - from deerflow.runtime.events.store.base import RunEventStore - -logger = logging.getLogger(__name__) - - -class RunJournal(BaseCallbackHandler): - """LangChain callback handler that captures events to RunEventStore.""" - - def __init__( - self, - run_id: str, - thread_id: str, - event_store: RunEventStore, - *, - track_token_usage: bool = True, - flush_threshold: int = 20, - ): - super().__init__() - self.run_id = run_id - self.thread_id = thread_id - self._store = event_store - self._track_tokens = track_token_usage - self._flush_threshold = flush_threshold - - # Write buffer - self._buffer: list[dict] = [] - self._pending_flush_tasks: set[asyncio.Task[None]] = set() - - # Token accumulators - self._total_input_tokens = 0 - self._total_output_tokens = 0 - self._total_tokens = 0 - self._llm_call_count = 0 - self._lead_agent_tokens = 0 - self._subagent_tokens = 0 - self._middleware_tokens = 0 - - # Convenience fields - self._last_ai_msg: str | None = None - self._first_human_msg: str | None = None - self._msg_count = 0 - - # Latency tracking - self._llm_start_times: dict[str, float] = {} # langchain run_id -> start time - - # LLM request/response tracking - self._llm_call_index = 0 - self._cached_prompts: dict[str, list[dict]] = {} # langchain run_id -> OpenAI messages - - # -- Lifecycle callbacks -- - - def on_chain_start( - self, - serialized: dict[str, Any], - inputs: dict[str, Any], - *, - run_id: UUID, - parent_run_id: UUID | None = None, - tags: list[str] | None = None, - metadata: dict[str, Any] | None = None, - **kwargs: Any, - ) -> None: - caller = self._identify_caller(tags) - if parent_run_id is None: - # Root graph invocation — emit a single trace event for the run start. - chain_name = (serialized or {}).get("name", "unknown") - self._put( - event_type="run.start", - category="trace", - content={"chain": chain_name}, - metadata={"caller": caller, **(metadata or {})}, - ) - - def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None: - self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"}) - self._flush_sync() - - def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: - self._put( - event_type="run.error", - category="error", - content=str(error), - metadata={"error_type": type(error).__name__}, - ) - self._flush_sync() - - # -- LLM callbacks -- - - def on_chat_model_start( - self, - serialized: dict, - messages: list[list[BaseMessage]], - *, - run_id: UUID, - tags: list[str] | None = None, - **kwargs: Any, - ) -> None: - """Capture structured prompt messages for llm_request event. - - This is also the canonical place to extract the first human message: - messages are fully structured here, it fires only on real LLM calls, - and the content is never compressed by checkpoint trimming. - """ - rid = str(run_id) - self._llm_start_times[rid] = time.monotonic() - self._llm_call_index += 1 - # Mark this run_id as seen so on_llm_end knows not to increment again. - self._cached_prompts[rid] = [] - - logger.info(f"on_chat_model_start {run_id}: tags={tags} serialized={serialized} messages={messages}") - - # Capture the first human message sent to any LLM in this run. - if not self._first_human_msg: - for batch in messages.reversed(): - for m in batch.reversed(): - if isinstance(m, HumanMessage) and m.name != "summary": - caller = self._identify_caller(tags) - self.set_first_human_message(m.text) - self._put( - event_type="llm.human.input", - category="message", - content=m.model_dump(), - metadata={"caller": caller}, - ) - break - if self._first_human_msg: - break - - def on_llm_start(self, serialized: dict, prompts: list[str], *, run_id: UUID, parent_run_id: UUID | None = None, tags: list[str] | None = None, metadata: dict[str, Any] | None = None, **kwargs: Any) -> None: - # Fallback: on_chat_model_start is preferred. This just tracks latency. - self._llm_start_times[str(run_id)] = time.monotonic() - - def on_llm_end(self, response, *, run_id, parent_run_id, tags, **kwargs) -> None: - messages: list[AnyMessage] = [] - logger.info(f"on_llm_end {run_id}: response: {tags} {kwargs}") - for generation in response.generations: - for gen in generation: - if hasattr(gen, "message"): - messages.append(gen.message) - else: - logger.warning(f"on_llm_end {run_id}: generation has no message attribute: {gen}") - - for message in messages: - caller = self._identify_caller(tags) - - # Latency - rid = str(run_id) - start = self._llm_start_times.pop(rid, None) - latency_ms = int((time.monotonic() - start) * 1000) if start else None - - # Token usage from message - usage = getattr(message, "usage_metadata", None) - usage_dict = dict(usage) if usage else {} - - # Resolve call index - call_index = self._llm_call_index - if rid not in self._cached_prompts: - # Fallback: on_chat_model_start was not called - self._llm_call_index += 1 - call_index = self._llm_call_index - - # Trace event: llm_response (OpenAI completion format) - self._put( - event_type="llm.ai.response", - category="message", - content=message.model_dump(), - metadata={ - "caller": caller, - "usage": usage_dict, - "latency_ms": latency_ms, - "llm_call_index": call_index, - }, - ) - - # Token accumulation - if self._track_tokens: - input_tk = usage_dict.get("input_tokens", 0) or 0 - output_tk = usage_dict.get("output_tokens", 0) or 0 - total_tk = usage_dict.get("total_tokens", 0) or 0 - if total_tk == 0: - total_tk = input_tk + output_tk - if total_tk > 0: - self._total_input_tokens += input_tk - self._total_output_tokens += output_tk - self._total_tokens += total_tk - self._llm_call_count += 1 - - def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: - self._llm_start_times.pop(str(run_id), None) - self._put(event_type="llm.error", category="trace", content=str(error)) - - def on_tool_start(self, serialized, input_str, *, run_id, parent_run_id=None, tags=None, metadata=None, inputs=None, **kwargs): - """Handle tool start event, cache tool call ID for later correlation""" - tool_call_id = str(run_id) - logger.info(f"Tool start for node {run_id}, tool_call_id={tool_call_id}, tags={tags}, metadata={metadata}") - - def on_tool_end(self, output, *, run_id, parent_run_id=None, **kwargs): - """Handle tool end event, append message and clear node data""" - try: - if isinstance(output, ToolMessage): - msg = cast(ToolMessage, output) - self._put(event_type="llm.tool.result", category="message", content=msg.model_dump()) - elif isinstance(output, Command): - cmd = cast(Command, output) - messages = cmd.update.get("messages", []) - for message in messages: - if isinstance(message, BaseMessage): - self._put(event_type="llm.tool.result", category="message", content=message.model_dump()) - else: - logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}") - else: - logger.warning(f"on_tool_end {run_id}: output is not ToolMessage: {type(output)}") - finally: - logger.info(f"Tool end for node {run_id}") - - # -- Internal methods -- - - def _put(self, *, event_type: str, category: str, content: str | dict = "", metadata: dict | None = None) -> None: - self._buffer.append( - { - "thread_id": self.thread_id, - "run_id": self.run_id, - "event_type": event_type, - "category": category, - "content": content, - "metadata": metadata or {}, - "created_at": datetime.now(UTC).isoformat(), - } - ) - if len(self._buffer) >= self._flush_threshold: - self._flush_sync() - - def _flush_sync(self) -> None: - """Best-effort flush of buffer to RunEventStore. - - BaseCallbackHandler methods are synchronous. If an event loop is - running we schedule an async ``put_batch``; otherwise the events - stay in the buffer and are flushed later by the async ``flush()`` - call in the worker's ``finally`` block. - """ - if not self._buffer: - return - # Skip if a flush is already in flight — avoids concurrent writes - # to the same SQLite file from multiple fire-and-forget tasks. - if self._pending_flush_tasks: - return - try: - loop = asyncio.get_running_loop() - except RuntimeError: - # No event loop — keep events in buffer for later async flush. - return - batch = self._buffer.copy() - self._buffer.clear() - task = loop.create_task(self._flush_async(batch)) - self._pending_flush_tasks.add(task) - task.add_done_callback(self._on_flush_done) - - async def _flush_async(self, batch: list[dict]) -> None: - try: - await self._store.put_batch(batch) - except Exception: - logger.warning( - "Failed to flush %d events for run %s — returning to buffer", - len(batch), - self.run_id, - exc_info=True, - ) - # Return failed events to buffer for retry on next flush - self._buffer = batch + self._buffer - - def _on_flush_done(self, task: asyncio.Task) -> None: - self._pending_flush_tasks.discard(task) - if task.cancelled(): - return - exc = task.exception() - if exc: - logger.warning("Journal flush task failed: %s", exc) - - def _identify_caller(self, tags: list[str] | None, **kwargs) -> str: - _tags = tags or kwargs.get("tags", []) - for tag in _tags: - if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"): - return tag - # Default to lead_agent: the main agent graph does not inject - # callback tags, while subagents and middleware explicitly tag - # themselves. - return "lead_agent" - - # -- Public methods (called by worker) -- - - def set_first_human_message(self, content: str) -> None: - """Record the first human message for convenience fields.""" - self._first_human_msg = content[:2000] if content else None - - def record_middleware(self, tag: str, *, name: str, hook: str, action: str, changes: dict) -> None: - """Record a middleware state-change event. - - Called by middleware implementations when they perform a meaningful - state change (e.g., title generation, summarization, HITL approval). - Pure-observation middleware should not call this. - - Args: - tag: Short identifier for the middleware (e.g., "title", "summarize", - "guardrail"). Used to form event_type="middleware:{tag}". - name: Full middleware class name. - hook: Lifecycle hook that triggered the action (e.g., "after_model"). - action: Specific action performed (e.g., "generate_title"). - changes: Dict describing the state changes made. - """ - self._put( - event_type=f"middleware:{tag}", - category="middleware", - content={"name": name, "hook": hook, "action": action, "changes": changes}, - ) - - async def flush(self) -> None: - """Force flush remaining buffer. Called in worker's finally block.""" - if self._pending_flush_tasks: - await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True) - - while self._buffer: - batch = self._buffer[: self._flush_threshold] - del self._buffer[: self._flush_threshold] - try: - await self._store.put_batch(batch) - except Exception: - self._buffer = batch + self._buffer - raise - - def get_completion_data(self) -> dict: - """Return accumulated token and message data for run completion.""" - return { - "total_input_tokens": self._total_input_tokens, - "total_output_tokens": self._total_output_tokens, - "total_tokens": self._total_tokens, - "llm_call_count": self._llm_call_count, - "lead_agent_tokens": self._lead_agent_tokens, - "subagent_tokens": self._subagent_tokens, - "middleware_tokens": self._middleware_tokens, - "message_count": self._msg_count, - "last_ai_message": self._last_ai_msg, - "first_human_message": self._first_human_msg, - } diff --git a/backend/packages/harness/deerflow/runtime/store/__init__.py b/backend/packages/harness/deerflow/runtime/store/__init__.py deleted file mode 100644 index 2f5e77aaa..000000000 --- a/backend/packages/harness/deerflow/runtime/store/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Store provider for the DeerFlow runtime. - -Re-exports the public API of both the async provider (for long-running -servers) and the sync provider (for CLI tools and the embedded client). - -Async usage (FastAPI lifespan):: - - from deerflow.runtime.store import make_store - - async with make_store() as store: - app.state.store = store - -Sync usage (CLI / DeerFlowClient):: - - from deerflow.runtime.store import get_store, store_context - - store = get_store() # singleton - with store_context() as store: ... # one-shot -""" - -from .async_provider import make_store -from .provider import get_store, reset_store, store_context - -__all__ = [ - # async - "make_store", - # sync - "get_store", - "reset_store", - "store_context", -] diff --git a/backend/packages/harness/deerflow/runtime/store/_sqlite_utils.py b/backend/packages/harness/deerflow/runtime/store/_sqlite_utils.py deleted file mode 100644 index bb970e572..000000000 --- a/backend/packages/harness/deerflow/runtime/store/_sqlite_utils.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Shared SQLite connection utilities for store and checkpointer providers.""" - -from __future__ import annotations - -import pathlib - -from deerflow.config.paths import resolve_path - - -def resolve_sqlite_conn_str(raw: str) -> str: - """Return a SQLite connection string ready for use with store/checkpointer backends. - - SQLite special strings (``":memory:"`` and ``file:`` URIs) are returned - unchanged. Plain filesystem paths — relative or absolute — are resolved - to an absolute string via :func:`resolve_path`. - """ - if raw == ":memory:" or raw.startswith("file:"): - return raw - return str(resolve_path(raw)) - - -def ensure_sqlite_parent_dir(conn_str: str) -> None: - """Create parent directory for a SQLite filesystem path. - - No-op for in-memory databases (``":memory:"``) and ``file:`` URIs. - """ - if conn_str != ":memory:" and not conn_str.startswith("file:"): - pathlib.Path(conn_str).parent.mkdir(parents=True, exist_ok=True) diff --git a/backend/packages/harness/deerflow/runtime/store/async_provider.py b/backend/packages/harness/deerflow/runtime/store/async_provider.py deleted file mode 100644 index 68cd107c8..000000000 --- a/backend/packages/harness/deerflow/runtime/store/async_provider.py +++ /dev/null @@ -1,113 +0,0 @@ -"""Async Store factory — backend mirrors the configured checkpointer. - -The store and checkpointer share the same ``checkpointer`` section in -*config.yaml* so they always use the same persistence backend: - -- ``type: memory`` → :class:`langgraph.store.memory.InMemoryStore` -- ``type: sqlite`` → :class:`langgraph.store.sqlite.aio.AsyncSqliteStore` -- ``type: postgres`` → :class:`langgraph.store.postgres.aio.AsyncPostgresStore` - -Usage (e.g. FastAPI lifespan):: - - from deerflow.runtime.store import make_store - - async with make_store() as store: - app.state.store = store -""" - -from __future__ import annotations - -import contextlib -import logging -from collections.abc import AsyncIterator - -from langgraph.store.base import BaseStore - -from deerflow.config.app_config import get_app_config -from deerflow.runtime.store.provider import POSTGRES_CONN_REQUIRED, POSTGRES_STORE_INSTALL, SQLITE_STORE_INSTALL, ensure_sqlite_parent_dir, resolve_sqlite_conn_str - -logger = logging.getLogger(__name__) - -# --------------------------------------------------------------------------- -# Internal backend factory -# --------------------------------------------------------------------------- - - -@contextlib.asynccontextmanager -async def _async_store(config) -> AsyncIterator[BaseStore]: - """Async context manager that constructs and tears down a Store. - - The ``config`` argument is a :class:`deerflow.config.checkpointer_config.CheckpointerConfig` - instance — the same object used by the checkpointer factory. - """ - if config.type == "memory": - from langgraph.store.memory import InMemoryStore - - logger.info("Store: using InMemoryStore (in-process, not persistent)") - yield InMemoryStore() - return - - if config.type == "sqlite": - try: - from langgraph.store.sqlite.aio import AsyncSqliteStore - except ImportError as exc: - raise ImportError(SQLITE_STORE_INSTALL) from exc - - conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db") - ensure_sqlite_parent_dir(conn_str) - - async with AsyncSqliteStore.from_conn_string(conn_str) as store: - await store.setup() - logger.info("Store: using AsyncSqliteStore (%s)", conn_str) - yield store - return - - if config.type == "postgres": - try: - from langgraph.store.postgres.aio import AsyncPostgresStore # type: ignore[import] - except ImportError as exc: - raise ImportError(POSTGRES_STORE_INSTALL) from exc - - if not config.connection_string: - raise ValueError(POSTGRES_CONN_REQUIRED) - - async with AsyncPostgresStore.from_conn_string(config.connection_string) as store: - await store.setup() - logger.info("Store: using AsyncPostgresStore") - yield store - return - - raise ValueError(f"Unknown store backend type: {config.type!r}") - - -# --------------------------------------------------------------------------- -# Public async context manager -# --------------------------------------------------------------------------- - - -@contextlib.asynccontextmanager -async def make_store() -> AsyncIterator[BaseStore]: - """Async context manager that yields a Store whose backend matches the - configured checkpointer. - - Reads from the same ``checkpointer`` section of *config.yaml* used by - :func:`deerflow.runtime.checkpointer.async_provider.make_checkpointer` so - that both singletons always use the same persistence technology:: - - async with make_store() as store: - app.state.store = store - - Yields an :class:`~langgraph.store.memory.InMemoryStore` when no - ``checkpointer`` section is configured (emits a WARNING in that case). - """ - config = get_app_config() - - if config.checkpointer is None: - from langgraph.store.memory import InMemoryStore - - logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.") - yield InMemoryStore() - return - - async with _async_store(config.checkpointer) as store: - yield store diff --git a/backend/packages/harness/deerflow/runtime/store/provider.py b/backend/packages/harness/deerflow/runtime/store/provider.py deleted file mode 100644 index a9394fb9f..000000000 --- a/backend/packages/harness/deerflow/runtime/store/provider.py +++ /dev/null @@ -1,188 +0,0 @@ -"""Sync Store factory. - -Provides a **sync singleton** and a **sync context manager** for CLI tools -and the embedded :class:`~deerflow.client.DeerFlowClient`. - -The backend mirrors the configured checkpointer so that both always use the -same persistence technology. Supported backends: memory, sqlite, postgres. - -Usage:: - - from deerflow.runtime.store.provider import get_store, store_context - - # Singleton — reused across calls, closed on process exit - store = get_store() - - # One-shot — fresh connection, closed on block exit - with store_context() as store: - store.put(("ns",), "key", {"value": 1}) -""" - -from __future__ import annotations - -import contextlib -import logging -from collections.abc import Iterator - -from langgraph.store.base import BaseStore - -from deerflow.config.app_config import get_app_config -from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str - -logger = logging.getLogger(__name__) - -# --------------------------------------------------------------------------- -# Error message constants -# --------------------------------------------------------------------------- - -SQLITE_STORE_INSTALL = "langgraph-checkpoint-sqlite is required for the SQLite store. Install it with: uv add langgraph-checkpoint-sqlite" -POSTGRES_STORE_INSTALL = "langgraph-checkpoint-postgres is required for the PostgreSQL store. Install it with: uv add langgraph-checkpoint-postgres psycopg[binary] psycopg-pool" -POSTGRES_CONN_REQUIRED = "checkpointer.connection_string is required for the postgres backend" - -# --------------------------------------------------------------------------- -# Sync factory -# --------------------------------------------------------------------------- - - -@contextlib.contextmanager -def _sync_store_cm(config) -> Iterator[BaseStore]: - """Context manager that creates and tears down a sync Store. - - The ``config`` argument is a - :class:`~deerflow.config.checkpointer_config.CheckpointerConfig` instance — - the same object used by the checkpointer factory. - """ - if config.type == "memory": - from langgraph.store.memory import InMemoryStore - - logger.info("Store: using InMemoryStore (in-process, not persistent)") - yield InMemoryStore() - return - - if config.type == "sqlite": - try: - from langgraph.store.sqlite import SqliteStore - except ImportError as exc: - raise ImportError(SQLITE_STORE_INSTALL) from exc - - conn_str = resolve_sqlite_conn_str(config.connection_string or "store.db") - ensure_sqlite_parent_dir(conn_str) - - with SqliteStore.from_conn_string(conn_str) as store: - store.setup() - logger.info("Store: using SqliteStore (%s)", conn_str) - yield store - return - - if config.type == "postgres": - try: - from langgraph.store.postgres import PostgresStore # type: ignore[import] - except ImportError as exc: - raise ImportError(POSTGRES_STORE_INSTALL) from exc - - if not config.connection_string: - raise ValueError(POSTGRES_CONN_REQUIRED) - - with PostgresStore.from_conn_string(config.connection_string) as store: - store.setup() - logger.info("Store: using PostgresStore") - yield store - return - - raise ValueError(f"Unknown store backend type: {config.type!r}") - - -# --------------------------------------------------------------------------- -# Sync singleton -# --------------------------------------------------------------------------- - -_store: BaseStore | None = None -_store_ctx = None # open context manager keeping the connection alive - - -def get_store() -> BaseStore: - """Return the global sync Store singleton, creating it on first call. - - Returns an :class:`~langgraph.store.memory.InMemoryStore` when no - checkpointer is configured in *config.yaml* (emits a WARNING in that case). - - Raises: - ImportError: If the required package for the configured backend is not installed. - ValueError: If ``connection_string`` is missing for a backend that requires it. - """ - global _store, _store_ctx - - if _store is not None: - return _store - - # Lazily load app config, mirroring the checkpointer singleton pattern so - # that tests that set the global checkpointer config explicitly remain isolated. - from deerflow.config.app_config import _app_config - from deerflow.config.checkpointer_config import get_checkpointer_config - - config = get_checkpointer_config() - - if config is None and _app_config is None: - try: - get_app_config() - except FileNotFoundError: - pass - config = get_checkpointer_config() - - if config is None: - from langgraph.store.memory import InMemoryStore - - logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.") - _store = InMemoryStore() - return _store - - _store_ctx = _sync_store_cm(config) - _store = _store_ctx.__enter__() - return _store - - -def reset_store() -> None: - """Reset the sync singleton, forcing recreation on the next call. - - Closes any open backend connections and clears the cached instance. - Useful in tests or after a configuration change. - """ - global _store, _store_ctx - if _store_ctx is not None: - try: - _store_ctx.__exit__(None, None, None) - except Exception: - logger.warning("Error during store cleanup", exc_info=True) - _store_ctx = None - _store = None - - -# --------------------------------------------------------------------------- -# Sync context manager -# --------------------------------------------------------------------------- - - -@contextlib.contextmanager -def store_context() -> Iterator[BaseStore]: - """Sync context manager that yields a Store and cleans up on exit. - - Unlike :func:`get_store`, this does **not** cache the instance — each - ``with`` block creates and destroys its own connection. Use it in CLI - scripts or tests where you want deterministic cleanup:: - - with store_context() as store: - store.put(("threads",), thread_id, {...}) - - Yields an :class:`~langgraph.store.memory.InMemoryStore` when no - checkpointer is configured in *config.yaml*. - """ - config = get_app_config() - if config.checkpointer is None: - from langgraph.store.memory import InMemoryStore - - logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.") - yield InMemoryStore() - return - - with _sync_store_cm(config.checkpointer) as store: - yield store