mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
feat(threads): switch search endpoint to threads_meta table and sync title
- POST /api/threads/search now queries threads_meta table directly, removing the two-phase Store + Checkpointer scan approach - Add ThreadMetaRepository.search() with metadata/status filters - Add ThreadMetaRepository.update_display_name() for title sync - Worker syncs checkpoint title to threads_meta.display_name on run completion - Map display_name to values.title in search response for API compatibility Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
52e7acafee
commit
35001c7c73
@ -317,107 +317,31 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
|
||||
@router.post("/search", response_model=list[ThreadResponse])
|
||||
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
|
||||
"""Search and list threads.
|
||||
"""Search and list threads from the threads_meta table."""
|
||||
from app.gateway.deps import get_thread_meta_repo
|
||||
|
||||
Two-phase approach:
|
||||
repo = get_thread_meta_repo(request)
|
||||
if repo is None:
|
||||
raise HTTPException(status_code=503, detail="Thread metadata store not available")
|
||||
|
||||
**Phase 1 — Store (fast path, O(threads))**: returns threads that were
|
||||
created or run through this Gateway. Store records are tiny metadata
|
||||
dicts so fetching all of them at once is cheap.
|
||||
|
||||
**Phase 2 — Checkpointer supplement (lazy migration)**: threads that
|
||||
were created directly by LangGraph Server (and therefore absent from the
|
||||
Store) are discovered here by iterating the shared checkpointer. Any
|
||||
newly found thread is immediately written to the Store so that the next
|
||||
search skips Phase 2 for that thread — the Store converges to a full
|
||||
index over time without a one-shot migration job.
|
||||
"""
|
||||
store = get_store(request)
|
||||
checkpointer = get_checkpointer(request)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 1: Store
|
||||
# -----------------------------------------------------------------------
|
||||
merged: dict[str, ThreadResponse] = {}
|
||||
|
||||
if store is not None:
|
||||
try:
|
||||
items = await store.asearch(THREADS_NS, limit=10_000)
|
||||
except Exception:
|
||||
logger.warning("Store search failed — falling back to checkpointer only", exc_info=True)
|
||||
items = []
|
||||
|
||||
for item in items:
|
||||
val = item.value
|
||||
merged[val["thread_id"]] = ThreadResponse(
|
||||
thread_id=val["thread_id"],
|
||||
status=val.get("status", "idle"),
|
||||
created_at=str(val.get("created_at", "")),
|
||||
updated_at=str(val.get("updated_at", "")),
|
||||
metadata=val.get("metadata", {}),
|
||||
values=val.get("values", {}),
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 2: Checkpointer supplement
|
||||
# Discovers threads not yet in the Store (e.g. created by LangGraph
|
||||
# Server) and lazily migrates them so future searches skip this phase.
|
||||
# -----------------------------------------------------------------------
|
||||
try:
|
||||
async for checkpoint_tuple in checkpointer.alist(None):
|
||||
cfg = getattr(checkpoint_tuple, "config", {})
|
||||
thread_id = cfg.get("configurable", {}).get("thread_id")
|
||||
if not thread_id or thread_id in merged:
|
||||
continue
|
||||
|
||||
# Skip sub-graph checkpoints (checkpoint_ns is non-empty for those)
|
||||
if cfg.get("configurable", {}).get("checkpoint_ns", ""):
|
||||
continue
|
||||
|
||||
ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {}
|
||||
# Strip LangGraph internal keys from the user-visible metadata dict
|
||||
user_meta = {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}
|
||||
|
||||
# Extract state values (title) from the checkpoint's channel_values
|
||||
checkpoint_data = getattr(checkpoint_tuple, "checkpoint", {}) or {}
|
||||
channel_values = checkpoint_data.get("channel_values", {})
|
||||
ckpt_values = {}
|
||||
if title := channel_values.get("title"):
|
||||
ckpt_values["title"] = title
|
||||
|
||||
thread_resp = ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
status=_derive_thread_status(checkpoint_tuple),
|
||||
created_at=str(ckpt_meta.get("created_at", "")),
|
||||
updated_at=str(ckpt_meta.get("updated_at", ckpt_meta.get("created_at", ""))),
|
||||
metadata=user_meta,
|
||||
values=ckpt_values,
|
||||
)
|
||||
merged[thread_id] = thread_resp
|
||||
|
||||
# Lazy migration — write to Store so the next search finds it there
|
||||
if store is not None:
|
||||
try:
|
||||
await _store_upsert(store, thread_id, metadata=user_meta, values=ckpt_values or None)
|
||||
except Exception:
|
||||
logger.debug("Failed to migrate thread %s to store (non-fatal)", thread_id)
|
||||
except Exception:
|
||||
logger.exception("Checkpointer scan failed during thread search")
|
||||
# Don't raise — return whatever was collected from Store + partial scan
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Phase 3: Filter → sort → paginate
|
||||
# -----------------------------------------------------------------------
|
||||
results = list(merged.values())
|
||||
|
||||
if body.metadata:
|
||||
results = [r for r in results if all(r.metadata.get(k) == v for k, v in body.metadata.items())]
|
||||
|
||||
if body.status:
|
||||
results = [r for r in results if r.status == body.status]
|
||||
|
||||
results.sort(key=lambda r: r.updated_at, reverse=True)
|
||||
return results[body.offset : body.offset + body.limit]
|
||||
rows = await repo.search(
|
||||
metadata=body.metadata or None,
|
||||
status=body.status,
|
||||
limit=body.limit,
|
||||
offset=body.offset,
|
||||
)
|
||||
return [
|
||||
ThreadResponse(
|
||||
thread_id=r["thread_id"],
|
||||
status=r.get("status", "idle"),
|
||||
created_at=r.get("created_at", ""),
|
||||
updated_at=r.get("updated_at", ""),
|
||||
metadata=r.get("metadata", {}),
|
||||
values={"title": r["display_name"]} if r.get("display_name") else {},
|
||||
interrupts={},
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
|
||||
@router.patch("/{thread_id}", response_model=ThreadResponse)
|
||||
|
||||
@ -323,6 +323,7 @@ async def start_run(
|
||||
event_store=event_store,
|
||||
run_events_config=run_events_config,
|
||||
follow_up_to_run_id=follow_up_to_run_id,
|
||||
thread_meta_repo=thread_meta_repo,
|
||||
)
|
||||
)
|
||||
record.task = task
|
||||
|
||||
@ -78,6 +78,37 @@ class ThreadMetaRepository:
|
||||
return True
|
||||
return row.owner_id == owner_id
|
||||
|
||||
async def search(
|
||||
self,
|
||||
*,
|
||||
metadata: dict | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[dict]:
|
||||
"""Search threads with optional metadata and status filters."""
|
||||
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
||||
if status:
|
||||
stmt = stmt.where(ThreadMetaRow.status == status)
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
rows = [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
if metadata:
|
||||
rows = [r for r in rows if all(r.get("metadata", {}).get(k) == v for k, v in metadata.items())]
|
||||
return rows
|
||||
|
||||
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
||||
"""Update the display_name (title) for a thread."""
|
||||
async with self._sf() as session:
|
||||
await session.execute(
|
||||
update(ThreadMetaRow)
|
||||
.where(ThreadMetaRow.thread_id == thread_id)
|
||||
.values(display_name=display_name, updated_at=datetime.now(UTC))
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def update_status(self, thread_id: str, status: str) -> None:
|
||||
async with self._sf() as session:
|
||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
||||
|
||||
@ -48,6 +48,7 @@ async def run_agent(
|
||||
event_store: Any | None = None,
|
||||
run_events_config: Any | None = None,
|
||||
follow_up_to_run_id: str | None = None,
|
||||
thread_meta_repo: Any | None = None,
|
||||
) -> None:
|
||||
"""Execute an agent in the background, publishing events to *bridge*."""
|
||||
|
||||
@ -262,6 +263,19 @@ async def run_agent(
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
|
||||
|
||||
# Sync title from checkpoint to threads_meta.display_name
|
||||
if thread_meta_repo is not None and checkpointer is not None:
|
||||
try:
|
||||
ckpt_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
ckpt_tuple = await checkpointer.aget_tuple(ckpt_config)
|
||||
if ckpt_tuple is not None:
|
||||
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
|
||||
title = ckpt.get("channel_values", {}).get("title")
|
||||
if title:
|
||||
await thread_meta_repo.update_display_name(thread_id, title)
|
||||
except Exception:
|
||||
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
|
||||
|
||||
await bridge.publish_end(run_id)
|
||||
asyncio.create_task(bridge.cleanup(run_id, delay=60))
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user