mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
feat(storage): implement unified persistence layer with database models and repositories
This commit is contained in:
parent
3a99c4e81c
commit
38a6ec496f
@ -1,17 +1,12 @@
|
||||
{
|
||||
"$schema": "https://langgra.ph/schema.json",
|
||||
"python_version": "3.12",
|
||||
"dependencies": [
|
||||
"."
|
||||
],
|
||||
"dependencies": ["."],
|
||||
"env": ".env",
|
||||
"graphs": {
|
||||
"lead_agent": "deerflow.agents:make_lead_agent"
|
||||
},
|
||||
"auth": {
|
||||
"path": "./app/gateway/langgraph_auth.py:auth"
|
||||
},
|
||||
"checkpointer": {
|
||||
"path": "./packages/harness/deerflow/agents/checkpointer/async_provider.py:make_checkpointer"
|
||||
"path": "./packages/harness/deerflow/runtime/checkpointer/async_provider.py:make_checkpointer"
|
||||
}
|
||||
}
|
||||
|
||||
35
backend/packages/storage/pyproject.toml
Normal file
35
backend/packages/storage/pyproject.toml
Normal file
@ -0,0 +1,35 @@
|
||||
[project]
|
||||
name = "deerflow-storage"
|
||||
version = "0.1.0"
|
||||
description = "DeerFlow storage framework"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"dotenv>=0.9.9",
|
||||
"pydantic>=2.12.5",
|
||||
"pyyaml>=6.0.3",
|
||||
"sqlalchemy[asyncio]>=2.0,<3.0",
|
||||
"alembic>=1.13",
|
||||
"langgraph>=1.0.6,<1.0.10",
|
||||
]
|
||||
[project.optional-dependencies]
|
||||
postgres = [
|
||||
"asyncpg>=0.29",
|
||||
"langgraph-checkpoint-postgres>=3.0.5",
|
||||
"psycopg[binary]>=3.3.3",
|
||||
"psycopg-pool>=3.3.0",
|
||||
]
|
||||
mysql = [
|
||||
"aiomysql>=0.2",
|
||||
"langgraph-checkpoint-mysql>=3.0.0",
|
||||
]
|
||||
sqlite = [
|
||||
"aiosqlite>=0.22.1",
|
||||
"langgraph-checkpoint-sqlite>=3.0.3"
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["store"]
|
||||
0
backend/packages/storage/store/__init__.py
Normal file
0
backend/packages/storage/store/__init__.py
Normal file
5
backend/packages/storage/store/common/__init__.py
Normal file
5
backend/packages/storage/store/common/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .enums import DataBaseType
|
||||
|
||||
__all__ = [
|
||||
'DataBaseType',
|
||||
]
|
||||
40
backend/packages/storage/store/common/enums.py
Normal file
40
backend/packages/storage/store/common/enums.py
Normal file
@ -0,0 +1,40 @@
|
||||
from enum import Enum
|
||||
from enum import IntEnum as SourceIntEnum
|
||||
from typing import Any, TypeVar
|
||||
|
||||
T = TypeVar('T', bound=Enum)
|
||||
|
||||
|
||||
class _EnumBase:
|
||||
"""Base enum class with common utility methods."""
|
||||
|
||||
@classmethod
|
||||
def get_member_keys(cls) -> list[str]:
|
||||
"""Return a list of enum member names."""
|
||||
return list(cls.__members__.keys())
|
||||
|
||||
@classmethod
|
||||
def get_member_values(cls) -> list:
|
||||
"""Return a list of enum member values."""
|
||||
return [item.value for item in cls.__members__.values()]
|
||||
|
||||
@classmethod
|
||||
def get_member_dict(cls) -> dict[str, Any]:
|
||||
"""Return a dict mapping member names to values."""
|
||||
return {name: item.value for name, item in cls.__members__.items()}
|
||||
|
||||
|
||||
class IntEnum(_EnumBase, SourceIntEnum):
|
||||
"""Integer enum base class."""
|
||||
|
||||
|
||||
class StrEnum(_EnumBase, str, Enum):
|
||||
"""String enum base class."""
|
||||
|
||||
|
||||
class DataBaseType(StrEnum):
|
||||
"""Database type."""
|
||||
|
||||
sqlite = 'sqlite'
|
||||
mysql = 'mysql'
|
||||
postgresql = 'postgresql'
|
||||
0
backend/packages/storage/store/config/__init__.py
Normal file
0
backend/packages/storage/store/config/__init__.py
Normal file
244
backend/packages/storage/store/config/app_config.py
Normal file
244
backend/packages/storage/store/config/app_config.py
Normal file
@ -0,0 +1,244 @@
|
||||
import logging
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
from pathlib import Path
|
||||
from typing import Any, Self
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from store.config.storage_config import StorageConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _default_config_candidates() -> tuple[Path, ...]:
|
||||
"""Return deterministic config.yaml locations without relying on cwd."""
|
||||
backend_dir = Path(__file__).resolve().parents[3]
|
||||
repo_root = backend_dir.parent
|
||||
return backend_dir / "config.yaml", repo_root / "config.yaml"
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""DeerFlow application configuration."""
|
||||
|
||||
timezone: str = Field(default="UTC", description="Timezone for scheduling and timestamps (e.g. 'UTC', 'America/New_York')")
|
||||
log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)")
|
||||
storage: StorageConfig = Field(default=StorageConfig())
|
||||
model_config = ConfigDict(extra="allow", frozen=False)
|
||||
|
||||
@classmethod
|
||||
def resolve_config_path(cls, config_path: str | None = None) -> Path:
|
||||
"""Resolve the config file path.
|
||||
|
||||
Priority:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`.
|
||||
"""
|
||||
if config_path:
|
||||
path = Path(config_path)
|
||||
if not Path.exists(path):
|
||||
raise FileNotFoundError(f"Config file specified by param `config_path` not found at {path}")
|
||||
return path
|
||||
elif os.getenv("DEER_FLOW_CONFIG_PATH"):
|
||||
path = Path(os.getenv("DEER_FLOW_CONFIG_PATH"))
|
||||
if not Path.exists(path):
|
||||
raise FileNotFoundError(
|
||||
f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
|
||||
return path
|
||||
else:
|
||||
for path in _default_config_candidates():
|
||||
if path.exists():
|
||||
return path
|
||||
raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations")
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str | None = None) -> Self:
|
||||
"""Load and validate config from YAML. See `resolve_config_path` for path resolution."""
|
||||
resolved_path = cls.resolve_config_path(config_path)
|
||||
with open(resolved_path, encoding="utf-8") as f:
|
||||
config_data = yaml.safe_load(f) or {}
|
||||
|
||||
cls._check_config_version(config_data, resolved_path)
|
||||
|
||||
config_data = cls.resolve_env_variables(config_data)
|
||||
|
||||
if os.getenv("TIMEZONE"):
|
||||
config_data["timezone"] = os.getenv("TIMEZONE")
|
||||
|
||||
result = cls.model_validate(config_data)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _check_config_version(cls, config_data: dict, config_path: Path) -> None:
|
||||
"""Check if the user's config.yaml is outdated compared to config.example.yaml.
|
||||
|
||||
Emits a warning if the user's config_version is lower than the example's.
|
||||
Missing config_version is treated as version 0 (pre-versioning).
|
||||
"""
|
||||
try:
|
||||
user_version = int(config_data.get("config_version", 0))
|
||||
except (TypeError, ValueError):
|
||||
user_version = 0
|
||||
|
||||
# Find config.example.yaml by searching config.yaml's directory and its parents
|
||||
example_path = None
|
||||
search_dir = config_path.parent
|
||||
for _ in range(5): # search up to 5 levels
|
||||
candidate = search_dir / "config.example.yaml"
|
||||
if candidate.exists():
|
||||
example_path = candidate
|
||||
break
|
||||
parent = search_dir.parent
|
||||
if parent == search_dir:
|
||||
break
|
||||
search_dir = parent
|
||||
if example_path is None:
|
||||
return
|
||||
|
||||
try:
|
||||
with open(example_path, encoding="utf-8") as f:
|
||||
example_data = yaml.safe_load(f)
|
||||
raw = example_data.get("config_version", 0) if example_data else 0
|
||||
try:
|
||||
example_version = int(raw)
|
||||
except (TypeError, ValueError):
|
||||
example_version = 0
|
||||
except Exception:
|
||||
return
|
||||
|
||||
if user_version < example_version:
|
||||
logger.warning(
|
||||
"Your config.yaml (version %d) is outdated — the latest version is %d. Run `make config-upgrade` to "
|
||||
"merge new fields into your config.",
|
||||
user_version,
|
||||
example_version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def resolve_env_variables(cls, config: Any) -> Any:
|
||||
"""Recursively replace $VAR strings with their environment variable values (e.g. $OPENAI_API_KEY)."""
|
||||
if isinstance(config, str):
|
||||
if config.startswith("$"):
|
||||
env_value = os.getenv(config[1:])
|
||||
if env_value is None:
|
||||
raise ValueError(f"Environment variable {config[1:]} not found for config value {config}")
|
||||
return env_value
|
||||
return config
|
||||
elif isinstance(config, dict):
|
||||
return {k: cls.resolve_env_variables(v) for k, v in config.items()}
|
||||
elif isinstance(config, list):
|
||||
return [cls.resolve_env_variables(item) for item in config]
|
||||
return config
|
||||
|
||||
|
||||
|
||||
_app_config: AppConfig | None = None
|
||||
_app_config_path: Path | None = None
|
||||
_app_config_mtime: float | None = None
|
||||
_app_config_is_custom = False
|
||||
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
|
||||
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack",
|
||||
default=())
|
||||
|
||||
|
||||
def _get_config_mtime(config_path: Path) -> float | None:
|
||||
"""Get the modification time of a config file if it exists."""
|
||||
try:
|
||||
return config_path.stat().st_mtime
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
|
||||
def _load_and_cache_app_config(config_path: str | None = None) -> AppConfig:
|
||||
"""Load config from disk and refresh cache metadata."""
|
||||
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
|
||||
|
||||
resolved_path = AppConfig.resolve_config_path(config_path)
|
||||
_app_config = AppConfig.from_file(str(resolved_path))
|
||||
_app_config_path = resolved_path
|
||||
_app_config_mtime = _get_config_mtime(resolved_path)
|
||||
_app_config_is_custom = False
|
||||
return _app_config
|
||||
|
||||
|
||||
def get_app_config() -> AppConfig:
|
||||
"""Get the DeerFlow config instance.
|
||||
|
||||
Returns a cached singleton instance and automatically reloads it when the
|
||||
underlying config file path or modification time changes. Use
|
||||
`reload_app_config()` to force a reload, or `reset_app_config()` to clear
|
||||
the cache.
|
||||
"""
|
||||
global _app_config, _app_config_path, _app_config_mtime
|
||||
|
||||
runtime_override = _current_app_config.get()
|
||||
if runtime_override is not None:
|
||||
return runtime_override
|
||||
|
||||
if _app_config is not None and _app_config_is_custom:
|
||||
return _app_config
|
||||
|
||||
resolved_path = AppConfig.resolve_config_path()
|
||||
current_mtime = _get_config_mtime(resolved_path)
|
||||
|
||||
should_reload = _app_config is None or _app_config_path != resolved_path or _app_config_mtime != current_mtime
|
||||
if should_reload:
|
||||
if _app_config_path == resolved_path and _app_config_mtime is not None and current_mtime is not None and _app_config_mtime != current_mtime:
|
||||
logger.info(
|
||||
"Config file has been modified (mtime: %s -> %s), reloading AppConfig",
|
||||
_app_config_mtime,
|
||||
current_mtime,
|
||||
)
|
||||
_load_and_cache_app_config(str(resolved_path))
|
||||
return _app_config
|
||||
|
||||
|
||||
def reload_app_config(config_path: str | None = None) -> AppConfig:
|
||||
"""Force reload from file and update the cache."""
|
||||
return _load_and_cache_app_config(config_path)
|
||||
|
||||
|
||||
def reset_app_config() -> None:
|
||||
"""Clear the cache so the next `get_app_config()` reloads from file."""
|
||||
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
|
||||
_app_config = None
|
||||
_app_config_path = None
|
||||
_app_config_mtime = None
|
||||
_app_config_is_custom = False
|
||||
|
||||
|
||||
def set_app_config(config: AppConfig) -> None:
|
||||
"""Inject a config instance directly, bypassing file loading (for testing)."""
|
||||
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
|
||||
_app_config = config
|
||||
_app_config_path = None
|
||||
_app_config_mtime = None
|
||||
_app_config_is_custom = True
|
||||
|
||||
|
||||
def peek_current_app_config() -> AppConfig | None:
|
||||
"""Return the runtime-scoped AppConfig override, if one is active."""
|
||||
return _current_app_config.get()
|
||||
|
||||
|
||||
def push_current_app_config(config: AppConfig) -> None:
|
||||
"""Push a runtime-scoped AppConfig override for the current execution context."""
|
||||
stack = _current_app_config_stack.get()
|
||||
_current_app_config_stack.set(stack + (_current_app_config.get(),))
|
||||
_current_app_config.set(config)
|
||||
|
||||
|
||||
def pop_current_app_config() -> None:
|
||||
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
|
||||
stack = _current_app_config_stack.get()
|
||||
if not stack:
|
||||
_current_app_config.set(None)
|
||||
return
|
||||
previous = stack[-1]
|
||||
_current_app_config_stack.set(stack[:-1])
|
||||
_current_app_config.set(previous)
|
||||
52
backend/packages/storage/store/config/storage_config.py
Normal file
52
backend/packages/storage/store/config/storage_config.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""Unified storage backend configuration for checkpointer and application data.
|
||||
|
||||
SQLite: checkpointer → {sqlite_dir}/checkpoints.db, app → {sqlite_dir}/deerflow.db
|
||||
(separate files to avoid write-lock contention)
|
||||
Postgres: shared URL, independent connection pools per layer.
|
||||
|
||||
Sensitive values use $VAR syntax resolved by AppConfig.resolve_env_variables()
|
||||
before this config is instantiated.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class StorageConfig(BaseModel):
|
||||
driver: Literal["mysql", "sqlite", "postgres"] = Field(
|
||||
default="sqlite",
|
||||
description="Storage driver for both checkpointer and application data. "
|
||||
"'sqlite' for single-node deployment (default),"
|
||||
"'postgres' for production multi-node deployment, "
|
||||
"'mysql' for MySQL databases.",
|
||||
)
|
||||
sqlite_dir: str = Field(
|
||||
default=".deer-flow/data",
|
||||
description="Directory for SQLite .db files (sqlite driver only).",
|
||||
)
|
||||
username: str = Field(default="", description="db username ")
|
||||
password: str = Field(default="", description="db password. Use $VAR syntax in config.yaml to read from .env.")
|
||||
host: str = Field(default="localhost", description="db host.")
|
||||
port: int = Field(default=5432, description="db port.")
|
||||
db_name: str = Field(default="deerflow", description="db database name.")
|
||||
sqlite_db_path: str = Field(default=".deer-flow/data", description="Directory for SQLite .db files (sqlite driver only).")
|
||||
echo_sql: bool = Field(default=False, description="Log all SQL statements (debug only).")
|
||||
pool_size: int = Field(default=5, description="Connection pool size per layer.")
|
||||
|
||||
# -- Derived helpers (not user-configured) --
|
||||
|
||||
@property
|
||||
def _resolved_sqlite_dir(self) -> str:
|
||||
"""Resolve sqlite_dir to an absolute path (relative to CWD)."""
|
||||
from pathlib import Path
|
||||
|
||||
return str(Path(self.sqlite_dir).resolve())
|
||||
|
||||
@property
|
||||
def sqlite_storage_path(self) -> str:
|
||||
"""SQLite file path for the LangGraph checkpointer."""
|
||||
return os.path.join(self._resolved_sqlite_dir, "deerflow.db")
|
||||
23
backend/packages/storage/store/persistence/__init__.py
Normal file
23
backend/packages/storage/store/persistence/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
from store.persistence.base_model import (
|
||||
Base,
|
||||
DataClassBase,
|
||||
DateTimeMixin,
|
||||
MappedBase,
|
||||
TimeZone,
|
||||
UniversalText,
|
||||
id_key,
|
||||
)
|
||||
from .factory import create_persistence
|
||||
from .types import AppPersistence
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"DataClassBase",
|
||||
"DateTimeMixin",
|
||||
"MappedBase",
|
||||
"TimeZone",
|
||||
"UniversalText",
|
||||
"id_key",
|
||||
"create_persistence",
|
||||
"AppPersistence"
|
||||
]
|
||||
107
backend/packages/storage/store/persistence/base_model.py
Normal file
107
backend/packages/storage/store/persistence/base_model.py
Normal file
@ -0,0 +1,107 @@
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from sqlalchemy import BigInteger, DateTime, Integer, Text, TypeDecorator
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column
|
||||
|
||||
from store.common import DataBaseType
|
||||
from store.config.app_config import get_app_config
|
||||
from store.utils import get_timezone
|
||||
|
||||
timezone = get_timezone()
|
||||
app_config = get_app_config()
|
||||
|
||||
# SQLite autoincrement only works with INTEGER PRIMARY KEY (not BIGINT)
|
||||
_id_type = Integer if app_config.storage.driver == DataBaseType.sqlite else BigInteger
|
||||
|
||||
id_key = Annotated[
|
||||
int,
|
||||
mapped_column(
|
||||
_id_type,
|
||||
primary_key=True,
|
||||
unique=True,
|
||||
index=True,
|
||||
autoincrement=True,
|
||||
sort_order=-999,
|
||||
comment="Primary key ID",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class UniversalText(TypeDecorator[str]):
|
||||
"""Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL)."""
|
||||
|
||||
impl = LONGTEXT if DataBaseType.mysql == app_config.storage.driver else Text
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: str | None, dialect) -> str | None: # noqa: ANN001
|
||||
return value
|
||||
|
||||
|
||||
class TimeZone(TypeDecorator[datetime]):
|
||||
"""Timezone-aware datetime type compatible with PostgreSQL and MySQL."""
|
||||
|
||||
impl = DateTime(timezone=True)
|
||||
cache_ok = True
|
||||
|
||||
@property
|
||||
def python_type(self) -> type[datetime]:
|
||||
return datetime
|
||||
|
||||
def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
||||
if value is not None and value.utcoffset() != timezone.now().utcoffset():
|
||||
value = timezone.from_datetime(value)
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
||||
if value is not None and value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.tz_info)
|
||||
return value
|
||||
|
||||
|
||||
class DateTimeMixin(MappedAsDataclass):
|
||||
"""Mixin that adds created_time / updated_time columns."""
|
||||
|
||||
created_time: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=timezone.now,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
updated_time: Mapped[datetime | None] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
onupdate=timezone.now,
|
||||
sort_order=999,
|
||||
comment="Updated at",
|
||||
)
|
||||
|
||||
|
||||
class MappedBase(AsyncAttrs, DeclarativeBase):
|
||||
"""Async-capable declarative base for all ORM models."""
|
||||
|
||||
@declared_attr.directive
|
||||
def __tablename__(self) -> str:
|
||||
return self.__name__.lower()
|
||||
|
||||
@declared_attr.directive
|
||||
def __table_args__(self) -> dict:
|
||||
return {"comment": self.__doc__ or ""}
|
||||
|
||||
|
||||
class DataClassBase(MappedAsDataclass, MappedBase):
|
||||
"""Declarative base with native dataclass integration."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
|
||||
class Base(DataClassBase, DateTimeMixin):
|
||||
"""Declarative dataclass base with created_time / updated_time columns."""
|
||||
|
||||
__abstract__ = True
|
||||
@ -0,0 +1,9 @@
|
||||
from .mysql import build_mysql_persistence
|
||||
from .postgres import build_postgres_persistence
|
||||
from .sqlite import build_sqlite_persistence
|
||||
|
||||
__all__ = [
|
||||
"build_postgres_persistence",
|
||||
"build_mysql_persistence",
|
||||
"build_sqlite_persistence",
|
||||
]
|
||||
67
backend/packages/storage/store/persistence/drivers/mysql.py
Normal file
67
backend/packages/storage/store/persistence/drivers/mysql.py
Normal file
@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from store.persistence import MappedBase
|
||||
from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
def _validate_mysql_driver(db_url: str) -> str:
|
||||
url = make_url(db_url)
|
||||
driver = url.get_driver_name()
|
||||
|
||||
if driver not in {"aiomysql", "asyncmy"}:
|
||||
raise ValueError(
|
||||
"MySQL persistence requires async SQLAlchemy driver "
|
||||
f"(aiomysql/asyncmy), got: {driver!r}"
|
||||
)
|
||||
return driver
|
||||
|
||||
|
||||
async def build_mysql_persistence(db_url: URL) -> AppPersistence:
|
||||
_validate_mysql_driver(db_url)
|
||||
|
||||
from langgraph.checkpoint.mysql.aio import AIOMySQLSaver
|
||||
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
future=True,
|
||||
)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AIOMySQLSaver.from_conn_string(db_url)
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
# 1. LangGraph checkpoint tables / migrations
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. ORM business tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
async def _close_saver() -> None:
|
||||
await saver_cm.__aexit__(None, None, None)
|
||||
|
||||
async def aclose() -> None:
|
||||
await close_in_order(
|
||||
engine.dispose,
|
||||
_close_saver,
|
||||
)
|
||||
|
||||
return AppPersistence(
|
||||
checkpointer=checkpointer,
|
||||
engine=engine,
|
||||
session_factory=session_factory,
|
||||
setup=setup,
|
||||
aclose=aclose,
|
||||
)
|
||||
@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from store.persistence import MappedBase
|
||||
from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
async def build_postgres_persistence(db_url: URL) -> AppPersistence:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
future=True,
|
||||
)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AsyncPostgresSaver.from_conn_string(db_url)
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
# 1. LangGraph checkpoint tables / migrations
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. ORM business tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
async def _close_saver() -> None:
|
||||
await saver_cm.__aexit__(None, None, None)
|
||||
|
||||
async def aclose() -> None:
|
||||
await close_in_order(
|
||||
engine.dispose,
|
||||
_close_saver,
|
||||
)
|
||||
|
||||
return AppPersistence(
|
||||
checkpointer=checkpointer,
|
||||
engine=engine,
|
||||
session_factory=session_factory,
|
||||
setup=setup,
|
||||
aclose=aclose,
|
||||
)
|
||||
53
backend/packages/storage/store/persistence/drivers/sqlite.py
Normal file
53
backend/packages/storage/store/persistence/drivers/sqlite.py
Normal file
@ -0,0 +1,53 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from store.persistence import MappedBase
|
||||
from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
async def build_sqlite_persistence(db_url: URL) -> AppPersistence:
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
future=True,
|
||||
)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AsyncSqliteSaver.from_conn_string(db_url.database)
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
# 1. LangGraph checkpoint tables
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. ORM business tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
async def _close_saver() -> None:
|
||||
await saver_cm.__aexit__(None, None, None)
|
||||
|
||||
async def aclose() -> None:
|
||||
await close_in_order(
|
||||
engine.dispose,
|
||||
_close_saver,
|
||||
)
|
||||
|
||||
return AppPersistence(
|
||||
checkpointer=checkpointer,
|
||||
engine=engine,
|
||||
session_factory=session_factory,
|
||||
setup=setup,
|
||||
aclose=aclose,
|
||||
)
|
||||
63
backend/packages/storage/store/persistence/factory.py
Normal file
63
backend/packages/storage/store/persistence/factory.py
Normal file
@ -0,0 +1,63 @@
|
||||
from sqlalchemy import URL
|
||||
|
||||
from store.common import DataBaseType
|
||||
from store.config.app_config import get_app_config
|
||||
from store.config.storage_config import StorageConfig
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
def _create_database_url(storage_config: StorageConfig) -> URL:
|
||||
"""Build an async SQLAlchemy URL from StorageConfig (sqlite/mysql/postgres)."""
|
||||
|
||||
if storage_config.driver == DataBaseType.sqlite:
|
||||
driver = "sqlite+aiosqlite"
|
||||
elif storage_config.driver == DataBaseType.mysql:
|
||||
driver = "mysql+aiomysql"
|
||||
elif storage_config.driver == DataBaseType.postgresql:
|
||||
driver = "postgresql+asyncpg"
|
||||
else:
|
||||
raise ValueError(f"Unsupported database driver: {storage_config.driver}")
|
||||
|
||||
if storage_config.driver == DataBaseType.sqlite:
|
||||
import os
|
||||
|
||||
db_path = storage_config.sqlite_storage_path
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
url = URL.create(
|
||||
drivername=driver,
|
||||
database=db_path,
|
||||
)
|
||||
else:
|
||||
url = URL.create(
|
||||
drivername=driver,
|
||||
username=storage_config.username,
|
||||
password=storage_config.password,
|
||||
host=storage_config.host,
|
||||
port=storage_config.port,
|
||||
database=storage_config.db_name or "deerflow",
|
||||
)
|
||||
|
||||
return url
|
||||
|
||||
|
||||
async def create_persistence() -> AppPersistence:
|
||||
from .drivers.mysql import build_mysql_persistence
|
||||
from .drivers.postgres import build_postgres_persistence
|
||||
from .drivers.sqlite import build_sqlite_persistence
|
||||
|
||||
app_config = get_app_config()
|
||||
|
||||
driver = app_config.storage.driver
|
||||
db_url = _create_database_url(app_config.storage)
|
||||
|
||||
if driver == "postgresql":
|
||||
return await build_postgres_persistence(db_url)
|
||||
|
||||
if driver == "mysql":
|
||||
return await build_mysql_persistence(db_url)
|
||||
|
||||
if driver == "sqlite":
|
||||
return await build_sqlite_persistence(db_url)
|
||||
|
||||
raise ValueError(f"Unsupported database driver: {driver}")
|
||||
@ -0,0 +1,3 @@
|
||||
from .close import close_in_order
|
||||
|
||||
__all__ = ["close_in_order"]
|
||||
29
backend/packages/storage/store/persistence/shared/close.py
Normal file
29
backend/packages/storage/store/persistence/shared/close.py
Normal file
@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
|
||||
AsyncCloser = Callable[[], Awaitable[None]]
|
||||
|
||||
|
||||
async def close_in_order(*closers: AsyncCloser) -> None:
|
||||
"""
|
||||
Run async closers in order and raise the first error, if any.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Used to keep driver-specific close logic readable.
|
||||
- We intentionally do not stop at first failure, so later resources
|
||||
still get a chance to close.
|
||||
"""
|
||||
first_error: Exception | None = None
|
||||
|
||||
for closer in closers:
|
||||
try:
|
||||
await closer()
|
||||
except Exception as exc:
|
||||
if first_error is None:
|
||||
first_error = exc
|
||||
|
||||
if first_error is not None:
|
||||
raise first_error
|
||||
23
backend/packages/storage/store/persistence/types.py
Normal file
23
backend/packages/storage/store/persistence/types.py
Normal file
@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||
|
||||
|
||||
AsyncSetup = Callable[[], Awaitable[None]]
|
||||
AsyncClose = Callable[[], Awaitable[None]]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AppPersistence:
|
||||
"""
|
||||
Unified runtime persistence bundle.
|
||||
"""
|
||||
checkpointer: Checkpointer
|
||||
engine: AsyncEngine
|
||||
session_factory: async_sessionmaker[AsyncSession]
|
||||
setup: AsyncSetup
|
||||
aclose: AsyncClose
|
||||
40
backend/packages/storage/store/repositories/__init__.py
Normal file
40
backend/packages/storage/store/repositories/__init__.py
Normal file
@ -0,0 +1,40 @@
|
||||
from store.repositories.contracts import (
|
||||
Feedback,
|
||||
FeedbackCreate,
|
||||
FeedbackRepositoryProtocol,
|
||||
Run,
|
||||
RunCreate,
|
||||
RunEvent,
|
||||
RunEventCreate,
|
||||
RunEventRepositoryProtocol,
|
||||
RunRepositoryProtocol,
|
||||
ThreadMeta,
|
||||
ThreadMetaCreate,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
)
|
||||
|
||||
from store.repositories.factory import (
|
||||
build_feedback_repository,
|
||||
build_run_event_repository,
|
||||
build_run_repository,
|
||||
build_thread_meta_repository,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Feedback",
|
||||
"FeedbackCreate",
|
||||
"FeedbackRepositoryProtocol",
|
||||
"Run",
|
||||
"RunCreate",
|
||||
"RunEvent",
|
||||
"RunEventCreate",
|
||||
"RunEventRepositoryProtocol",
|
||||
"RunRepositoryProtocol",
|
||||
"ThreadMeta",
|
||||
"ThreadMetaCreate",
|
||||
"ThreadMetaRepositoryProtocol",
|
||||
"build_run_repository",
|
||||
"build_run_event_repository",
|
||||
"build_thread_meta_repository",
|
||||
"build_feedback_repository",
|
||||
]
|
||||
@ -0,0 +1,35 @@
|
||||
from store.repositories.contracts.feedback import (
|
||||
Feedback,
|
||||
FeedbackCreate,
|
||||
FeedbackRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.contracts.run import (
|
||||
Run,
|
||||
RunCreate,
|
||||
RunRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.contracts.run_event import (
|
||||
RunEvent,
|
||||
RunEventCreate,
|
||||
RunEventRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.contracts.thread_meta import (
|
||||
ThreadMeta,
|
||||
ThreadMetaCreate,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Feedback",
|
||||
"FeedbackCreate",
|
||||
"FeedbackRepositoryProtocol",
|
||||
"Run",
|
||||
"RunCreate",
|
||||
"RunEvent",
|
||||
"RunEventCreate",
|
||||
"RunEventRepositoryProtocol",
|
||||
"RunRepositoryProtocol",
|
||||
"ThreadMeta",
|
||||
"ThreadMetaCreate",
|
||||
"ThreadMetaRepositoryProtocol",
|
||||
]
|
||||
@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class FeedbackCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
feedback_id: str
|
||||
run_id: str
|
||||
thread_id: str
|
||||
rating: int
|
||||
user_id: str | None = None
|
||||
message_id: str | None = None
|
||||
comment: str | None = None
|
||||
|
||||
|
||||
class Feedback(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
feedback_id: str
|
||||
run_id: str
|
||||
thread_id: str
|
||||
rating: int
|
||||
user_id: str | None
|
||||
message_id: str | None
|
||||
comment: str | None
|
||||
created_time: datetime
|
||||
|
||||
|
||||
class FeedbackRepositoryProtocol(Protocol):
|
||||
async def create_feedback(self, data: FeedbackCreate) -> Feedback: ...
|
||||
async def get_feedback(self, feedback_id: str) -> Feedback | None: ...
|
||||
async def list_feedback_by_run(self, run_id: str) -> list[Feedback]: ...
|
||||
async def list_feedback_by_thread(self, thread_id: str) -> list[Feedback]: ...
|
||||
async def delete_feedback(self, feedback_id: str) -> bool: ...
|
||||
74
backend/packages/storage/store/repositories/contracts/run.py
Normal file
74
backend/packages/storage/store/repositories/contracts/run.py
Normal file
@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class RunCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
run_id: str
|
||||
thread_id: str
|
||||
assistant_id: str | None = None
|
||||
user_id: str | None = None
|
||||
status: str = "pending"
|
||||
model_name: str | None = None
|
||||
multitask_strategy: str = "reject"
|
||||
follow_up_to_run_id: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class Run(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
run_id: str
|
||||
thread_id: str
|
||||
assistant_id: str | None
|
||||
user_id: str | None
|
||||
status: str
|
||||
model_name: str | None
|
||||
multitask_strategy: str
|
||||
error: str | None
|
||||
follow_up_to_run_id: str | None
|
||||
metadata: dict[str, Any]
|
||||
kwargs: dict[str, Any]
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
total_tokens: int
|
||||
llm_call_count: int
|
||||
lead_agent_tokens: int
|
||||
subagent_tokens: int
|
||||
middleware_tokens: int
|
||||
message_count: int
|
||||
first_human_message: str | None
|
||||
last_ai_message: str | None
|
||||
created_time: datetime
|
||||
updated_time: datetime | None
|
||||
|
||||
|
||||
class RunRepositoryProtocol(Protocol):
|
||||
async def create_run(self, data: RunCreate) -> Run: ...
|
||||
async def get_run(self, run_id: str) -> Run | None: ...
|
||||
async def list_runs_by_thread(self, thread_id: str, *, limit: int = 50, offset: int = 0) -> list[Run]: ...
|
||||
async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None: ...
|
||||
async def delete_run(self, run_id: str) -> None: ...
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
status: str,
|
||||
total_input_tokens: int = 0,
|
||||
total_output_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
llm_call_count: int = 0,
|
||||
lead_agent_tokens: int = 0,
|
||||
subagent_tokens: int = 0,
|
||||
middleware_tokens: int = 0,
|
||||
message_count: int = 0,
|
||||
first_human_message: str | None = None,
|
||||
last_ai_message: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None: ...
|
||||
@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class RunEventCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
thread_id: str
|
||||
run_id: str
|
||||
event_type: str
|
||||
category: str
|
||||
content: str | dict[str, Any] = ""
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: datetime | None = None
|
||||
|
||||
|
||||
class RunEvent(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
thread_id: str
|
||||
run_id: str
|
||||
event_type: str
|
||||
category: str
|
||||
content: str | dict[str, Any]
|
||||
metadata: dict[str, Any]
|
||||
seq: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class RunEventRepositoryProtocol(Protocol):
|
||||
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]: ...
|
||||
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[RunEvent]: ...
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
event_types: list[str] | None = None,
|
||||
limit: int = 500,
|
||||
) -> list[RunEvent]: ...
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[RunEvent]: ...
|
||||
|
||||
async def count_messages(self, thread_id: str) -> int: ...
|
||||
|
||||
async def delete_by_thread(self, thread_id: str) -> int: ...
|
||||
|
||||
async def delete_by_run(self, thread_id: str, run_id: str) -> int: ...
|
||||
@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ThreadMetaCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
thread_id: str
|
||||
assistant_id: str | None = None
|
||||
user_id: str | None = None
|
||||
display_name: str | None = None
|
||||
status: str = "idle"
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ThreadMeta(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
thread_id: str
|
||||
assistant_id: str | None
|
||||
user_id: str | None
|
||||
display_name: str | None
|
||||
status: str
|
||||
metadata: dict[str, Any]
|
||||
created_time: datetime
|
||||
updated_time: datetime | None
|
||||
|
||||
|
||||
class ThreadMetaRepositoryProtocol(Protocol):
|
||||
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta: ...
|
||||
|
||||
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None: ...
|
||||
|
||||
async def update_thread_meta(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
display_name: str | None = None,
|
||||
status: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None: ...
|
||||
|
||||
async def delete_thread(self, thread_id: str) -> None: ...
|
||||
|
||||
async def search_threads(
|
||||
self,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
user_id: str | None = None,
|
||||
assistant_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ThreadMeta]: ...
|
||||
11
backend/packages/storage/store/repositories/db/__init__.py
Normal file
11
backend/packages/storage/store/repositories/db/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
from store.repositories.db.feedback import DbFeedbackRepository
|
||||
from store.repositories.db.run import DbRunRepository
|
||||
from store.repositories.db.run_event import DbRunEventRepository
|
||||
from store.repositories.db.thread_meta import DbThreadMetaRepository
|
||||
|
||||
__all__ = [
|
||||
"DbFeedbackRepository",
|
||||
"DbRunRepository",
|
||||
"DbRunEventRepository",
|
||||
"DbThreadMetaRepository",
|
||||
]
|
||||
74
backend/packages/storage/store/repositories/db/feedback.py
Normal file
74
backend/packages/storage/store/repositories/db/feedback.py
Normal file
@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.feedback import Feedback, FeedbackCreate, FeedbackRepositoryProtocol
|
||||
from store.repositories.models.feedback import Feedback as FeedbackModel
|
||||
|
||||
|
||||
def _to_feedback(m: FeedbackModel) -> Feedback:
|
||||
return Feedback(
|
||||
feedback_id=m.feedback_id,
|
||||
run_id=m.run_id,
|
||||
thread_id=m.thread_id,
|
||||
rating=m.rating,
|
||||
user_id=m.user_id,
|
||||
message_id=m.message_id,
|
||||
comment=m.comment,
|
||||
created_time=m.created_time,
|
||||
)
|
||||
|
||||
|
||||
class DbFeedbackRepository(FeedbackRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_feedback(self, data: FeedbackCreate) -> Feedback:
|
||||
if data.rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {data.rating}")
|
||||
model = FeedbackModel(
|
||||
feedback_id=data.feedback_id,
|
||||
run_id=data.run_id,
|
||||
thread_id=data.thread_id,
|
||||
rating=data.rating,
|
||||
user_id=data.user_id,
|
||||
message_id=data.message_id,
|
||||
comment=data.comment,
|
||||
)
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_feedback(model)
|
||||
|
||||
async def get_feedback(self, feedback_id: str) -> Feedback | None:
|
||||
result = await self._session.execute(
|
||||
select(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_feedback(model) if model else None
|
||||
|
||||
async def list_feedback_by_run(self, run_id: str) -> list[Feedback]:
|
||||
result = await self._session.execute(
|
||||
select(FeedbackModel)
|
||||
.where(FeedbackModel.run_id == run_id)
|
||||
.order_by(FeedbackModel.created_time.desc())
|
||||
)
|
||||
return [_to_feedback(m) for m in result.scalars().all()]
|
||||
|
||||
async def list_feedback_by_thread(self, thread_id: str) -> list[Feedback]:
|
||||
result = await self._session.execute(
|
||||
select(FeedbackModel)
|
||||
.where(FeedbackModel.thread_id == thread_id)
|
||||
.order_by(FeedbackModel.created_time.desc())
|
||||
)
|
||||
return [_to_feedback(m) for m in result.scalars().all()]
|
||||
|
||||
async def delete_feedback(self, feedback_id: str) -> bool:
|
||||
existing = await self.get_feedback(feedback_id)
|
||||
if existing is None:
|
||||
return False
|
||||
await self._session.execute(
|
||||
delete(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id)
|
||||
)
|
||||
return True
|
||||
125
backend/packages/storage/store/repositories/db/run.py
Normal file
125
backend/packages/storage/store/repositories/db/run.py
Normal file
@ -0,0 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.run import Run, RunCreate, RunRepositoryProtocol
|
||||
from store.repositories.models.run import Run as RunModel
|
||||
|
||||
|
||||
def _to_run(m: RunModel) -> Run:
|
||||
return Run(
|
||||
run_id=m.run_id,
|
||||
thread_id=m.thread_id,
|
||||
assistant_id=m.assistant_id,
|
||||
user_id=m.user_id,
|
||||
status=m.status,
|
||||
model_name=m.model_name,
|
||||
multitask_strategy=m.multitask_strategy,
|
||||
error=m.error,
|
||||
follow_up_to_run_id=m.follow_up_to_run_id,
|
||||
metadata=dict(m.meta or {}),
|
||||
kwargs=dict(m.kwargs or {}),
|
||||
total_input_tokens=m.total_input_tokens,
|
||||
total_output_tokens=m.total_output_tokens,
|
||||
total_tokens=m.total_tokens,
|
||||
llm_call_count=m.llm_call_count,
|
||||
lead_agent_tokens=m.lead_agent_tokens,
|
||||
subagent_tokens=m.subagent_tokens,
|
||||
middleware_tokens=m.middleware_tokens,
|
||||
message_count=m.message_count,
|
||||
first_human_message=m.first_human_message,
|
||||
last_ai_message=m.last_ai_message,
|
||||
created_time=m.created_time,
|
||||
updated_time=m.updated_time,
|
||||
)
|
||||
|
||||
|
||||
class DbRunRepository(RunRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_run(self, data: RunCreate) -> Run:
|
||||
model = RunModel(
|
||||
run_id=data.run_id,
|
||||
thread_id=data.thread_id,
|
||||
assistant_id=data.assistant_id,
|
||||
user_id=data.user_id,
|
||||
status=data.status,
|
||||
model_name=data.model_name,
|
||||
multitask_strategy=data.multitask_strategy,
|
||||
follow_up_to_run_id=data.follow_up_to_run_id,
|
||||
meta=dict(data.metadata),
|
||||
kwargs=dict(data.kwargs),
|
||||
)
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_run(model)
|
||||
|
||||
async def get_run(self, run_id: str) -> Run | None:
|
||||
result = await self._session.execute(
|
||||
select(RunModel).where(RunModel.run_id == run_id)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_run(model) if model else None
|
||||
|
||||
async def list_runs_by_thread(
|
||||
self, thread_id: str, *, limit: int = 50, offset: int = 0
|
||||
) -> list[Run]:
|
||||
result = await self._session.execute(
|
||||
select(RunModel)
|
||||
.where(RunModel.thread_id == thread_id)
|
||||
.order_by(RunModel.created_time.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
return [_to_run(m) for m in result.scalars().all()]
|
||||
|
||||
async def update_run_status(
|
||||
self, run_id: str, status: str, *, error: str | None = None
|
||||
) -> None:
|
||||
values: dict = {"status": status}
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
await self._session.execute(
|
||||
update(RunModel).where(RunModel.run_id == run_id).values(**values)
|
||||
)
|
||||
|
||||
async def delete_run(self, run_id: str) -> None:
|
||||
await self._session.execute(delete(RunModel).where(RunModel.run_id == run_id))
|
||||
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
status: str,
|
||||
total_input_tokens: int = 0,
|
||||
total_output_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
llm_call_count: int = 0,
|
||||
lead_agent_tokens: int = 0,
|
||||
subagent_tokens: int = 0,
|
||||
middleware_tokens: int = 0,
|
||||
message_count: int = 0,
|
||||
first_human_message: str | None = None,
|
||||
last_ai_message: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
values = {
|
||||
"status": status,
|
||||
"total_input_tokens": total_input_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"llm_call_count": llm_call_count,
|
||||
"lead_agent_tokens": lead_agent_tokens,
|
||||
"subagent_tokens": subagent_tokens,
|
||||
"middleware_tokens": middleware_tokens,
|
||||
"message_count": message_count,
|
||||
"first_human_message": first_human_message,
|
||||
"last_ai_message": last_ai_message,
|
||||
"error": error,
|
||||
}
|
||||
await self._session.execute(
|
||||
update(RunModel).where(RunModel.run_id == run_id).values(**values)
|
||||
)
|
||||
180
backend/packages/storage/store/repositories/db/run_event.py
Normal file
180
backend/packages/storage/store/repositories/db/run_event.py
Normal file
@ -0,0 +1,180 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol
|
||||
from store.repositories.models.run_event import RunEvent as RunEventModel
|
||||
|
||||
|
||||
def _serialize_content(content: str | dict[str, Any], metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||||
if isinstance(content, dict):
|
||||
return json.dumps(content, default=str, ensure_ascii=False), {**metadata, "content_is_dict": True}
|
||||
return content, metadata
|
||||
|
||||
|
||||
def _deserialize_content(content: str, metadata: dict[str, Any]) -> str | dict[str, Any]:
|
||||
if not metadata.get("content_is_dict"):
|
||||
return content
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
|
||||
def _to_run_event(model: RunEventModel) -> RunEvent:
|
||||
raw_metadata = dict(model.meta or {})
|
||||
metadata = {key: value for key, value in raw_metadata.items() if key != "content_is_dict"}
|
||||
return RunEvent(
|
||||
thread_id=model.thread_id,
|
||||
run_id=model.run_id,
|
||||
event_type=model.event_type,
|
||||
category=model.category,
|
||||
content=_deserialize_content(model.content, raw_metadata),
|
||||
metadata=metadata,
|
||||
seq=model.seq,
|
||||
created_at=model.created_at,
|
||||
)
|
||||
|
||||
|
||||
class DbRunEventRepository(RunEventRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]:
|
||||
if not events:
|
||||
return []
|
||||
|
||||
thread_ids = {event.thread_id for event in events}
|
||||
seq_by_thread: dict[str, int] = {}
|
||||
for thread_id in thread_ids:
|
||||
max_seq = await self._session.scalar(
|
||||
select(func.max(RunEventModel.seq))
|
||||
.where(RunEventModel.thread_id == thread_id)
|
||||
.with_for_update()
|
||||
)
|
||||
seq_by_thread[thread_id] = max_seq or 0
|
||||
|
||||
rows: list[RunEventModel] = []
|
||||
|
||||
for event in events:
|
||||
seq_by_thread[event.thread_id] += 1
|
||||
content, metadata = _serialize_content(event.content, dict(event.metadata))
|
||||
row = RunEventModel(
|
||||
thread_id=event.thread_id,
|
||||
run_id=event.run_id,
|
||||
seq=seq_by_thread[event.thread_id],
|
||||
event_type=event.event_type,
|
||||
category=event.category,
|
||||
content=content,
|
||||
meta=metadata,
|
||||
)
|
||||
if event.created_at is not None:
|
||||
row.created_at = event.created_at
|
||||
self._session.add(row)
|
||||
rows.append(row)
|
||||
|
||||
await self._session.flush()
|
||||
return [_to_run_event(row) for row in rows]
|
||||
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[RunEvent]:
|
||||
stmt = select(RunEventModel).where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.category == "message",
|
||||
)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run_event(row) for row in result.scalars().all()]
|
||||
|
||||
stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
event_types: list[str] | None = None,
|
||||
limit: int = 500,
|
||||
) -> list[RunEvent]:
|
||||
stmt = select(RunEventModel).where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.run_id == run_id,
|
||||
)
|
||||
if event_types is not None:
|
||||
stmt = stmt.where(RunEventModel.event_type.in_(event_types))
|
||||
stmt = stmt.order_by(RunEventModel.seq.asc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run_event(row) for row in result.scalars().all()]
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[RunEvent]:
|
||||
stmt = (
|
||||
select(RunEventModel)
|
||||
.where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.run_id == run_id,
|
||||
RunEventModel.category == "message",
|
||||
)
|
||||
)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run_event(row) for row in result.scalars().all()]
|
||||
|
||||
stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
|
||||
async def count_messages(self, thread_id: str) -> int:
|
||||
count = await self._session.scalar(
|
||||
select(func.count())
|
||||
.select_from(RunEventModel)
|
||||
.where(RunEventModel.thread_id == thread_id, RunEventModel.category == "message")
|
||||
)
|
||||
return int(count or 0)
|
||||
|
||||
async def delete_by_thread(self, thread_id: str) -> int:
|
||||
count = await self._session.scalar(
|
||||
select(func.count()).select_from(RunEventModel).where(RunEventModel.thread_id == thread_id)
|
||||
)
|
||||
await self._session.execute(delete(RunEventModel).where(RunEventModel.thread_id == thread_id))
|
||||
return int(count or 0)
|
||||
|
||||
async def delete_by_run(self, thread_id: str, run_id: str) -> int:
|
||||
count = await self._session.scalar(
|
||||
select(func.count())
|
||||
.select_from(RunEventModel)
|
||||
.where(RunEventModel.thread_id == thread_id, RunEventModel.run_id == run_id)
|
||||
)
|
||||
await self._session.execute(
|
||||
delete(RunEventModel).where(RunEventModel.thread_id == thread_id, RunEventModel.run_id == run_id)
|
||||
)
|
||||
return int(count or 0)
|
||||
@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.thread_meta import ThreadMeta, ThreadMetaCreate, ThreadMetaRepositoryProtocol
|
||||
from store.repositories.models.thread_meta import ThreadMeta as ThreadMetaModel
|
||||
|
||||
|
||||
def _to_thread_meta(m: ThreadMetaModel) -> ThreadMeta:
|
||||
return ThreadMeta(
|
||||
thread_id=m.thread_id,
|
||||
assistant_id=m.assistant_id,
|
||||
user_id=m.user_id,
|
||||
display_name=m.display_name,
|
||||
status=m.status,
|
||||
metadata=dict(m.meta or {}),
|
||||
created_time=m.created_time,
|
||||
updated_time=m.updated_time,
|
||||
)
|
||||
|
||||
|
||||
class DbThreadMetaRepository(ThreadMetaRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
|
||||
model = ThreadMetaModel(
|
||||
thread_id=data.thread_id,
|
||||
assistant_id=data.assistant_id,
|
||||
user_id=data.user_id,
|
||||
display_name=data.display_name,
|
||||
status=data.status,
|
||||
meta=dict(data.metadata),
|
||||
)
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_thread_meta(model)
|
||||
|
||||
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
|
||||
result = await self._session.execute(select(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_thread_meta(model) if model else None
|
||||
|
||||
async def update_thread_meta(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
display_name: str | None = None,
|
||||
status: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
values: dict = {}
|
||||
if display_name is not None:
|
||||
values["display_name"] = display_name
|
||||
if status is not None:
|
||||
values["status"] = status
|
||||
if metadata is not None:
|
||||
values["meta"] = dict(metadata)
|
||||
if not values:
|
||||
return
|
||||
await self._session.execute(
|
||||
update(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id).values(**values))
|
||||
|
||||
async def delete_thread(self, thread_id: str) -> None:
|
||||
await self._session.execute(delete(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
|
||||
|
||||
async def search_threads(
|
||||
self,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
user_id: str | None = None,
|
||||
assistant_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ThreadMeta]:
|
||||
stmt = select(ThreadMetaModel)
|
||||
|
||||
if status is not None:
|
||||
stmt = stmt.where(ThreadMetaModel.status == status)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(ThreadMetaModel.user_id == user_id)
|
||||
if assistant_id is not None:
|
||||
stmt = stmt.where(ThreadMetaModel.assistant_id == assistant_id)
|
||||
if metadata:
|
||||
for key, value in metadata.items():
|
||||
stmt = stmt.where(ThreadMetaModel.meta[key].as_string() == str(value))
|
||||
|
||||
stmt = stmt.order_by(ThreadMetaModel.created_time.desc())
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_thread_meta(m) for m in result.scalars().all()]
|
||||
25
backend/packages/storage/store/repositories/factory.py
Normal file
25
backend/packages/storage/store/repositories/factory.py
Normal file
@ -0,0 +1,25 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories import (
|
||||
FeedbackRepositoryProtocol,
|
||||
RunEventRepositoryProtocol,
|
||||
RunRepositoryProtocol,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.db import DbFeedbackRepository, DbRunEventRepository, DbRunRepository, DbThreadMetaRepository
|
||||
|
||||
|
||||
def build_thread_meta_repository(session: AsyncSession) -> ThreadMetaRepositoryProtocol:
|
||||
return DbThreadMetaRepository(session)
|
||||
|
||||
|
||||
def build_run_repository(session: AsyncSession) -> RunRepositoryProtocol:
|
||||
return DbRunRepository(session)
|
||||
|
||||
|
||||
def build_feedback_repository(session: AsyncSession) -> FeedbackRepositoryProtocol:
|
||||
return DbFeedbackRepository(session)
|
||||
|
||||
|
||||
def build_run_event_repository(session: AsyncSession) -> RunEventRepositoryProtocol:
|
||||
return DbRunEventRepository(session)
|
||||
@ -0,0 +1,6 @@
|
||||
from store.repositories.models.feedback import Feedback
|
||||
from store.repositories.models.run import Run
|
||||
from store.repositories.models.run_event import RunEvent
|
||||
from store.repositories.models.thread_meta import ThreadMeta
|
||||
|
||||
__all__ = ["Feedback", "Run", "RunEvent", "ThreadMeta"]
|
||||
@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, id_key
|
||||
from store.utils import get_timezone
|
||||
|
||||
_tz = get_timezone()
|
||||
|
||||
|
||||
class Feedback(DataClassBase):
|
||||
"""Feedback table (create-only, no updated_time)."""
|
||||
|
||||
__tablename__ = "feedback"
|
||||
|
||||
id: Mapped[id_key] = mapped_column(init=False)
|
||||
|
||||
feedback_id: Mapped[str] = mapped_column(String(64), unique=True, index=True)
|
||||
run_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
rating: Mapped[int] = mapped_column(Integer)
|
||||
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||
message_id: Mapped[str | None] = mapped_column(String(64), default=None)
|
||||
comment: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||
|
||||
created_time: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=_tz.now,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
42
backend/packages/storage/store/repositories/models/run.py
Normal file
42
backend/packages/storage/store/repositories/models/run.py
Normal file
@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import Base, UniversalText, id_key
|
||||
|
||||
|
||||
class Run(Base):
|
||||
"""Run metadata table."""
|
||||
|
||||
__tablename__ = "runs"
|
||||
|
||||
id: Mapped[id_key] = mapped_column(init=False)
|
||||
|
||||
run_id: Mapped[str] = mapped_column(String(64), unique=True, index=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(64), default=None)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||
status: Mapped[str] = mapped_column(String(32), default="pending", index=True)
|
||||
model_name: Mapped[str | None] = mapped_column(String(128), default=None)
|
||||
multitask_strategy: Mapped[str] = mapped_column(String(32), default="reject")
|
||||
error: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64), default=None)
|
||||
|
||||
meta: Mapped[dict[str, Any]] = mapped_column("metadata", JSON, default_factory=dict)
|
||||
kwargs: Mapped[dict[str, Any]] = mapped_column(JSON, default_factory=dict)
|
||||
|
||||
total_input_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
total_output_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
total_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
llm_call_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
lead_agent_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
subagent_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
middleware_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
message_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
first_human_message: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||
last_ai_message: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, Integer, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, id_key
|
||||
from store.utils import get_timezone
|
||||
|
||||
_tz = get_timezone()
|
||||
|
||||
|
||||
class RunEvent(DataClassBase):
|
||||
"""Run event table."""
|
||||
|
||||
__tablename__ = "run_events"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("thread_id", "seq", name="uq_run_events_thread_seq"),
|
||||
{"comment": "Run event table."},
|
||||
)
|
||||
|
||||
id: Mapped[id_key] = mapped_column(init=False)
|
||||
|
||||
thread_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
run_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
seq: Mapped[int] = mapped_column(Integer, index=True)
|
||||
|
||||
event_type: Mapped[str] = mapped_column(String(128), index=True)
|
||||
category: Mapped[str] = mapped_column(String(64), index=True)
|
||||
content: Mapped[str] = mapped_column(UniversalText, default="")
|
||||
meta: Mapped[dict[str, Any]] = mapped_column("metadata", JSON, default_factory=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=_tz.now,
|
||||
sort_order=999,
|
||||
comment="Event timestamp",
|
||||
)
|
||||
@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import Base, id_key
|
||||
|
||||
|
||||
class ThreadMeta(Base):
|
||||
"""Thread metadata table."""
|
||||
|
||||
__tablename__ = "thread_meta"
|
||||
|
||||
id: Mapped[id_key] = mapped_column(init=False)
|
||||
|
||||
thread_id: Mapped[str] = mapped_column(String(64), unique=True, index=True)
|
||||
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(64), default=None)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||
display_name: Mapped[str | None] = mapped_column(String(255), default=None)
|
||||
status: Mapped[str] = mapped_column(String(32), default="idle", index=True)
|
||||
|
||||
meta: Mapped[dict[str, Any]] = mapped_column("metadata", JSON, default_factory=dict)
|
||||
3
backend/packages/storage/store/utils/__init__.py
Normal file
3
backend/packages/storage/store/utils/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .timezone import get_timezone
|
||||
|
||||
__all__ = ["get_timezone"]
|
||||
52
backend/packages/storage/store/utils/timezone.py
Normal file
52
backend/packages/storage/store/utils/timezone.py
Normal file
@ -0,0 +1,52 @@
|
||||
import zoneinfo
|
||||
from datetime import datetime
|
||||
from datetime import timezone as datetime_timezone
|
||||
|
||||
from store.config.app_config import get_app_config
|
||||
|
||||
# IANA identifiers that map to UTC — see https://en.wikipedia.org/wiki/List_of_tz_database_time_zones
|
||||
_UTC_IDENTIFIERS = frozenset({"Etc/UCT", "Etc/Universal", "Etc/UTC", "Etc/Zulu", "UCT", "Universal", "UTC", "Zulu"})
|
||||
|
||||
|
||||
class TimeZone:
|
||||
def __init__(self) -> None:
|
||||
app_config = get_app_config()
|
||||
if app_config.timezone in _UTC_IDENTIFIERS:
|
||||
self.tz_info = datetime_timezone.utc
|
||||
else:
|
||||
self.tz_info = zoneinfo.ZoneInfo(app_config.timezone)
|
||||
|
||||
def now(self) -> datetime:
|
||||
"""Return the current time in the configured timezone."""
|
||||
return datetime.now(self.tz_info)
|
||||
|
||||
def from_datetime(self, t: datetime) -> datetime:
|
||||
"""Convert a datetime to the configured timezone."""
|
||||
return t.astimezone(self.tz_info)
|
||||
|
||||
def from_str(self, t_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> datetime:
|
||||
"""Parse a time string and attach the configured timezone."""
|
||||
return datetime.strptime(t_str, format_str).replace(tzinfo=self.tz_info)
|
||||
|
||||
@staticmethod
|
||||
def to_str(t: datetime, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
|
||||
"""Format a datetime to string."""
|
||||
return t.strftime(format_str)
|
||||
|
||||
@staticmethod
|
||||
def to_utc(t: datetime | int) -> datetime:
|
||||
"""Convert a datetime or Unix timestamp to UTC."""
|
||||
if isinstance(t, datetime):
|
||||
return t.astimezone(datetime_timezone.utc)
|
||||
return datetime.fromtimestamp(t, tz=datetime_timezone.utc)
|
||||
|
||||
|
||||
_timezone = None
|
||||
|
||||
|
||||
def get_timezone() -> TimeZone:
|
||||
"""Return the global TimeZone singleton (lazy-initialized)."""
|
||||
global _timezone
|
||||
if _timezone is None:
|
||||
_timezone = TimeZone()
|
||||
return _timezone
|
||||
@ -34,7 +34,8 @@ markers = [
|
||||
]
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["packages/harness"]
|
||||
members = ["packages/harness", "packages/storage"]
|
||||
|
||||
[tool.uv.sources]
|
||||
deerflow-harness = { workspace = true }
|
||||
deerflow-storage = { workspace = true }
|
||||
|
||||
84
backend/uv.lock
generated
84
backend/uv.lock
generated
@ -14,6 +14,7 @@ resolution-markers = [
|
||||
members = [
|
||||
"deer-flow",
|
||||
"deerflow-harness",
|
||||
"deerflow-storage",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -136,6 +137,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/63/278a98c715ae467624eafe375542d8ba9b4383a016df8fdefe0ae28382a7/aiohttp-3.13.3-cp314-cp314t-win_amd64.whl", hash = "sha256:44531a36aa2264a1860089ffd4dce7baf875ee5a6079d5fb42e261c704ef7344", size = 499694, upload-time = "2026-01-03T17:32:24.546Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiomysql"
|
||||
version = "0.3.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pymysql" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/29/e0/302aeffe8d90853556f47f3106b89c16cc2ec2a4d269bdfd82e3f4ae12cc/aiomysql-0.3.2.tar.gz", hash = "sha256:72d15ef5cfc34c03468eb41e1b90adb9fd9347b0b589114bd23ead569a02ac1a", size = 108311, upload-time = "2025-10-22T00:15:21.278Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4c/af/aae0153c3e28712adaf462328f6c7a3c196a1c1c27b491de4377dd3e6b52/aiomysql-0.3.2-py3-none-any.whl", hash = "sha256:c82c5ba04137d7afd5c693a258bea8ead2aad77101668044143a991e04632eb2", size = 71834, upload-time = "2025-10-22T00:15:15.905Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiosignal"
|
||||
version = "1.4.0"
|
||||
@ -941,6 +954,54 @@ requires-dist = [
|
||||
]
|
||||
provides-extras = ["ollama", "postgres", "pymupdf"]
|
||||
|
||||
[[package]]
|
||||
name = "deerflow-storage"
|
||||
version = "0.1.0"
|
||||
source = { editable = "packages/storage" }
|
||||
dependencies = [
|
||||
{ name = "alembic" },
|
||||
{ name = "dotenv" },
|
||||
{ name = "langgraph" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "sqlalchemy", extra = ["asyncio"] },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
mysql = [
|
||||
{ name = "aiomysql" },
|
||||
{ name = "langgraph-checkpoint-mysql" },
|
||||
]
|
||||
postgres = [
|
||||
{ name = "asyncpg" },
|
||||
{ name = "langgraph-checkpoint-postgres" },
|
||||
{ name = "psycopg", extra = ["binary"] },
|
||||
{ name = "psycopg-pool" },
|
||||
]
|
||||
sqlite = [
|
||||
{ name = "aiosqlite" },
|
||||
{ name = "langgraph-checkpoint-sqlite" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "aiomysql", marker = "extra == 'mysql'", specifier = ">=0.2" },
|
||||
{ name = "aiosqlite", marker = "extra == 'sqlite'", specifier = ">=0.22.1" },
|
||||
{ name = "alembic", specifier = ">=1.13" },
|
||||
{ name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.29" },
|
||||
{ name = "dotenv", specifier = ">=0.9.9" },
|
||||
{ name = "langgraph", specifier = ">=1.0.6,<1.0.10" },
|
||||
{ name = "langgraph-checkpoint-mysql", marker = "extra == 'mysql'", specifier = ">=3.0.0" },
|
||||
{ name = "langgraph-checkpoint-postgres", marker = "extra == 'postgres'", specifier = ">=3.0.5" },
|
||||
{ name = "langgraph-checkpoint-sqlite", marker = "extra == 'sqlite'", specifier = ">=3.0.3" },
|
||||
{ name = "psycopg", extras = ["binary"], marker = "extra == 'postgres'", specifier = ">=3.3.3" },
|
||||
{ name = "psycopg-pool", marker = "extra == 'postgres'", specifier = ">=3.3.0" },
|
||||
{ name = "pydantic", specifier = ">=2.12.5" },
|
||||
{ name = "pyyaml", specifier = ">=6.0.3" },
|
||||
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0,<3.0" },
|
||||
]
|
||||
provides-extras = ["postgres", "mysql", "sqlite"]
|
||||
|
||||
[[package]]
|
||||
name = "defusedxml"
|
||||
version = "0.7.1"
|
||||
@ -1956,6 +2017,20 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4a/de/ddd53b7032e623f3c7bcdab2b44e8bf635e468f62e10e5ff1946f62c9356/langgraph_checkpoint-4.0.0-py3-none-any.whl", hash = "sha256:3fa9b2635a7c5ac28b338f631abf6a030c3b508b7b9ce17c22611513b589c784", size = 46329, upload-time = "2026-01-12T20:30:25.2Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langgraph-checkpoint-mysql"
|
||||
version = "3.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "langgraph-checkpoint" },
|
||||
{ name = "orjson" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e0/4e/0a6c78e5d3f2ca1525903c2363e721873594b6b77dd83537a6369193c474/langgraph_checkpoint_mysql-3.0.0.tar.gz", hash = "sha256:006aaa089f4c2fbd7b2c113b800ccd3dbb95f92203e656451677256b4b4f880f", size = 213142, upload-time = "2026-01-23T11:11:15.74Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/08/68/343103a7fae05523f9cecabbec2babdb737e66b4bf6ea48ae00c685ed11c/langgraph_checkpoint_mysql-3.0.0-py3-none-any.whl", hash = "sha256:7560ccd16e7596a047e15a307cec12dbd88fdcaab45a75759e5c6adef22a27d1", size = 38009, upload-time = "2026-01-23T11:11:14.697Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langgraph-checkpoint-postgres"
|
||||
version = "3.0.5"
|
||||
@ -3433,6 +3508,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/01/fc/a4977b84f9a7e70aac4c9beed55d4693b985cef89fab7d49c896335bf158/pymupdf4llm-1.27.2.2-py3-none-any.whl", hash = "sha256:ec3bbceed21c6f86289155f29c557aa54ae1c8282c4a45d6de984f16fb4c90cb", size = 84294, upload-time = "2026-03-20T09:45:55.365Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pymysql"
|
||||
version = "1.1.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f5/ae/1fe3fcd9f959efa0ebe200b8de88b5a5ce3e767e38c7ac32fb179f16a388/pymysql-1.1.2.tar.gz", hash = "sha256:4961d3e165614ae65014e361811a724e2044ad3ea3739de9903ae7c21f539f03", size = 48258, upload-time = "2025-08-24T12:55:55.146Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/7c/4c/ad33b92b9864cbde84f259d5df035a6447f91891f5be77788e2a3892bce3/pymysql-1.1.2-py3-none-any.whl", hash = "sha256:e6b1d89711dd51f8f74b1631fe08f039e7d76cf67a42a323d3178f0f25762ed9", size = 45300, upload-time = "2025-08-24T12:55:53.394Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pypdfium2"
|
||||
version = "5.3.0"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user