mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
style: apply ruff format to persistence and runtime files
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
107b3143c3
commit
07954cf9d2
@ -616,7 +616,6 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
||||
logger.warning("Failed to load messages from event store for thread %s", _sanitize_log_param(thread_id), exc_info=True)
|
||||
all_messages = []
|
||||
|
||||
|
||||
entries: list[HistoryEntry] = []
|
||||
is_latest_checkpoint = True
|
||||
try:
|
||||
|
||||
@ -20,6 +20,7 @@ def _json_serializer(obj: object) -> str:
|
||||
"""JSON serializer with ensure_ascii=False for Chinese character support."""
|
||||
return json.dumps(obj, ensure_ascii=False)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_engine: AsyncEngine | None = None
|
||||
|
||||
@ -21,10 +21,7 @@ try:
|
||||
import deerflow.persistence.models # noqa: F401 — register ORM models with Base.metadata
|
||||
except ImportError:
|
||||
# Models not available — migration will work with existing metadata only.
|
||||
logging.getLogger(__name__).warning(
|
||||
"Could not import deerflow.persistence.models; "
|
||||
"Alembic may not detect all tables"
|
||||
)
|
||||
logging.getLogger(__name__).warning("Could not import deerflow.persistence.models; Alembic may not detect all tables")
|
||||
|
||||
config = context.config
|
||||
if config.config_file_name is not None:
|
||||
|
||||
@ -99,11 +99,7 @@ class ThreadMetaRepository:
|
||||
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
||||
"""Update the display_name (title) for a thread."""
|
||||
async with self._sf() as session:
|
||||
await session.execute(
|
||||
update(ThreadMetaRow)
|
||||
.where(ThreadMetaRow.thread_id == thread_id)
|
||||
.values(display_name=display_name, updated_at=datetime.now(UTC))
|
||||
)
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
|
||||
await session.commit()
|
||||
|
||||
async def update_status(self, thread_id: str, status: str) -> None:
|
||||
|
||||
@ -47,14 +47,16 @@ def langchain_to_openai_message(message: Any) -> dict:
|
||||
openai_tool_calls = []
|
||||
for tc in tool_calls:
|
||||
args = tc.get("args", {})
|
||||
openai_tool_calls.append({
|
||||
"id": tc.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.get("name", ""),
|
||||
"arguments": json.dumps(args) if not isinstance(args, str) else args,
|
||||
},
|
||||
})
|
||||
openai_tool_calls.append(
|
||||
{
|
||||
"id": tc.get("id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tc.get("name", ""),
|
||||
"arguments": json.dumps(args) if not isinstance(args, str) else args,
|
||||
},
|
||||
}
|
||||
)
|
||||
# If no text content, set content to null per OpenAI spec
|
||||
result["content"] = content if (isinstance(content, list) and content) or (isinstance(content, str) and content) else None
|
||||
result["tool_calls"] = openai_tool_calls
|
||||
|
||||
@ -65,11 +65,7 @@ class DbRunEventStore(RunEventStore):
|
||||
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||
max_seq = await session.scalar(
|
||||
select(func.max(RunEventRow.seq))
|
||||
.where(RunEventRow.thread_id == thread_id)
|
||||
.with_for_update()
|
||||
)
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||
seq = (max_seq or 0) + 1
|
||||
row = RunEventRow(
|
||||
thread_id=thread_id,
|
||||
@ -93,11 +89,7 @@ class DbRunEventStore(RunEventStore):
|
||||
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||
thread_id = events[0]["thread_id"]
|
||||
max_seq = await session.scalar(
|
||||
select(func.max(RunEventRow.seq))
|
||||
.where(RunEventRow.thread_id == thread_id)
|
||||
.with_for_update()
|
||||
)
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||
seq = max_seq or 0
|
||||
rows = []
|
||||
for e in events:
|
||||
|
||||
@ -357,15 +357,17 @@ class RunJournal(BaseCallbackHandler):
|
||||
# -- Internal methods --
|
||||
|
||||
def _put(self, *, event_type: str, category: str, content: str | dict = "", metadata: dict | None = None) -> None:
|
||||
self._buffer.append({
|
||||
"thread_id": self.thread_id,
|
||||
"run_id": self.run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
})
|
||||
self._buffer.append(
|
||||
{
|
||||
"thread_id": self.thread_id,
|
||||
"run_id": self.run_id,
|
||||
"event_type": event_type,
|
||||
"category": category,
|
||||
"content": content,
|
||||
"metadata": metadata or {},
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
)
|
||||
if len(self._buffer) >= self._flush_threshold:
|
||||
self._flush_sync()
|
||||
|
||||
@ -395,7 +397,9 @@ class RunJournal(BaseCallbackHandler):
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to flush %d events for run %s — returning to buffer",
|
||||
len(batch), self.run_id, exc_info=True,
|
||||
len(batch),
|
||||
self.run_id,
|
||||
exc_info=True,
|
||||
)
|
||||
# Return failed events to buffer for retry on next flush
|
||||
self._buffer = batch + self._buffer
|
||||
|
||||
@ -41,7 +41,10 @@ class TestFeedbackRepository:
|
||||
async def test_create_negative_with_comment(self, tmp_path):
|
||||
repo = await _make_feedback_repo(tmp_path)
|
||||
record = await repo.create(
|
||||
run_id="r1", thread_id="t1", rating=-1, comment="Response was inaccurate",
|
||||
run_id="r1",
|
||||
thread_id="t1",
|
||||
rating=-1,
|
||||
comment="Response was inaccurate",
|
||||
)
|
||||
assert record["rating"] == -1
|
||||
assert record["comment"] == "Response was inaccurate"
|
||||
|
||||
@ -947,8 +947,10 @@ class TestFullRunSequence:
|
||||
# 1. Human message (written by worker, using model_dump format)
|
||||
human_msg = HumanMessage(content="Search for quantum computing")
|
||||
await store.put(
|
||||
thread_id="t1", run_id="r1",
|
||||
event_type="human_message", category="message",
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content=human_msg.model_dump(),
|
||||
)
|
||||
j.set_first_human_message("Search for quantum computing")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user