From 35f141fc48ff0ae70ebfb97d8c8ccd9565187b52 Mon Sep 17 00:00:00 2001 From: luo jiyin Date: Thu, 9 Apr 2026 17:56:36 +0800 Subject: [PATCH] feat: implement full checkpoint rollback on user cancellation (#1867) * feat: implement full checkpoint rollback on user cancellation - Capture pre-run checkpoint snapshot including checkpoint state, metadata, and pending_writes - Add _rollback_to_pre_run_checkpoint() function to restore thread state - Implement _call_checkpointer_method() helper to support both async and sync checkpointer methods - Rollback now properly restores checkpoint, metadata, channel_versions, and pending_writes - Remove obsolete TODO comment (Phase 2) as rollback is now complete This resolves the TODO(Phase 2) comment and enables full thread state restoration when a run is cancelled by the user. * fix: address rollback review feedback * fix: strengthen checkpoint rollback validation and error handling - Validate restored_config structure and checkpoint_id before use - Raise RuntimeError on malformed pending_writes instead of silent skip - Normalize None checkpoint_ns to empty string instead of "None" - Move delete_thread to only execute when pre_run_snapshot is None - Add docstring noting non-atomic rollback as known limitation This addresses review feedback on PR #1867 regarding data integrity in the checkpoint rollback implementation. * test: add comprehensive coverage for checkpoint rollback edge cases - test_rollback_restores_snapshot_without_deleting_thread - test_rollback_deletes_thread_when_no_snapshot_exists - test_rollback_raises_when_restore_config_has_no_checkpoint_id - test_rollback_normalizes_none_checkpoint_ns_to_root_namespace - test_rollback_raises_on_malformed_pending_write_not_a_tuple - test_rollback_raises_on_malformed_pending_write_non_string_channel - test_rollback_propagates_aput_writes_failure Covers all scenarios from PR #1867 review feedback. * test: format rollback worker tests --- .../harness/deerflow/runtime/runs/worker.py | 161 +++++++++++-- backend/tests/test_run_worker_rollback.py | 214 ++++++++++++++++++ 2 files changed, 356 insertions(+), 19 deletions(-) create mode 100644 backend/tests/test_run_worker_rollback.py diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index 2d67ecb27..c8b074f7a 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -16,6 +16,8 @@ internal checkpoint callbacks that are not exposed in the Python public API. from __future__ import annotations import asyncio +import copy +import inspect import logging from typing import Any, Literal @@ -51,6 +53,9 @@ async def run_agent( run_id = record.run_id thread_id = record.thread_id requested_modes: set[str] = set(stream_modes or ["values"]) + pre_run_checkpoint_id: str | None = None + pre_run_snapshot: dict[str, Any] | None = None + snapshot_capture_failed = False # Track whether "events" was requested but skipped if "events" in requested_modes: @@ -63,15 +68,23 @@ async def run_agent( # 1. Mark running await run_manager.set_status(run_id, RunStatus.running) - # Record pre-run checkpoint_id to support rollback (Phase 2). - pre_run_checkpoint_id = None - try: - config_for_check = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} - ckpt_tuple = await checkpointer.aget_tuple(config_for_check) - if ckpt_tuple is not None: - pre_run_checkpoint_id = getattr(ckpt_tuple, "config", {}).get("configurable", {}).get("checkpoint_id") - except Exception: - logger.debug("Could not get pre-run checkpoint_id for run %s", run_id) + # Snapshot the latest pre-run checkpoint so rollback can restore it. + if checkpointer is not None: + try: + config_for_check = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} + ckpt_tuple = await checkpointer.aget_tuple(config_for_check) + if ckpt_tuple is not None: + ckpt_config = getattr(ckpt_tuple, "config", {}).get("configurable", {}) + pre_run_checkpoint_id = ckpt_config.get("checkpoint_id") + pre_run_snapshot = { + "checkpoint_ns": ckpt_config.get("checkpoint_ns", ""), + "checkpoint": copy.deepcopy(getattr(ckpt_tuple, "checkpoint", {})), + "metadata": copy.deepcopy(getattr(ckpt_tuple, "metadata", {})), + "pending_writes": copy.deepcopy(getattr(ckpt_tuple, "pending_writes", []) or []), + } + except Exception: + snapshot_capture_failed = True + logger.warning("Could not capture pre-run checkpoint snapshot for run %s", run_id, exc_info=True) # 2. Publish metadata — useStream needs both run_id AND thread_id await bridge.publish( @@ -172,17 +185,18 @@ async def run_agent( action = record.abort_action if action == "rollback": await run_manager.set_status(run_id, RunStatus.error, error="Rolled back by user") - # TODO(Phase 2): Implement full checkpoint rollback. - # Use pre_run_checkpoint_id to revert the thread's checkpoint - # to the state before this run started. Requires a - # checkpointer.adelete() or equivalent API. try: - if checkpointer is not None and pre_run_checkpoint_id is not None: - # Phase 2: roll back to pre_run_checkpoint_id - pass - logger.info("Run %s rolled back", run_id) + await _rollback_to_pre_run_checkpoint( + checkpointer=checkpointer, + thread_id=thread_id, + run_id=run_id, + pre_run_checkpoint_id=pre_run_checkpoint_id, + pre_run_snapshot=pre_run_snapshot, + snapshot_capture_failed=snapshot_capture_failed, + ) + logger.info("Run %s rolled back to pre-run checkpoint %s", run_id, pre_run_checkpoint_id) except Exception: - logger.warning("Failed to rollback checkpoint for run %s", run_id) + logger.warning("Failed to rollback checkpoint for run %s", run_id, exc_info=True) else: await run_manager.set_status(run_id, RunStatus.interrupted) else: @@ -192,7 +206,18 @@ async def run_agent( action = record.abort_action if action == "rollback": await run_manager.set_status(run_id, RunStatus.error, error="Rolled back by user") - logger.info("Run %s was cancelled (rollback)", run_id) + try: + await _rollback_to_pre_run_checkpoint( + checkpointer=checkpointer, + thread_id=thread_id, + run_id=run_id, + pre_run_checkpoint_id=pre_run_checkpoint_id, + pre_run_snapshot=pre_run_snapshot, + snapshot_capture_failed=snapshot_capture_failed, + ) + logger.info("Run %s was cancelled and rolled back", run_id) + except Exception: + logger.warning("Run %s cancellation rollback failed", run_id, exc_info=True) else: await run_manager.set_status(run_id, RunStatus.interrupted) logger.info("Run %s was cancelled", run_id) @@ -220,6 +245,104 @@ async def run_agent( # --------------------------------------------------------------------------- +async def _call_checkpointer_method(checkpointer: Any, async_name: str, sync_name: str, *args: Any, **kwargs: Any) -> Any: + """Call a checkpointer method, supporting async and sync variants.""" + method = getattr(checkpointer, async_name, None) or getattr(checkpointer, sync_name, None) + if method is None: + raise AttributeError(f"Missing checkpointer method: {async_name}/{sync_name}") + result = method(*args, **kwargs) + if inspect.isawaitable(result): + return await result + return result + + +async def _rollback_to_pre_run_checkpoint( + *, + checkpointer: Any, + thread_id: str, + run_id: str, + pre_run_checkpoint_id: str | None, + pre_run_snapshot: dict[str, Any] | None, + snapshot_capture_failed: bool, +) -> None: + """Restore thread state to the checkpoint snapshot captured before run start.""" + if checkpointer is None: + logger.info("Run %s rollback requested but no checkpointer is configured", run_id) + return + + if snapshot_capture_failed: + logger.warning("Run %s rollback skipped: pre-run checkpoint snapshot capture failed", run_id) + return + + if pre_run_snapshot is None: + await _call_checkpointer_method(checkpointer, "adelete_thread", "delete_thread", thread_id) + logger.info("Run %s rollback reset thread %s to empty state", run_id, thread_id) + return + + checkpoint_to_restore = None + metadata_to_restore: dict[str, Any] = {} + checkpoint_ns = "" + checkpoint = pre_run_snapshot.get("checkpoint") + if not isinstance(checkpoint, dict): + logger.warning("Run %s rollback skipped: invalid pre-run checkpoint snapshot", run_id) + return + checkpoint_to_restore = checkpoint + if checkpoint_to_restore.get("id") is None and pre_run_checkpoint_id is not None: + checkpoint_to_restore = {**checkpoint_to_restore, "id": pre_run_checkpoint_id} + if checkpoint_to_restore.get("id") is None: + logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id) + return + metadata = pre_run_snapshot.get("metadata", {}) + metadata_to_restore = metadata if isinstance(metadata, dict) else {} + raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns") + checkpoint_ns = raw_checkpoint_ns if isinstance(raw_checkpoint_ns, str) else "" + + channel_versions = checkpoint_to_restore.get("channel_versions") + new_versions = dict(channel_versions) if isinstance(channel_versions, dict) else {} + + restore_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns}} + restored_config = await _call_checkpointer_method( + checkpointer, + "aput", + "put", + restore_config, + checkpoint_to_restore, + metadata_to_restore if isinstance(metadata_to_restore, dict) else {}, + new_versions, + ) + if not isinstance(restored_config, dict): + raise RuntimeError(f"Run {run_id} rollback restore returned invalid config: expected dict") + restored_configurable = restored_config.get("configurable", {}) + if not isinstance(restored_configurable, dict): + raise RuntimeError(f"Run {run_id} rollback restore returned invalid config payload") + restored_checkpoint_id = restored_configurable.get("checkpoint_id") + if not restored_checkpoint_id: + raise RuntimeError(f"Run {run_id} rollback restore did not return checkpoint_id") + + pending_writes = pre_run_snapshot.get("pending_writes", []) + if not pending_writes: + return + + writes_by_task: dict[str, list[tuple[str, Any]]] = {} + for item in pending_writes: + if not isinstance(item, (tuple, list)) or len(item) != 3: + raise RuntimeError(f"Run {run_id} rollback failed: pending_write is not a 3-tuple: {item!r}") + task_id, channel, value = item + if not isinstance(channel, str): + raise RuntimeError(f"Run {run_id} rollback failed: pending_write has non-string channel: task_id={task_id!r}, channel={channel!r}") + writes_by_task.setdefault(str(task_id), []).append((channel, value)) + + for task_id, writes in writes_by_task.items(): + await _call_checkpointer_method( + checkpointer, + "aput_writes", + "put_writes", + restored_config, + writes, + task_id=task_id, + ) + + def _lg_mode_to_sse_event(mode: str) -> str: """Map LangGraph internal stream_mode name to SSE event name. diff --git a/backend/tests/test_run_worker_rollback.py b/backend/tests/test_run_worker_rollback.py new file mode 100644 index 000000000..714ccdde1 --- /dev/null +++ b/backend/tests/test_run_worker_rollback.py @@ -0,0 +1,214 @@ +from unittest.mock import AsyncMock, call + +import pytest + +from deerflow.runtime.runs.worker import _rollback_to_pre_run_checkpoint + + +class FakeCheckpointer: + def __init__(self, *, put_result): + self.adelete_thread = AsyncMock() + self.aput = AsyncMock(return_value=put_result) + self.aput_writes = AsyncMock() + + +@pytest.mark.anyio +async def test_rollback_restores_snapshot_without_deleting_thread(): + checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) + + await _rollback_to_pre_run_checkpoint( + checkpointer=checkpointer, + thread_id="thread-1", + run_id="run-1", + pre_run_checkpoint_id="ckpt-1", + pre_run_snapshot={ + "checkpoint_ns": "", + "checkpoint": { + "id": "ckpt-1", + "channel_versions": {"messages": 3}, + "channel_values": {"messages": ["before"]}, + }, + "metadata": {"source": "input"}, + "pending_writes": [ + ("task-a", "messages", {"content": "first"}), + ("task-a", "status", "done"), + ("task-b", "events", {"type": "tool"}), + ], + }, + snapshot_capture_failed=False, + ) + + checkpointer.adelete_thread.assert_not_awaited() + checkpointer.aput.assert_awaited_once_with( + {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}, + { + "id": "ckpt-1", + "channel_versions": {"messages": 3}, + "channel_values": {"messages": ["before"]}, + }, + {"source": "input"}, + {"messages": 3}, + ) + assert checkpointer.aput_writes.await_args_list == [ + call( + {"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}, + [("messages", {"content": "first"}), ("status", "done")], + task_id="task-a", + ), + call( + {"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}, + [("events", {"type": "tool"})], + task_id="task-b", + ), + ] + + +@pytest.mark.anyio +async def test_rollback_deletes_thread_when_no_snapshot_exists(): + checkpointer = FakeCheckpointer(put_result=None) + + await _rollback_to_pre_run_checkpoint( + checkpointer=checkpointer, + thread_id="thread-1", + run_id="run-1", + pre_run_checkpoint_id=None, + pre_run_snapshot=None, + snapshot_capture_failed=False, + ) + + checkpointer.adelete_thread.assert_awaited_once_with("thread-1") + checkpointer.aput.assert_not_awaited() + checkpointer.aput_writes.assert_not_awaited() + + +@pytest.mark.anyio +async def test_rollback_raises_when_restore_config_has_no_checkpoint_id(): + checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}) + + with pytest.raises(RuntimeError, match="did not return checkpoint_id"): + await _rollback_to_pre_run_checkpoint( + checkpointer=checkpointer, + thread_id="thread-1", + run_id="run-1", + pre_run_checkpoint_id="ckpt-1", + pre_run_snapshot={ + "checkpoint_ns": "", + "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, + "metadata": {}, + "pending_writes": [("task-a", "messages", "value")], + }, + snapshot_capture_failed=False, + ) + + checkpointer.adelete_thread.assert_not_awaited() + checkpointer.aput.assert_awaited_once() + checkpointer.aput_writes.assert_not_awaited() + + +@pytest.mark.anyio +async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace(): + checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) + + await _rollback_to_pre_run_checkpoint( + checkpointer=checkpointer, + thread_id="thread-1", + run_id="run-1", + pre_run_checkpoint_id="ckpt-1", + pre_run_snapshot={ + "checkpoint_ns": None, + "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, + "metadata": {}, + "pending_writes": [], + }, + snapshot_capture_failed=False, + ) + + checkpointer.aput.assert_awaited_once_with( + {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}, + {"id": "ckpt-1", "channel_versions": {}}, + {}, + {}, + ) + + +@pytest.mark.anyio +async def test_rollback_raises_on_malformed_pending_write_not_a_tuple(): + """pending_writes containing a non-3-tuple item should raise RuntimeError.""" + checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) + + with pytest.raises(RuntimeError, match="rollback failed: pending_write is not a 3-tuple"): + await _rollback_to_pre_run_checkpoint( + checkpointer=checkpointer, + thread_id="thread-1", + run_id="run-1", + pre_run_checkpoint_id="ckpt-1", + pre_run_snapshot={ + "checkpoint_ns": "", + "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, + "metadata": {}, + "pending_writes": [ + ("task-a", "messages", "valid"), # valid + ["only", "two"], # malformed: only 2 elements + ], + }, + snapshot_capture_failed=False, + ) + + # aput succeeded but aput_writes should not be called due to malformed data + checkpointer.aput.assert_awaited_once() + checkpointer.aput_writes.assert_not_awaited() + + +@pytest.mark.anyio +async def test_rollback_raises_on_malformed_pending_write_non_string_channel(): + """pending_writes containing a non-string channel should raise RuntimeError.""" + checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) + + with pytest.raises(RuntimeError, match="rollback failed: pending_write has non-string channel"): + await _rollback_to_pre_run_checkpoint( + checkpointer=checkpointer, + thread_id="thread-1", + run_id="run-1", + pre_run_checkpoint_id="ckpt-1", + pre_run_snapshot={ + "checkpoint_ns": "", + "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, + "metadata": {}, + "pending_writes": [ + ("task-a", 123, "value"), # malformed: channel is not a string + ], + }, + snapshot_capture_failed=False, + ) + + checkpointer.aput.assert_awaited_once() + checkpointer.aput_writes.assert_not_awaited() + + +@pytest.mark.anyio +async def test_rollback_propagates_aput_writes_failure(): + """If aput_writes fails, the exception should propagate (not be swallowed).""" + checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}}) + # Simulate aput_writes failure + checkpointer.aput_writes.side_effect = RuntimeError("Database connection lost") + + with pytest.raises(RuntimeError, match="Database connection lost"): + await _rollback_to_pre_run_checkpoint( + checkpointer=checkpointer, + thread_id="thread-1", + run_id="run-1", + pre_run_checkpoint_id="ckpt-1", + pre_run_snapshot={ + "checkpoint_ns": "", + "checkpoint": {"id": "ckpt-1", "channel_versions": {}}, + "metadata": {}, + "pending_writes": [ + ("task-a", "messages", "value"), + ], + }, + snapshot_capture_failed=False, + ) + + # aput succeeded, aput_writes was called but failed + checkpointer.aput.assert_awaited_once() + checkpointer.aput_writes.assert_awaited_once()