mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-29 13:28:11 +00:00
refactor(persistence): remove UTFJSON, use engine-level json_serializer + datetime.now()
- Replace custom UTFJSON type with standard sqlalchemy.JSON in all ORM models. Add json_serializer=json.dumps(ensure_ascii=False) to all create_async_engine calls so non-ASCII text (Chinese etc.) is stored as-is in both SQLite and Postgres. - Change ORM datetime defaults from datetime.now(UTC) to datetime.now(), remove UTC imports. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
14c5f4b798
commit
3b4622a26f
@ -12,14 +12,45 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
_json_serializer = lambda obj: json.dumps(obj, ensure_ascii=False)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_engine: AsyncEngine | None = None
|
||||
_session_factory: async_sessionmaker[AsyncSession] | None = None
|
||||
|
||||
|
||||
async def _auto_create_postgres_db(url: str) -> None:
|
||||
"""Connect to the ``postgres`` maintenance DB and CREATE DATABASE.
|
||||
|
||||
The target database name is extracted from *url*. The connection is
|
||||
made to the default ``postgres`` database on the same server using
|
||||
``AUTOCOMMIT`` isolation (CREATE DATABASE cannot run inside a
|
||||
transaction).
|
||||
"""
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
parsed = make_url(url)
|
||||
db_name = parsed.database
|
||||
if not db_name:
|
||||
raise ValueError("Cannot auto-create database: no database name in URL")
|
||||
|
||||
# Connect to the default 'postgres' database to issue CREATE DATABASE
|
||||
maint_url = parsed.set(database="postgres")
|
||||
maint_engine = create_async_engine(maint_url, isolation_level="AUTOCOMMIT")
|
||||
try:
|
||||
async with maint_engine.connect() as conn:
|
||||
await conn.execute(text(f'CREATE DATABASE "{db_name}"'))
|
||||
logger.info("Auto-created PostgreSQL database: %s", db_name)
|
||||
finally:
|
||||
await maint_engine.dispose()
|
||||
|
||||
|
||||
async def init_engine(
|
||||
backend: str,
|
||||
*,
|
||||
@ -53,13 +84,14 @@ async def init_engine(
|
||||
import os
|
||||
|
||||
os.makedirs(sqlite_dir or ".", exist_ok=True)
|
||||
_engine = create_async_engine(url, echo=echo)
|
||||
_engine = create_async_engine(url, echo=echo, json_serializer=_json_serializer)
|
||||
elif backend == "postgres":
|
||||
_engine = create_async_engine(
|
||||
url,
|
||||
echo=echo,
|
||||
pool_size=pool_size,
|
||||
pool_pre_ping=True,
|
||||
json_serializer=_json_serializer,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown persistence backend: {backend!r}")
|
||||
@ -76,8 +108,21 @@ async def init_engine(
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
async with _engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
try:
|
||||
async with _engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
except Exception as exc:
|
||||
if backend == "postgres" and "does not exist" in str(exc):
|
||||
# Database not yet created — attempt to auto-create it, then retry.
|
||||
await _auto_create_postgres_db(url)
|
||||
# Rebuild engine against the now-existing database
|
||||
await _engine.dispose()
|
||||
_engine = create_async_engine(url, echo=echo, pool_size=pool_size, pool_pre_ping=True, json_serializer=_json_serializer)
|
||||
_session_factory = async_sessionmaker(_engine, expire_on_commit=False)
|
||||
async with _engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
else:
|
||||
raise
|
||||
|
||||
logger.info("Persistence engine initialized: backend=%s", backend)
|
||||
|
||||
|
||||
@ -2,9 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, Text
|
||||
from sqlalchemy import DateTime, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
@ -27,4 +27,4 @@ class FeedbackRow(Base):
|
||||
comment: Mapped[str | None] = mapped_column(Text)
|
||||
# Optional text feedback from the user
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now())
|
||||
|
||||
@ -2,9 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import JSON, Index, String, Text
|
||||
from sqlalchemy import JSON, DateTime, Index, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
@ -43,7 +43,7 @@ class RunRow(Base):
|
||||
# Follow-up association
|
||||
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64))
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))
|
||||
updated_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(), onupdate=lambda: datetime.now())
|
||||
|
||||
__table_args__ = (Index("ix_runs_thread_status", "thread_id", "status"),)
|
||||
|
||||
@ -2,9 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import JSON, Index, String, Text
|
||||
from sqlalchemy import JSON, DateTime, Index, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
@ -22,7 +22,7 @@ class RunEventRow(Base):
|
||||
content: Mapped[str] = mapped_column(Text, default="")
|
||||
event_metadata: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
seq: Mapped[int] = mapped_column(nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now())
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
|
||||
|
||||
@ -2,9 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import JSON, String
|
||||
from sqlalchemy import JSON, DateTime, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
@ -19,5 +19,5 @@ class ThreadMetaRow(Base):
|
||||
display_name: Mapped[str | None] = mapped_column(String(256))
|
||||
status: Mapped[str] = mapped_column(String(20), default="idle")
|
||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))
|
||||
updated_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(), onupdate=lambda: datetime.now())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user