mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-27 20:28:16 +00:00
fix(persistence): address 22 review comments from CodeQL, Copilot, and Code Quality
Bug fixes: - Sanitize log params to prevent log injection (CodeQL) - Reset threads_meta.status to idle/error when run completes - Attach messages only to latest checkpoint in /history response - Write threads_meta on POST /threads so new threads appear in search Lint fixes: - Remove unused imports (journal.py, migrations/env.py, test_converters.py) - Convert lambda to named function (engine.py, Ruff E731) - Remove unused logger definitions in repos (Ruff F841) - Add logging to JSONL decode errors and empty except blocks - Separate assert side-effects in tests (CodeQL) - Remove unused local variables in tests (Ruff F841) - Fix max_trace_content truncation to use byte length, not char length Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
32f69674a5
commit
b94383c93a
@ -35,6 +35,11 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/api/threads", tags=["threads"])
|
||||
|
||||
|
||||
def _sanitize_log_param(value: str) -> str:
|
||||
"""Strip control characters to prevent log injection."""
|
||||
return value.replace("\n", "").replace("\r", "").replace("\x00", "")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Response / request models
|
||||
# ---------------------------------------------------------------------------
|
||||
@ -136,13 +141,13 @@ def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDel
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
except FileNotFoundError:
|
||||
# Not critical — thread data may not exist on disk
|
||||
logger.debug("No local thread data to delete for %s", thread_id)
|
||||
logger.debug("No local thread data to delete for %s", _sanitize_log_param(thread_id))
|
||||
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to delete thread data for %s", thread_id)
|
||||
logger.exception("Failed to delete thread data for %s", _sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
|
||||
|
||||
logger.info("Deleted local thread data for %s", thread_id)
|
||||
logger.info("Deleted local thread data for %s", _sanitize_log_param(thread_id))
|
||||
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
|
||||
|
||||
|
||||
@ -231,7 +236,7 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
||||
try:
|
||||
await store.adelete(THREADS_NS, thread_id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete store record for thread %s (not critical)", thread_id)
|
||||
logger.debug("Could not delete store record for thread %s (not critical)", _sanitize_log_param(thread_id))
|
||||
|
||||
# Remove checkpoints (best-effort)
|
||||
checkpointer = getattr(request.app.state, "checkpointer", None)
|
||||
@ -240,7 +245,7 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
||||
if hasattr(checkpointer, "adelete_thread"):
|
||||
await checkpointer.adelete_thread(thread_id)
|
||||
except Exception:
|
||||
logger.debug("Could not delete checkpoints for thread %s (not critical)", thread_id)
|
||||
logger.debug("Could not delete checkpoints for thread %s (not critical)", _sanitize_log_param(thread_id))
|
||||
|
||||
return response
|
||||
|
||||
@ -284,7 +289,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to write thread %s to store", thread_id)
|
||||
logger.exception("Failed to write thread %s to store", _sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
||||
|
||||
# Write an empty checkpoint so state endpoints work immediately
|
||||
@ -302,10 +307,24 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
}
|
||||
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
|
||||
except Exception:
|
||||
logger.exception("Failed to create checkpoint for thread %s", thread_id)
|
||||
logger.exception("Failed to create checkpoint for thread %s", _sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to create thread")
|
||||
|
||||
logger.info("Thread created: %s", thread_id)
|
||||
# Write thread_meta so the thread appears in /threads/search immediately
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
thread_meta_repo = get_thread_meta_repo(request)
|
||||
if thread_meta_repo is not None:
|
||||
try:
|
||||
await thread_meta_repo.create(
|
||||
thread_id,
|
||||
assistant_id=getattr(body, "assistant_id", None),
|
||||
metadata=body.metadata,
|
||||
)
|
||||
except Exception:
|
||||
logger.debug("Failed to upsert thread_meta on create for %s (non-fatal)", _sanitize_log_param(thread_id))
|
||||
|
||||
logger.info("Thread created: %s", _sanitize_log_param(thread_id))
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status="idle",
|
||||
@ -372,7 +391,7 @@ async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Reques
|
||||
try:
|
||||
await _store_put(store, updated)
|
||||
except Exception:
|
||||
logger.exception("Failed to patch thread %s", thread_id)
|
||||
logger.exception("Failed to patch thread %s", _sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to update thread")
|
||||
|
||||
return ThreadResponse(
|
||||
@ -404,7 +423,7 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
except Exception:
|
||||
logger.exception("Failed to get checkpoint for thread %s", thread_id)
|
||||
logger.exception("Failed to get checkpoint for thread %s", _sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread")
|
||||
|
||||
if record is None and checkpoint_tuple is None:
|
||||
@ -452,7 +471,7 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(config)
|
||||
except Exception:
|
||||
logger.exception("Failed to get state for thread %s", thread_id)
|
||||
logger.exception("Failed to get state for thread %s", _sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
||||
|
||||
if checkpoint_tuple is None:
|
||||
@ -514,7 +533,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
try:
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(read_config)
|
||||
except Exception:
|
||||
logger.exception("Failed to get state for thread %s", thread_id)
|
||||
logger.exception("Failed to get state for thread %s", _sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread state")
|
||||
|
||||
if checkpoint_tuple is None:
|
||||
@ -548,7 +567,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
try:
|
||||
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
|
||||
except Exception:
|
||||
logger.exception("Failed to update state for thread %s", thread_id)
|
||||
logger.exception("Failed to update state for thread %s", _sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to update thread state")
|
||||
|
||||
new_checkpoint_id: str | None = None
|
||||
@ -560,7 +579,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
||||
try:
|
||||
await _store_upsert(store, thread_id, values={"title": body.values["title"]})
|
||||
except Exception:
|
||||
logger.debug("Failed to sync title to store for thread %s (non-fatal)", thread_id)
|
||||
logger.debug("Failed to sync title to store for thread %s (non-fatal)", _sanitize_log_param(thread_id))
|
||||
|
||||
return ThreadStateResponse(
|
||||
values=serialize_channel_values(channel_values),
|
||||
@ -594,16 +613,12 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
||||
try:
|
||||
all_messages = await event_store.list_messages(thread_id, limit=10_000)
|
||||
except Exception:
|
||||
logger.warning("Failed to load messages from event store for thread %s", thread_id, exc_info=True)
|
||||
logger.warning("Failed to load messages from event store for thread %s", _sanitize_log_param(thread_id), exc_info=True)
|
||||
all_messages = []
|
||||
|
||||
# Group messages by run_id for per-checkpoint assembly
|
||||
messages_by_run: dict[str, list[dict]] = {}
|
||||
for msg in all_messages:
|
||||
run_id = msg.get("run_id", "")
|
||||
messages_by_run.setdefault(run_id, []).append(msg.get("content", {}))
|
||||
|
||||
entries: list[HistoryEntry] = []
|
||||
is_latest_checkpoint = True
|
||||
try:
|
||||
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
|
||||
ckpt_config = getattr(checkpoint_tuple, "config", {})
|
||||
@ -625,9 +640,10 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
||||
if thread_data := channel_values.get("thread_data"):
|
||||
values["thread_data"] = thread_data
|
||||
|
||||
# Attach all messages from event store (not just this checkpoint's run)
|
||||
if all_messages:
|
||||
# Attach all messages only to the latest (first) checkpoint entry
|
||||
if is_latest_checkpoint and all_messages:
|
||||
values["messages"] = [m.get("content", {}) for m in all_messages]
|
||||
is_latest_checkpoint = False
|
||||
|
||||
# Derive next tasks
|
||||
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
|
||||
@ -650,7 +666,7 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to get history for thread %s", thread_id)
|
||||
logger.exception("Failed to get history for thread %s", _sanitize_log_param(thread_id))
|
||||
raise HTTPException(status_code=500, detail="Failed to get thread history")
|
||||
|
||||
return entries
|
||||
|
||||
@ -18,6 +18,7 @@ from fastapi import HTTPException, Request
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_store, get_stream_bridge, get_thread_meta_repo
|
||||
from app.gateway.routers.threads import _sanitize_log_param
|
||||
from deerflow.runtime import (
|
||||
END_SENTINEL,
|
||||
HEARTBEAT_SENTINEL,
|
||||
@ -184,7 +185,7 @@ async def _upsert_thread_in_store(store, thread_id: str, metadata: dict | None)
|
||||
try:
|
||||
await _store_upsert(store, thread_id, metadata=metadata)
|
||||
except Exception:
|
||||
logger.warning("Failed to upsert thread %s in store (non-fatal)", thread_id)
|
||||
logger.warning("Failed to upsert thread %s in store (non-fatal)", _sanitize_log_param(thread_id))
|
||||
|
||||
|
||||
async def _sync_thread_title_after_run(
|
||||
@ -312,7 +313,7 @@ async def start_run(
|
||||
else:
|
||||
await thread_meta_repo.update_status(thread_id, "running")
|
||||
except Exception:
|
||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", thread_id)
|
||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", _sanitize_log_param(thread_id))
|
||||
|
||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||
graph_input = normalize_input(body.input)
|
||||
|
||||
@ -10,13 +10,15 @@ None and fall back to in-memory implementations.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
_json_serializer = lambda obj: json.dumps(obj, ensure_ascii=False)
|
||||
|
||||
def _json_serializer(obj: object) -> str:
|
||||
"""JSON serializer with ensure_ascii=False for Chinese character support."""
|
||||
return json.dumps(obj, ensure_ascii=False)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -106,7 +108,9 @@ async def init_engine(
|
||||
try:
|
||||
import deerflow.persistence.models # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
# Models package not yet available — tables won't be auto-created.
|
||||
# This is expected during initial scaffolding or minimal installs.
|
||||
logger.debug("deerflow.persistence.models not found; skipping auto-create tables")
|
||||
|
||||
try:
|
||||
async with _engine.begin() as conn:
|
||||
|
||||
@ -8,6 +8,7 @@ have their own schema lifecycle and must not be touched by Alembic.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
@ -17,9 +18,13 @@ from deerflow.persistence.base import Base
|
||||
|
||||
# Import all models so metadata is populated.
|
||||
try:
|
||||
import deerflow.persistence.models # noqa: F401
|
||||
import deerflow.persistence.models # noqa: F401 — register ORM models with Base.metadata
|
||||
except ImportError:
|
||||
pass
|
||||
# Models not available — migration will work with existing metadata only.
|
||||
logging.getLogger(__name__).warning(
|
||||
"Could not import deerflow.persistence.models; "
|
||||
"Alembic may not detect all tables"
|
||||
)
|
||||
|
||||
config = context.config
|
||||
if config.config_file_name is not None:
|
||||
|
||||
@ -5,7 +5,6 @@ Each method acquires its own short-lived session.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
@ -14,8 +13,6 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.feedback import FeedbackRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FeedbackRepository:
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
|
||||
@ -8,7 +8,6 @@ minutes -- we don't hold connections across long execution.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
@ -18,8 +17,6 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from deerflow.persistence.models.run import RunRow
|
||||
from deerflow.runtime.runs.store.base import RunStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RunRepository(RunStore):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
|
||||
@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
@ -11,8 +10,6 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.thread_meta import ThreadMetaRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThreadMetaRepository:
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
|
||||
@ -7,6 +7,7 @@ at ``max_trace_content`` bytes to avoid bloating the database.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
@ -15,6 +16,8 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.runtime.events.store.base import RunEventStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DbRunEventStore(RunEventStore):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, max_trace_content: int = 10240):
|
||||
@ -35,15 +38,19 @@ class DbRunEventStore(RunEventStore):
|
||||
try:
|
||||
d["content"] = json.loads(raw)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
# Content looked like JSON (content_is_dict flag) but failed to parse;
|
||||
# keep the raw string as-is.
|
||||
logger.debug("Failed to deserialize content as JSON for event seq=%s", d.get("seq"))
|
||||
return d
|
||||
|
||||
def _truncate_trace(self, category: str, content: str | dict, metadata: dict | None) -> tuple[str | dict, dict]:
|
||||
if category == "trace":
|
||||
text = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content
|
||||
if len(text) > self._max_trace_content:
|
||||
content = text[: self._max_trace_content]
|
||||
metadata = {**(metadata or {}), "content_truncated": True}
|
||||
encoded = text.encode("utf-8")
|
||||
if len(encoded) > self._max_trace_content:
|
||||
# Truncate by bytes, then decode back (may cut a multi-byte char, so use errors="ignore")
|
||||
content = encoded[: self._max_trace_content].decode("utf-8", errors="ignore")
|
||||
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
|
||||
return content, metadata or {}
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None):
|
||||
|
||||
@ -51,6 +51,7 @@ class JsonlRunEventStore(RunEventStore):
|
||||
record = json.loads(line)
|
||||
max_seq = max(max_seq, record.get("seq", 0))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSONL line in %s", f)
|
||||
continue
|
||||
self._seq_counters[thread_id] = max_seq
|
||||
|
||||
@ -73,6 +74,7 @@ class JsonlRunEventStore(RunEventStore):
|
||||
try:
|
||||
events.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSONL line in %s", f)
|
||||
continue
|
||||
events.sort(key=lambda e: e.get("seq", 0))
|
||||
return events
|
||||
@ -89,6 +91,7 @@ class JsonlRunEventStore(RunEventStore):
|
||||
try:
|
||||
events.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
logger.debug("Skipping malformed JSONL line in %s", path)
|
||||
continue
|
||||
events.sort(key=lambda e: e.get("seq", 0))
|
||||
return events
|
||||
|
||||
@ -135,7 +135,7 @@ class RunJournal(BaseCallbackHandler):
|
||||
self._llm_start_times[str(run_id)] = time.monotonic()
|
||||
|
||||
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from deerflow.runtime.converters import langchain_to_openai_completion, langchain_to_openai_message
|
||||
from deerflow.runtime.converters import langchain_to_openai_completion
|
||||
|
||||
try:
|
||||
message = response.generations[0][0].message
|
||||
|
||||
@ -17,7 +17,10 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from deerflow.runtime.serialization import serialize
|
||||
from deerflow.runtime.stream_bridge import StreamBridge
|
||||
@ -273,6 +276,14 @@ async def run_agent(
|
||||
except Exception:
|
||||
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
|
||||
|
||||
# Update threads_meta status based on run outcome
|
||||
if thread_meta_repo is not None:
|
||||
try:
|
||||
final_status = "idle" if record.status == RunStatus.success else record.status.value
|
||||
await thread_meta_repo.update_status(thread_id, final_status)
|
||||
except Exception:
|
||||
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
|
||||
|
||||
await bridge.publish_end(run_id)
|
||||
asyncio.create_task(bridge.cleanup(run_id, delay=60))
|
||||
|
||||
@ -294,7 +305,7 @@ def _lg_mode_to_sse_event(mode: str) -> str:
|
||||
return mode
|
||||
|
||||
|
||||
def _extract_human_message(graph_input: dict) -> "HumanMessage | None":
|
||||
def _extract_human_message(graph_input: dict) -> HumanMessage | None:
|
||||
"""Extract or construct a HumanMessage from graph_input for event recording.
|
||||
|
||||
Returns a LangChain HumanMessage so callers can use .model_dump() to get
|
||||
|
||||
@ -5,10 +5,7 @@ from __future__ import annotations
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.converters import (
|
||||
_infer_finish_reason,
|
||||
langchain_messages_to_openai,
|
||||
langchain_to_openai_completion,
|
||||
langchain_to_openai_message,
|
||||
|
||||
@ -117,14 +117,16 @@ class TestFeedbackRepository:
|
||||
async def test_delete(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
created = await repo.create(run_id="r1", thread_id="t1", rating=1)
|
||||
assert await repo.delete(created["feedback_id"]) is True
|
||||
deleted = await repo.delete(created["feedback_id"])
|
||||
assert deleted is True
|
||||
assert await repo.get(created["feedback_id"]) is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
assert await repo.delete("nonexistent") is False
|
||||
deleted = await repo.delete("nonexistent")
|
||||
assert deleted is False
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
|
||||
@ -225,6 +225,8 @@ class TestEngineLifecycle:
|
||||
|
||||
pytest.skip("asyncpg is installed -- cannot test missing-dep path")
|
||||
except ImportError:
|
||||
pass
|
||||
# asyncpg is not installed — this is the expected state for this test.
|
||||
# We proceed to verify that init_engine raises an actionable ImportError.
|
||||
pass # noqa: S110 — intentionally ignored
|
||||
with pytest.raises(ImportError, match="uv sync --extra postgres"):
|
||||
await init_engine("postgres", url="postgresql+asyncpg://x:x@localhost/x")
|
||||
|
||||
@ -456,7 +456,7 @@ class TestDictContentFlag:
|
||||
sf = get_session_factory()
|
||||
store = DbRunEventStore(sf)
|
||||
|
||||
record = await store.put(
|
||||
await store.put(
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
event_type="tool_end",
|
||||
@ -480,7 +480,7 @@ class TestDictContentFlag:
|
||||
sf = get_session_factory()
|
||||
store = DbRunEventStore(sf)
|
||||
|
||||
record = await store.put(
|
||||
await store.put(
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
event_type="tool_end",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user