From 17447fccbe91aa685f5363d85ee2b5c0afa323ce Mon Sep 17 00:00:00 2001 From: KiteEater <145987840+Kiteeater@users.noreply.github.com> Date: Sat, 2 May 2026 11:25:45 +0800 Subject: [PATCH] fix(runtime): make rollback restore checkpoint supersede newer checkpoints (#2582) * Restore rollback checkpoints with fresh ids * Tighten rollback checkpoint tests and imports * Update test_run_worker_rollback.py --------- Co-authored-by: Willem Jiang --- backend/app/gateway/routers/threads.py | 3 +- .../harness/deerflow/runtime/runs/worker.py | 13 ++++ backend/pyproject.toml | 1 - backend/tests/test_run_worker_rollback.py | 77 +++++++++++++++---- 4 files changed, 75 insertions(+), 19 deletions(-) diff --git a/backend/app/gateway/routers/threads.py b/backend/app/gateway/routers/threads.py index 484582839..253717d11 100644 --- a/backend/app/gateway/routers/threads.py +++ b/backend/app/gateway/routers/threads.py @@ -18,6 +18,7 @@ import uuid from typing import Any from fastapi import APIRouter, HTTPException, Request +from langgraph.checkpoint.base import empty_checkpoint from pydantic import BaseModel, Field, field_validator from app.gateway.authz import require_permission @@ -262,8 +263,6 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe # Write an empty checkpoint so state endpoints work immediately config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} try: - from langgraph.checkpoint.base import empty_checkpoint - ckpt_metadata = { "step": -1, "source": "input", diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index d8f9c139b..2aecb9a1b 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -23,6 +23,8 @@ from dataclasses import dataclass, field from functools import lru_cache from typing import TYPE_CHECKING, Any, Literal, cast +from langgraph.checkpoint.base import empty_checkpoint + if TYPE_CHECKING: from langchain_core.messages import HumanMessage @@ -442,6 +444,12 @@ async def _rollback_to_pre_run_checkpoint( if checkpoint_to_restore.get("id") is None: logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id) return + restore_marker = _new_checkpoint_marker() + checkpoint_to_restore = { + **checkpoint_to_restore, + "id": restore_marker["id"], + "ts": restore_marker["ts"], + } metadata = pre_run_snapshot.get("metadata", {}) metadata_to_restore = metadata if isinstance(metadata, dict) else {} raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns") @@ -493,6 +501,11 @@ async def _rollback_to_pre_run_checkpoint( ) +def _new_checkpoint_marker() -> dict[str, str]: + marker = empty_checkpoint() + return {"id": marker["id"], "ts": marker["ts"]} + + def _lg_mode_to_sse_event(mode: str) -> str: """Map LangGraph internal stream_mode name to SSE event name. diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 64c6e74c3..1b74a77c4 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -47,4 +47,3 @@ members = ["packages/harness"] [tool.uv.sources] deerflow-harness = { workspace = true } - diff --git a/backend/tests/test_run_worker_rollback.py b/backend/tests/test_run_worker_rollback.py index 0c99663ad..0a4421e2f 100644 --- a/backend/tests/test_run_worker_rollback.py +++ b/backend/tests/test_run_worker_rollback.py @@ -3,6 +3,8 @@ from types import SimpleNamespace from unittest.mock import AsyncMock, call import pytest +from langgraph.checkpoint.base import empty_checkpoint +from langgraph.checkpoint.memory import InMemorySaver from deerflow.runtime.runs.manager import RunManager from deerflow.runtime.runs.schemas import RunStatus @@ -16,6 +18,14 @@ class FakeCheckpointer: self.aput_writes = AsyncMock() +def _make_checkpoint(checkpoint_id: str, messages: list[str], version: int): + checkpoint = empty_checkpoint() + checkpoint["id"] = checkpoint_id + checkpoint["channel_values"] = {"messages": messages} + checkpoint["channel_versions"] = {"messages": version} + return checkpoint + + def test_build_runtime_context_includes_app_config_when_present(): app_config = object() @@ -110,16 +120,16 @@ async def test_rollback_restores_snapshot_without_deleting_thread(): ) checkpointer.adelete_thread.assert_not_awaited() - checkpointer.aput.assert_awaited_once_with( - {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}, - { - "id": "ckpt-1", - "channel_versions": {"messages": 3}, - "channel_values": {"messages": ["before"]}, - }, - {"source": "input"}, - {"messages": 3}, - ) + checkpointer.aput.assert_awaited_once() + restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args + assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + assert restored_checkpoint["id"] != "ckpt-1" + assert "channel_versions" in restored_checkpoint + assert "channel_values" in restored_checkpoint + assert restored_checkpoint["channel_versions"] == {"messages": 3} + assert restored_checkpoint["channel_values"] == {"messages": ["before"]} + assert restored_metadata == {"source": "input"} + assert new_versions == {"messages": 3} assert checkpointer.aput_writes.await_args_list == [ call( {"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}, @@ -134,6 +144,40 @@ async def test_rollback_restores_snapshot_without_deleting_thread(): ] +@pytest.mark.anyio +async def test_rollback_restored_checkpoint_becomes_latest_with_real_checkpointer(): + checkpointer = InMemorySaver() + thread_config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + before_checkpoint = _make_checkpoint("0001", ["before"], 1) + before_config = checkpointer.put(thread_config, before_checkpoint, {"step": 1}, {"messages": 1}) + after_checkpoint = _make_checkpoint("0002", ["after"], 2) + after_config = checkpointer.put(before_config, after_checkpoint, {"step": 2}, {"messages": 2}) + checkpointer.put_writes(after_config, [("messages", "pending-after")], task_id="task-after") + + await _rollback_to_pre_run_checkpoint( + checkpointer=checkpointer, + thread_id="thread-1", + run_id="run-1", + pre_run_checkpoint_id="0001", + pre_run_snapshot={ + "checkpoint_ns": "", + "checkpoint": before_checkpoint, + "metadata": {"step": 1}, + "pending_writes": [("task-before", "messages", "pending-before")], + }, + snapshot_capture_failed=False, + ) + + latest = checkpointer.get_tuple(thread_config) + + assert latest is not None + assert latest.config["configurable"]["checkpoint_id"] != "0001" + assert latest.config["configurable"]["checkpoint_id"] != "0002" + assert latest.checkpoint["channel_values"] == {"messages": ["before"]} + assert latest.pending_writes == [("task-before", "messages", "pending-before")] + assert ("task-after", "messages", "pending-after") not in latest.pending_writes + + @pytest.mark.anyio async def test_rollback_deletes_thread_when_no_snapshot_exists(): checkpointer = FakeCheckpointer(put_result=None) @@ -194,12 +238,13 @@ async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace(): snapshot_capture_failed=False, ) - checkpointer.aput.assert_awaited_once_with( - {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}, - {"id": "ckpt-1", "channel_versions": {}}, - {}, - {}, - ) + checkpointer.aput.assert_awaited_once() + restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args + assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + assert restored_checkpoint["id"] != "ckpt-1" + assert restored_checkpoint["channel_versions"] == {} + assert restored_metadata == {} + assert new_versions == {} @pytest.mark.anyio