mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-02 23:08:22 +00:00
fix(runtime): make rollback restore checkpoint supersede newer checkpoints (#2582)
* Restore rollback checkpoints with fresh ids * Tighten rollback checkpoint tests and imports * Update test_run_worker_rollback.py --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
parent
866d1ca409
commit
17447fccbe
@ -18,6 +18,7 @@ import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
@ -262,8 +263,6 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
# Write an empty checkpoint so state endpoints work immediately
|
||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
try:
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
|
||||
ckpt_metadata = {
|
||||
"step": -1,
|
||||
"source": "input",
|
||||
|
||||
@ -23,6 +23,8 @@ from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
@ -442,6 +444,12 @@ async def _rollback_to_pre_run_checkpoint(
|
||||
if checkpoint_to_restore.get("id") is None:
|
||||
logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id)
|
||||
return
|
||||
restore_marker = _new_checkpoint_marker()
|
||||
checkpoint_to_restore = {
|
||||
**checkpoint_to_restore,
|
||||
"id": restore_marker["id"],
|
||||
"ts": restore_marker["ts"],
|
||||
}
|
||||
metadata = pre_run_snapshot.get("metadata", {})
|
||||
metadata_to_restore = metadata if isinstance(metadata, dict) else {}
|
||||
raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns")
|
||||
@ -493,6 +501,11 @@ async def _rollback_to_pre_run_checkpoint(
|
||||
)
|
||||
|
||||
|
||||
def _new_checkpoint_marker() -> dict[str, str]:
|
||||
marker = empty_checkpoint()
|
||||
return {"id": marker["id"], "ts": marker["ts"]}
|
||||
|
||||
|
||||
def _lg_mode_to_sse_event(mode: str) -> str:
|
||||
"""Map LangGraph internal stream_mode name to SSE event name.
|
||||
|
||||
|
||||
@ -47,4 +47,3 @@ members = ["packages/harness"]
|
||||
|
||||
[tool.uv.sources]
|
||||
deerflow-harness = { workspace = true }
|
||||
|
||||
|
||||
@ -3,6 +3,8 @@ from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, call
|
||||
|
||||
import pytest
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from deerflow.runtime.runs.manager import RunManager
|
||||
from deerflow.runtime.runs.schemas import RunStatus
|
||||
@ -16,6 +18,14 @@ class FakeCheckpointer:
|
||||
self.aput_writes = AsyncMock()
|
||||
|
||||
|
||||
def _make_checkpoint(checkpoint_id: str, messages: list[str], version: int):
|
||||
checkpoint = empty_checkpoint()
|
||||
checkpoint["id"] = checkpoint_id
|
||||
checkpoint["channel_values"] = {"messages": messages}
|
||||
checkpoint["channel_versions"] = {"messages": version}
|
||||
return checkpoint
|
||||
|
||||
|
||||
def test_build_runtime_context_includes_app_config_when_present():
|
||||
app_config = object()
|
||||
|
||||
@ -110,16 +120,16 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
|
||||
)
|
||||
|
||||
checkpointer.adelete_thread.assert_not_awaited()
|
||||
checkpointer.aput.assert_awaited_once_with(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
||||
{
|
||||
"id": "ckpt-1",
|
||||
"channel_versions": {"messages": 3},
|
||||
"channel_values": {"messages": ["before"]},
|
||||
},
|
||||
{"source": "input"},
|
||||
{"messages": 3},
|
||||
)
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
|
||||
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||
assert restored_checkpoint["id"] != "ckpt-1"
|
||||
assert "channel_versions" in restored_checkpoint
|
||||
assert "channel_values" in restored_checkpoint
|
||||
assert restored_checkpoint["channel_versions"] == {"messages": 3}
|
||||
assert restored_checkpoint["channel_values"] == {"messages": ["before"]}
|
||||
assert restored_metadata == {"source": "input"}
|
||||
assert new_versions == {"messages": 3}
|
||||
assert checkpointer.aput_writes.await_args_list == [
|
||||
call(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
|
||||
@ -134,6 +144,40 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_restored_checkpoint_becomes_latest_with_real_checkpointer():
|
||||
checkpointer = InMemorySaver()
|
||||
thread_config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||
before_checkpoint = _make_checkpoint("0001", ["before"], 1)
|
||||
before_config = checkpointer.put(thread_config, before_checkpoint, {"step": 1}, {"messages": 1})
|
||||
after_checkpoint = _make_checkpoint("0002", ["after"], 2)
|
||||
after_config = checkpointer.put(before_config, after_checkpoint, {"step": 2}, {"messages": 2})
|
||||
checkpointer.put_writes(after_config, [("messages", "pending-after")], task_id="task-after")
|
||||
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="0001",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": before_checkpoint,
|
||||
"metadata": {"step": 1},
|
||||
"pending_writes": [("task-before", "messages", "pending-before")],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
latest = checkpointer.get_tuple(thread_config)
|
||||
|
||||
assert latest is not None
|
||||
assert latest.config["configurable"]["checkpoint_id"] != "0001"
|
||||
assert latest.config["configurable"]["checkpoint_id"] != "0002"
|
||||
assert latest.checkpoint["channel_values"] == {"messages": ["before"]}
|
||||
assert latest.pending_writes == [("task-before", "messages", "pending-before")]
|
||||
assert ("task-after", "messages", "pending-after") not in latest.pending_writes
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_deletes_thread_when_no_snapshot_exists():
|
||||
checkpointer = FakeCheckpointer(put_result=None)
|
||||
@ -194,12 +238,13 @@ async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace():
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.aput.assert_awaited_once_with(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
||||
{"id": "ckpt-1", "channel_versions": {}},
|
||||
{},
|
||||
{},
|
||||
)
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
|
||||
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||
assert restored_checkpoint["id"] != "ckpt-1"
|
||||
assert restored_checkpoint["channel_versions"] == {}
|
||||
assert restored_metadata == {}
|
||||
assert new_versions == {}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user