"""Thread CRUD, state, and history endpoints. Combines the existing thread-local filesystem cleanup with LangGraph Platform-compatible thread management backed by the checkpointer. Channel values returned in state responses are serialized through :func:`deerflow.runtime.serialization.serialize_channel_values` to ensure LangChain message objects are converted to JSON-safe dicts matching the LangGraph Platform wire format expected by the ``useStream`` React hook. """ from __future__ import annotations import logging import time import uuid from typing import Any from fastapi import APIRouter, HTTPException, Request from pydantic import BaseModel, Field from app.gateway.deps import get_checkpointer, get_store from deerflow.config.paths import Paths, get_paths from deerflow.runtime import serialize_channel_values # --------------------------------------------------------------------------- # Store namespace # --------------------------------------------------------------------------- THREADS_NS: tuple[str, ...] = ("threads",) """Namespace used by the Store for thread metadata records.""" logger = logging.getLogger(__name__) router = APIRouter(prefix="/api/threads", tags=["threads"]) def _sanitize_log_param(value: str) -> str: """Strip control characters to prevent log injection.""" return value.replace("\n", "").replace("\r", "").replace("\x00", "") # --------------------------------------------------------------------------- # Response / request models # --------------------------------------------------------------------------- class ThreadDeleteResponse(BaseModel): """Response model for thread cleanup.""" success: bool message: str class ThreadResponse(BaseModel): """Response model for a single thread.""" thread_id: str = Field(description="Unique thread identifier") status: str = Field(default="idle", description="Thread status: idle, busy, interrupted, error") created_at: str = Field(default="", description="ISO timestamp") updated_at: str = Field(default="", description="ISO timestamp") metadata: dict[str, Any] = Field(default_factory=dict, description="Thread metadata") values: dict[str, Any] = Field(default_factory=dict, description="Current state channel values") interrupts: dict[str, Any] = Field(default_factory=dict, description="Pending interrupts") class ThreadCreateRequest(BaseModel): """Request body for creating a thread.""" thread_id: str | None = Field(default=None, description="Optional thread ID (auto-generated if omitted)") assistant_id: str | None = Field(default=None, description="Associate thread with an assistant") metadata: dict[str, Any] = Field(default_factory=dict, description="Initial metadata") class ThreadSearchRequest(BaseModel): """Request body for searching threads.""" metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata filter (exact match)") limit: int = Field(default=100, ge=1, le=1000, description="Maximum results") offset: int = Field(default=0, ge=0, description="Pagination offset") status: str | None = Field(default=None, description="Filter by thread status") class ThreadStateResponse(BaseModel): """Response model for thread state.""" values: dict[str, Any] = Field(default_factory=dict, description="Current channel values") next: list[str] = Field(default_factory=list, description="Next tasks to execute") metadata: dict[str, Any] = Field(default_factory=dict, description="Checkpoint metadata") checkpoint: dict[str, Any] = Field(default_factory=dict, description="Checkpoint info") checkpoint_id: str | None = Field(default=None, description="Current checkpoint ID") parent_checkpoint_id: str | None = Field(default=None, description="Parent checkpoint ID") created_at: str | None = Field(default=None, description="Checkpoint timestamp") tasks: list[dict[str, Any]] = Field(default_factory=list, description="Interrupted task details") class ThreadPatchRequest(BaseModel): """Request body for patching thread metadata.""" metadata: dict[str, Any] = Field(default_factory=dict, description="Metadata to merge") class ThreadStateUpdateRequest(BaseModel): """Request body for updating thread state (human-in-the-loop resume).""" values: dict[str, Any] | None = Field(default=None, description="Channel values to merge") checkpoint_id: str | None = Field(default=None, description="Checkpoint to branch from") checkpoint: dict[str, Any] | None = Field(default=None, description="Full checkpoint object") as_node: str | None = Field(default=None, description="Node identity for the update") class HistoryEntry(BaseModel): """Single checkpoint history entry.""" checkpoint_id: str parent_checkpoint_id: str | None = None metadata: dict[str, Any] = Field(default_factory=dict) values: dict[str, Any] = Field(default_factory=dict) created_at: str | None = None next: list[str] = Field(default_factory=list) class ThreadHistoryRequest(BaseModel): """Request body for checkpoint history.""" limit: int = Field(default=10, ge=1, le=100, description="Maximum entries") before: str | None = Field(default=None, description="Cursor for pagination") # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDeleteResponse: """Delete local persisted filesystem data for a thread.""" path_manager = paths or get_paths() try: path_manager.delete_thread_dir(thread_id) except ValueError as exc: raise HTTPException(status_code=422, detail=str(exc)) from exc except FileNotFoundError: # Not critical — thread data may not exist on disk logger.debug("No local thread data to delete for %s", _sanitize_log_param(thread_id)) return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}") except Exception as exc: logger.exception("Failed to delete thread data for %s", _sanitize_log_param(thread_id)) raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc logger.info("Deleted local thread data for %s", _sanitize_log_param(thread_id)) return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}") async def _store_get(store, thread_id: str) -> dict | None: """Fetch a thread record from the Store; returns ``None`` if absent.""" item = await store.aget(THREADS_NS, thread_id) return item.value if item is not None else None async def _store_put(store, record: dict) -> None: """Write a thread record to the Store.""" await store.aput(THREADS_NS, record["thread_id"], record) async def _store_upsert(store, thread_id: str, *, metadata: dict | None = None, values: dict | None = None) -> None: """Create or refresh a thread record in the Store. On creation the record is written with ``status="idle"``. On update only ``updated_at`` (and optionally ``metadata`` / ``values``) are changed so that existing fields are preserved. ``values`` carries the agent-state snapshot exposed to the frontend (currently just ``{"title": "..."}``). """ now = time.time() existing = await _store_get(store, thread_id) if existing is None: await _store_put( store, { "thread_id": thread_id, "status": "idle", "created_at": now, "updated_at": now, "metadata": metadata or {}, "values": values or {}, }, ) else: val = dict(existing) val["updated_at"] = now if metadata: val.setdefault("metadata", {}).update(metadata) if values: val.setdefault("values", {}).update(values) await _store_put(store, val) def _derive_thread_status(checkpoint_tuple) -> str: """Derive thread status from checkpoint metadata.""" if checkpoint_tuple is None: return "idle" pending_writes = getattr(checkpoint_tuple, "pending_writes", None) or [] # Check for error in pending writes for pw in pending_writes: if len(pw) >= 2 and pw[1] == "__error__": return "error" # Check for pending next tasks (indicates interrupt) tasks = getattr(checkpoint_tuple, "tasks", None) if tasks: return "interrupted" return "idle" # --------------------------------------------------------------------------- # Endpoints # --------------------------------------------------------------------------- @router.delete("/{thread_id}", response_model=ThreadDeleteResponse) async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteResponse: """Delete local persisted filesystem data for a thread. Cleans DeerFlow-managed thread directories, removes checkpoint data, and removes the thread record from the Store. """ # Clean local filesystem response = _delete_thread_data(thread_id) # Remove from Store (best-effort) store = get_store(request) if store is not None: try: await store.adelete(THREADS_NS, thread_id) except Exception: logger.debug("Could not delete store record for thread %s (not critical)", _sanitize_log_param(thread_id)) # Remove checkpoints (best-effort) checkpointer = getattr(request.app.state, "checkpointer", None) if checkpointer is not None: try: if hasattr(checkpointer, "adelete_thread"): await checkpointer.adelete_thread(thread_id) except Exception: logger.debug("Could not delete checkpoints for thread %s (not critical)", _sanitize_log_param(thread_id)) return response @router.post("", response_model=ThreadResponse) async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadResponse: """Create a new thread. The thread record is written to the Store (for fast listing) and an empty checkpoint is written to the checkpointer (for state reads). Idempotent: returns the existing record when ``thread_id`` already exists. """ store = get_store(request) checkpointer = get_checkpointer(request) thread_id = body.thread_id or str(uuid.uuid4()) now = time.time() # Idempotency: return existing record from Store when already present if store is not None: existing_record = await _store_get(store, thread_id) if existing_record is not None: return ThreadResponse( thread_id=thread_id, status=existing_record.get("status", "idle"), created_at=str(existing_record.get("created_at", "")), updated_at=str(existing_record.get("updated_at", "")), metadata=existing_record.get("metadata", {}), ) # Write thread record to Store if store is not None: try: await _store_put( store, { "thread_id": thread_id, "status": "idle", "created_at": now, "updated_at": now, "metadata": body.metadata, }, ) except Exception: logger.exception("Failed to write thread %s to store", _sanitize_log_param(thread_id)) raise HTTPException(status_code=500, detail="Failed to create thread") # Write an empty checkpoint so state endpoints work immediately config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} try: from langgraph.checkpoint.base import empty_checkpoint ckpt_metadata = { "step": -1, "source": "input", "writes": None, "parents": {}, **body.metadata, "created_at": now, } await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {}) except Exception: logger.exception("Failed to create checkpoint for thread %s", _sanitize_log_param(thread_id)) raise HTTPException(status_code=500, detail="Failed to create thread") # Write thread_meta so the thread appears in /threads/search immediately from app.gateway.deps import get_thread_meta_repo thread_meta_repo = get_thread_meta_repo(request) if thread_meta_repo is not None: try: await thread_meta_repo.create( thread_id, assistant_id=getattr(body, "assistant_id", None), metadata=body.metadata, ) except Exception: logger.debug("Failed to upsert thread_meta on create for %s (non-fatal)", _sanitize_log_param(thread_id)) logger.info("Thread created: %s", _sanitize_log_param(thread_id)) return ThreadResponse( thread_id=thread_id, status="idle", created_at=str(now), updated_at=str(now), metadata=body.metadata, ) @router.post("/search", response_model=list[ThreadResponse]) async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]: """Search and list threads from the threads_meta table. NOTE: Migration from pre-persistence-layer deployments: Threads created via LangGraph Server before this change are NOT automatically indexed in threads_meta. They will not appear in search results until a new run is created on them (which triggers thread_meta upsert in services.py). For bulk migration, run: python -m deerflow.persistence.migrate_threads_from_checkpointer (migration script TBD in a follow-up PR) """ from app.gateway.deps import get_thread_meta_repo repo = get_thread_meta_repo(request) if repo is None: raise HTTPException(status_code=503, detail="Thread metadata store not available") rows = await repo.search( metadata=body.metadata or None, status=body.status, limit=body.limit, offset=body.offset, ) return [ ThreadResponse( thread_id=r["thread_id"], status=r.get("status", "idle"), created_at=r.get("created_at", ""), updated_at=r.get("updated_at", ""), metadata=r.get("metadata", {}), values={"title": r["display_name"]} if r.get("display_name") else {}, interrupts={}, ) for r in rows ] @router.patch("/{thread_id}", response_model=ThreadResponse) async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse: """Merge metadata into a thread record.""" store = get_store(request) if store is None: raise HTTPException(status_code=503, detail="Store not available") record = await _store_get(store, thread_id) if record is None: raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") now = time.time() updated = dict(record) updated.setdefault("metadata", {}).update(body.metadata) updated["updated_at"] = now try: await _store_put(store, updated) except Exception: logger.exception("Failed to patch thread %s", _sanitize_log_param(thread_id)) raise HTTPException(status_code=500, detail="Failed to update thread") return ThreadResponse( thread_id=thread_id, status=updated.get("status", "idle"), created_at=str(updated.get("created_at", "")), updated_at=str(now), metadata=updated.get("metadata", {}), ) @router.get("/{thread_id}", response_model=ThreadResponse) async def get_thread(thread_id: str, request: Request) -> ThreadResponse: """Get thread info. Reads metadata from the Store and derives the accurate execution status from the checkpointer. Falls back to the checkpointer alone for threads that pre-date Store adoption (backward compat). """ store = get_store(request) checkpointer = get_checkpointer(request) record: dict | None = None if store is not None: record = await _store_get(store, thread_id) # Derive accurate status from the checkpointer config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} try: checkpoint_tuple = await checkpointer.aget_tuple(config) except Exception: logger.exception("Failed to get checkpoint for thread %s", _sanitize_log_param(thread_id)) raise HTTPException(status_code=500, detail="Failed to get thread") if record is None and checkpoint_tuple is None: raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") # If the thread exists in the checkpointer but not the store (e.g. legacy # data), synthesize a minimal store record from the checkpoint metadata. if record is None and checkpoint_tuple is not None: ckpt_meta = getattr(checkpoint_tuple, "metadata", {}) or {} record = { "thread_id": thread_id, "status": "idle", "created_at": ckpt_meta.get("created_at", ""), "updated_at": ckpt_meta.get("updated_at", ckpt_meta.get("created_at", "")), "metadata": {k: v for k, v in ckpt_meta.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")}, } if record is None: raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") status = _derive_thread_status(checkpoint_tuple) if checkpoint_tuple is not None else record.get("status", "idle") checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} if checkpoint_tuple is not None else {} channel_values = checkpoint.get("channel_values", {}) return ThreadResponse( thread_id=thread_id, status=status, created_at=str(record.get("created_at", "")), updated_at=str(record.get("updated_at", "")), metadata=record.get("metadata", {}), values=serialize_channel_values(channel_values), ) @router.get("/{thread_id}/state", response_model=ThreadStateResponse) async def get_thread_state(thread_id: str, request: Request) -> ThreadStateResponse: """Get the latest state snapshot for a thread. Channel values are serialized to ensure LangChain message objects are converted to JSON-safe dicts. """ checkpointer = get_checkpointer(request) config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} try: checkpoint_tuple = await checkpointer.aget_tuple(config) except Exception: logger.exception("Failed to get state for thread %s", _sanitize_log_param(thread_id)) raise HTTPException(status_code=500, detail="Failed to get thread state") if checkpoint_tuple is None: raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} metadata = getattr(checkpoint_tuple, "metadata", {}) or {} checkpoint_id = None ckpt_config = getattr(checkpoint_tuple, "config", {}) if ckpt_config: checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id") channel_values = checkpoint.get("channel_values", {}) parent_config = getattr(checkpoint_tuple, "parent_config", None) parent_checkpoint_id = None if parent_config: parent_checkpoint_id = parent_config.get("configurable", {}).get("checkpoint_id") tasks_raw = getattr(checkpoint_tuple, "tasks", []) or [] next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")] tasks = [{"id": getattr(t, "id", ""), "name": getattr(t, "name", "")} for t in tasks_raw] return ThreadStateResponse( values=serialize_channel_values(channel_values), next=next_tasks, metadata=metadata, checkpoint={"id": checkpoint_id, "ts": str(metadata.get("created_at", ""))}, checkpoint_id=checkpoint_id, parent_checkpoint_id=parent_checkpoint_id, created_at=str(metadata.get("created_at", "")), tasks=tasks, ) @router.post("/{thread_id}/state", response_model=ThreadStateResponse) async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, request: Request) -> ThreadStateResponse: """Update thread state (e.g. for human-in-the-loop resume or title rename). Writes a new checkpoint that merges *body.values* into the latest channel values, then syncs any updated ``title`` field back to the Store so that ``/threads/search`` reflects the change immediately. """ checkpointer = get_checkpointer(request) store = get_store(request) # checkpoint_ns must be present in the config for aput — default to "" # (the root graph namespace). checkpoint_id is optional; omitting it # fetches the latest checkpoint for the thread. read_config: dict[str, Any] = { "configurable": { "thread_id": thread_id, "checkpoint_ns": "", } } if body.checkpoint_id: read_config["configurable"]["checkpoint_id"] = body.checkpoint_id try: checkpoint_tuple = await checkpointer.aget_tuple(read_config) except Exception: logger.exception("Failed to get state for thread %s", _sanitize_log_param(thread_id)) raise HTTPException(status_code=500, detail="Failed to get thread state") if checkpoint_tuple is None: raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") # Work on mutable copies so we don't accidentally mutate cached objects. checkpoint: dict[str, Any] = dict(getattr(checkpoint_tuple, "checkpoint", {}) or {}) metadata: dict[str, Any] = dict(getattr(checkpoint_tuple, "metadata", {}) or {}) channel_values: dict[str, Any] = dict(checkpoint.get("channel_values", {})) if body.values: channel_values.update(body.values) checkpoint["channel_values"] = channel_values metadata["updated_at"] = time.time() if body.as_node: metadata["source"] = "update" metadata["step"] = metadata.get("step", 0) + 1 metadata["writes"] = {body.as_node: body.values} # aput requires checkpoint_ns in the config — use the same config used for the # read (which always includes checkpoint_ns=""). Do NOT include checkpoint_id # so that aput generates a fresh checkpoint ID for the new snapshot. write_config: dict[str, Any] = { "configurable": { "thread_id": thread_id, "checkpoint_ns": "", } } try: new_config = await checkpointer.aput(write_config, checkpoint, metadata, {}) except Exception: logger.exception("Failed to update state for thread %s", _sanitize_log_param(thread_id)) raise HTTPException(status_code=500, detail="Failed to update thread state") new_checkpoint_id: str | None = None if isinstance(new_config, dict): new_checkpoint_id = new_config.get("configurable", {}).get("checkpoint_id") # Sync title changes to the Store so /threads/search reflects them immediately. if store is not None and body.values and "title" in body.values: try: await _store_upsert(store, thread_id, values={"title": body.values["title"]}) except Exception: logger.debug("Failed to sync title to store for thread %s (non-fatal)", _sanitize_log_param(thread_id)) return ThreadStateResponse( values=serialize_channel_values(channel_values), next=[], metadata=metadata, checkpoint_id=new_checkpoint_id, created_at=str(metadata.get("created_at", "")), ) @router.post("/{thread_id}/history", response_model=list[HistoryEntry]) async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request: Request) -> list[HistoryEntry]: """Get checkpoint history for a thread. Combines data from two sources: - **Checkpointer**: checkpoint_id, parent_checkpoint_id, metadata, values.title, values.thread_data - **RunEventStore**: values.messages (complete conversation history, not affected by summarization truncation) """ from app.gateway.deps import get_run_event_store checkpointer = get_checkpointer(request) config: dict[str, Any] = {"configurable": {"thread_id": thread_id}} if body.before: config["configurable"]["checkpoint_id"] = body.before # Fetch messages from event store (full history, not truncated by summarization) event_store = get_run_event_store(request) try: all_messages = await event_store.list_messages(thread_id, limit=10_000) except Exception: logger.warning("Failed to load messages from event store for thread %s", _sanitize_log_param(thread_id), exc_info=True) all_messages = [] entries: list[HistoryEntry] = [] is_latest_checkpoint = True try: async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit): ckpt_config = getattr(checkpoint_tuple, "config", {}) parent_config = getattr(checkpoint_tuple, "parent_config", None) metadata = getattr(checkpoint_tuple, "metadata", {}) or {} checkpoint = getattr(checkpoint_tuple, "checkpoint", {}) or {} checkpoint_id = ckpt_config.get("configurable", {}).get("checkpoint_id", "") parent_id = None if parent_config: parent_id = parent_config.get("configurable", {}).get("checkpoint_id") channel_values = checkpoint.get("channel_values", {}) # Build values: title + thread_data from checkpoint, messages from event store values: dict[str, Any] = {} if title := channel_values.get("title"): values["title"] = title if thread_data := channel_values.get("thread_data"): values["thread_data"] = thread_data # Attach all messages only to the latest (first) checkpoint entry if is_latest_checkpoint and all_messages: values["messages"] = [m.get("content", {}) for m in all_messages] is_latest_checkpoint = False # Derive next tasks tasks_raw = getattr(checkpoint_tuple, "tasks", []) or [] next_tasks = [t.name for t in tasks_raw if hasattr(t, "name")] # Strip LangGraph internal keys from metadata user_meta = {k: v for k, v in metadata.items() if k not in ("created_at", "updated_at", "step", "source", "writes", "parents")} # Keep step for ordering context if "step" in metadata: user_meta["step"] = metadata["step"] entries.append( HistoryEntry( checkpoint_id=checkpoint_id, parent_checkpoint_id=parent_id, metadata=user_meta, values=values, created_at=str(metadata.get("created_at", "")), next=next_tasks, ) ) except Exception: logger.exception("Failed to get history for thread %s", _sanitize_log_param(thread_id)) raise HTTPException(status_code=500, detail="Failed to get thread history") return entries