mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-25 11:18:22 +00:00
feat: implement full checkpoint rollback on user cancellation (#1867)
* feat: implement full checkpoint rollback on user cancellation - Capture pre-run checkpoint snapshot including checkpoint state, metadata, and pending_writes - Add _rollback_to_pre_run_checkpoint() function to restore thread state - Implement _call_checkpointer_method() helper to support both async and sync checkpointer methods - Rollback now properly restores checkpoint, metadata, channel_versions, and pending_writes - Remove obsolete TODO comment (Phase 2) as rollback is now complete This resolves the TODO(Phase 2) comment and enables full thread state restoration when a run is cancelled by the user. * fix: address rollback review feedback * fix: strengthen checkpoint rollback validation and error handling - Validate restored_config structure and checkpoint_id before use - Raise RuntimeError on malformed pending_writes instead of silent skip - Normalize None checkpoint_ns to empty string instead of "None" - Move delete_thread to only execute when pre_run_snapshot is None - Add docstring noting non-atomic rollback as known limitation This addresses review feedback on PR #1867 regarding data integrity in the checkpoint rollback implementation. * test: add comprehensive coverage for checkpoint rollback edge cases - test_rollback_restores_snapshot_without_deleting_thread - test_rollback_deletes_thread_when_no_snapshot_exists - test_rollback_raises_when_restore_config_has_no_checkpoint_id - test_rollback_normalizes_none_checkpoint_ns_to_root_namespace - test_rollback_raises_on_malformed_pending_write_not_a_tuple - test_rollback_raises_on_malformed_pending_write_non_string_channel - test_rollback_propagates_aput_writes_failure Covers all scenarios from PR #1867 review feedback. * test: format rollback worker tests
This commit is contained in:
parent
0b6fa8b9e1
commit
35f141fc48
@ -16,6 +16,8 @@ internal checkpoint callbacks that are not exposed in the Python public API.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
@ -51,6 +53,9 @@ async def run_agent(
|
||||
run_id = record.run_id
|
||||
thread_id = record.thread_id
|
||||
requested_modes: set[str] = set(stream_modes or ["values"])
|
||||
pre_run_checkpoint_id: str | None = None
|
||||
pre_run_snapshot: dict[str, Any] | None = None
|
||||
snapshot_capture_failed = False
|
||||
|
||||
# Track whether "events" was requested but skipped
|
||||
if "events" in requested_modes:
|
||||
@ -63,15 +68,23 @@ async def run_agent(
|
||||
# 1. Mark running
|
||||
await run_manager.set_status(run_id, RunStatus.running)
|
||||
|
||||
# Record pre-run checkpoint_id to support rollback (Phase 2).
|
||||
pre_run_checkpoint_id = None
|
||||
try:
|
||||
config_for_check = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
ckpt_tuple = await checkpointer.aget_tuple(config_for_check)
|
||||
if ckpt_tuple is not None:
|
||||
pre_run_checkpoint_id = getattr(ckpt_tuple, "config", {}).get("configurable", {}).get("checkpoint_id")
|
||||
except Exception:
|
||||
logger.debug("Could not get pre-run checkpoint_id for run %s", run_id)
|
||||
# Snapshot the latest pre-run checkpoint so rollback can restore it.
|
||||
if checkpointer is not None:
|
||||
try:
|
||||
config_for_check = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
ckpt_tuple = await checkpointer.aget_tuple(config_for_check)
|
||||
if ckpt_tuple is not None:
|
||||
ckpt_config = getattr(ckpt_tuple, "config", {}).get("configurable", {})
|
||||
pre_run_checkpoint_id = ckpt_config.get("checkpoint_id")
|
||||
pre_run_snapshot = {
|
||||
"checkpoint_ns": ckpt_config.get("checkpoint_ns", ""),
|
||||
"checkpoint": copy.deepcopy(getattr(ckpt_tuple, "checkpoint", {})),
|
||||
"metadata": copy.deepcopy(getattr(ckpt_tuple, "metadata", {})),
|
||||
"pending_writes": copy.deepcopy(getattr(ckpt_tuple, "pending_writes", []) or []),
|
||||
}
|
||||
except Exception:
|
||||
snapshot_capture_failed = True
|
||||
logger.warning("Could not capture pre-run checkpoint snapshot for run %s", run_id, exc_info=True)
|
||||
|
||||
# 2. Publish metadata — useStream needs both run_id AND thread_id
|
||||
await bridge.publish(
|
||||
@ -172,17 +185,18 @@ async def run_agent(
|
||||
action = record.abort_action
|
||||
if action == "rollback":
|
||||
await run_manager.set_status(run_id, RunStatus.error, error="Rolled back by user")
|
||||
# TODO(Phase 2): Implement full checkpoint rollback.
|
||||
# Use pre_run_checkpoint_id to revert the thread's checkpoint
|
||||
# to the state before this run started. Requires a
|
||||
# checkpointer.adelete() or equivalent API.
|
||||
try:
|
||||
if checkpointer is not None and pre_run_checkpoint_id is not None:
|
||||
# Phase 2: roll back to pre_run_checkpoint_id
|
||||
pass
|
||||
logger.info("Run %s rolled back", run_id)
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
pre_run_checkpoint_id=pre_run_checkpoint_id,
|
||||
pre_run_snapshot=pre_run_snapshot,
|
||||
snapshot_capture_failed=snapshot_capture_failed,
|
||||
)
|
||||
logger.info("Run %s rolled back to pre-run checkpoint %s", run_id, pre_run_checkpoint_id)
|
||||
except Exception:
|
||||
logger.warning("Failed to rollback checkpoint for run %s", run_id)
|
||||
logger.warning("Failed to rollback checkpoint for run %s", run_id, exc_info=True)
|
||||
else:
|
||||
await run_manager.set_status(run_id, RunStatus.interrupted)
|
||||
else:
|
||||
@ -192,7 +206,18 @@ async def run_agent(
|
||||
action = record.abort_action
|
||||
if action == "rollback":
|
||||
await run_manager.set_status(run_id, RunStatus.error, error="Rolled back by user")
|
||||
logger.info("Run %s was cancelled (rollback)", run_id)
|
||||
try:
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
pre_run_checkpoint_id=pre_run_checkpoint_id,
|
||||
pre_run_snapshot=pre_run_snapshot,
|
||||
snapshot_capture_failed=snapshot_capture_failed,
|
||||
)
|
||||
logger.info("Run %s was cancelled and rolled back", run_id)
|
||||
except Exception:
|
||||
logger.warning("Run %s cancellation rollback failed", run_id, exc_info=True)
|
||||
else:
|
||||
await run_manager.set_status(run_id, RunStatus.interrupted)
|
||||
logger.info("Run %s was cancelled", run_id)
|
||||
@ -220,6 +245,104 @@ async def run_agent(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _call_checkpointer_method(checkpointer: Any, async_name: str, sync_name: str, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call a checkpointer method, supporting async and sync variants."""
|
||||
method = getattr(checkpointer, async_name, None) or getattr(checkpointer, sync_name, None)
|
||||
if method is None:
|
||||
raise AttributeError(f"Missing checkpointer method: {async_name}/{sync_name}")
|
||||
result = method(*args, **kwargs)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
|
||||
async def _rollback_to_pre_run_checkpoint(
|
||||
*,
|
||||
checkpointer: Any,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
pre_run_checkpoint_id: str | None,
|
||||
pre_run_snapshot: dict[str, Any] | None,
|
||||
snapshot_capture_failed: bool,
|
||||
) -> None:
|
||||
"""Restore thread state to the checkpoint snapshot captured before run start."""
|
||||
if checkpointer is None:
|
||||
logger.info("Run %s rollback requested but no checkpointer is configured", run_id)
|
||||
return
|
||||
|
||||
if snapshot_capture_failed:
|
||||
logger.warning("Run %s rollback skipped: pre-run checkpoint snapshot capture failed", run_id)
|
||||
return
|
||||
|
||||
if pre_run_snapshot is None:
|
||||
await _call_checkpointer_method(checkpointer, "adelete_thread", "delete_thread", thread_id)
|
||||
logger.info("Run %s rollback reset thread %s to empty state", run_id, thread_id)
|
||||
return
|
||||
|
||||
checkpoint_to_restore = None
|
||||
metadata_to_restore: dict[str, Any] = {}
|
||||
checkpoint_ns = ""
|
||||
checkpoint = pre_run_snapshot.get("checkpoint")
|
||||
if not isinstance(checkpoint, dict):
|
||||
logger.warning("Run %s rollback skipped: invalid pre-run checkpoint snapshot", run_id)
|
||||
return
|
||||
checkpoint_to_restore = checkpoint
|
||||
if checkpoint_to_restore.get("id") is None and pre_run_checkpoint_id is not None:
|
||||
checkpoint_to_restore = {**checkpoint_to_restore, "id": pre_run_checkpoint_id}
|
||||
if checkpoint_to_restore.get("id") is None:
|
||||
logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id)
|
||||
return
|
||||
metadata = pre_run_snapshot.get("metadata", {})
|
||||
metadata_to_restore = metadata if isinstance(metadata, dict) else {}
|
||||
raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns")
|
||||
checkpoint_ns = raw_checkpoint_ns if isinstance(raw_checkpoint_ns, str) else ""
|
||||
|
||||
channel_versions = checkpoint_to_restore.get("channel_versions")
|
||||
new_versions = dict(channel_versions) if isinstance(channel_versions, dict) else {}
|
||||
|
||||
restore_config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": checkpoint_ns}}
|
||||
restored_config = await _call_checkpointer_method(
|
||||
checkpointer,
|
||||
"aput",
|
||||
"put",
|
||||
restore_config,
|
||||
checkpoint_to_restore,
|
||||
metadata_to_restore if isinstance(metadata_to_restore, dict) else {},
|
||||
new_versions,
|
||||
)
|
||||
if not isinstance(restored_config, dict):
|
||||
raise RuntimeError(f"Run {run_id} rollback restore returned invalid config: expected dict")
|
||||
restored_configurable = restored_config.get("configurable", {})
|
||||
if not isinstance(restored_configurable, dict):
|
||||
raise RuntimeError(f"Run {run_id} rollback restore returned invalid config payload")
|
||||
restored_checkpoint_id = restored_configurable.get("checkpoint_id")
|
||||
if not restored_checkpoint_id:
|
||||
raise RuntimeError(f"Run {run_id} rollback restore did not return checkpoint_id")
|
||||
|
||||
pending_writes = pre_run_snapshot.get("pending_writes", [])
|
||||
if not pending_writes:
|
||||
return
|
||||
|
||||
writes_by_task: dict[str, list[tuple[str, Any]]] = {}
|
||||
for item in pending_writes:
|
||||
if not isinstance(item, (tuple, list)) or len(item) != 3:
|
||||
raise RuntimeError(f"Run {run_id} rollback failed: pending_write is not a 3-tuple: {item!r}")
|
||||
task_id, channel, value = item
|
||||
if not isinstance(channel, str):
|
||||
raise RuntimeError(f"Run {run_id} rollback failed: pending_write has non-string channel: task_id={task_id!r}, channel={channel!r}")
|
||||
writes_by_task.setdefault(str(task_id), []).append((channel, value))
|
||||
|
||||
for task_id, writes in writes_by_task.items():
|
||||
await _call_checkpointer_method(
|
||||
checkpointer,
|
||||
"aput_writes",
|
||||
"put_writes",
|
||||
restored_config,
|
||||
writes,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
|
||||
def _lg_mode_to_sse_event(mode: str) -> str:
|
||||
"""Map LangGraph internal stream_mode name to SSE event name.
|
||||
|
||||
|
||||
214
backend/tests/test_run_worker_rollback.py
Normal file
214
backend/tests/test_run_worker_rollback.py
Normal file
@ -0,0 +1,214 @@
|
||||
from unittest.mock import AsyncMock, call
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.runs.worker import _rollback_to_pre_run_checkpoint
|
||||
|
||||
|
||||
class FakeCheckpointer:
|
||||
def __init__(self, *, put_result):
|
||||
self.adelete_thread = AsyncMock()
|
||||
self.aput = AsyncMock(return_value=put_result)
|
||||
self.aput_writes = AsyncMock()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_restores_snapshot_without_deleting_thread():
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
||||
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="ckpt-1",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": {
|
||||
"id": "ckpt-1",
|
||||
"channel_versions": {"messages": 3},
|
||||
"channel_values": {"messages": ["before"]},
|
||||
},
|
||||
"metadata": {"source": "input"},
|
||||
"pending_writes": [
|
||||
("task-a", "messages", {"content": "first"}),
|
||||
("task-a", "status", "done"),
|
||||
("task-b", "events", {"type": "tool"}),
|
||||
],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.adelete_thread.assert_not_awaited()
|
||||
checkpointer.aput.assert_awaited_once_with(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
||||
{
|
||||
"id": "ckpt-1",
|
||||
"channel_versions": {"messages": 3},
|
||||
"channel_values": {"messages": ["before"]},
|
||||
},
|
||||
{"source": "input"},
|
||||
{"messages": 3},
|
||||
)
|
||||
assert checkpointer.aput_writes.await_args_list == [
|
||||
call(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
|
||||
[("messages", {"content": "first"}), ("status", "done")],
|
||||
task_id="task-a",
|
||||
),
|
||||
call(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
|
||||
[("events", {"type": "tool"})],
|
||||
task_id="task-b",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_deletes_thread_when_no_snapshot_exists():
|
||||
checkpointer = FakeCheckpointer(put_result=None)
|
||||
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id=None,
|
||||
pre_run_snapshot=None,
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.adelete_thread.assert_awaited_once_with("thread-1")
|
||||
checkpointer.aput.assert_not_awaited()
|
||||
checkpointer.aput_writes.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_raises_when_restore_config_has_no_checkpoint_id():
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}})
|
||||
|
||||
with pytest.raises(RuntimeError, match="did not return checkpoint_id"):
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="ckpt-1",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
||||
"metadata": {},
|
||||
"pending_writes": [("task-a", "messages", "value")],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.adelete_thread.assert_not_awaited()
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
checkpointer.aput_writes.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace():
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
||||
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="ckpt-1",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": None,
|
||||
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
||||
"metadata": {},
|
||||
"pending_writes": [],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.aput.assert_awaited_once_with(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
||||
{"id": "ckpt-1", "channel_versions": {}},
|
||||
{},
|
||||
{},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_raises_on_malformed_pending_write_not_a_tuple():
|
||||
"""pending_writes containing a non-3-tuple item should raise RuntimeError."""
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
||||
|
||||
with pytest.raises(RuntimeError, match="rollback failed: pending_write is not a 3-tuple"):
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="ckpt-1",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
||||
"metadata": {},
|
||||
"pending_writes": [
|
||||
("task-a", "messages", "valid"), # valid
|
||||
["only", "two"], # malformed: only 2 elements
|
||||
],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
# aput succeeded but aput_writes should not be called due to malformed data
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
checkpointer.aput_writes.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_raises_on_malformed_pending_write_non_string_channel():
|
||||
"""pending_writes containing a non-string channel should raise RuntimeError."""
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
||||
|
||||
with pytest.raises(RuntimeError, match="rollback failed: pending_write has non-string channel"):
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="ckpt-1",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
||||
"metadata": {},
|
||||
"pending_writes": [
|
||||
("task-a", 123, "value"), # malformed: channel is not a string
|
||||
],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
checkpointer.aput_writes.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_propagates_aput_writes_failure():
|
||||
"""If aput_writes fails, the exception should propagate (not be swallowed)."""
|
||||
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
|
||||
# Simulate aput_writes failure
|
||||
checkpointer.aput_writes.side_effect = RuntimeError("Database connection lost")
|
||||
|
||||
with pytest.raises(RuntimeError, match="Database connection lost"):
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="ckpt-1",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
|
||||
"metadata": {},
|
||||
"pending_writes": [
|
||||
("task-a", "messages", "value"),
|
||||
],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
# aput succeeded, aput_writes was called but failed
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
checkpointer.aput_writes.assert_awaited_once()
|
||||
Loading…
x
Reference in New Issue
Block a user