mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-30 04:18:09 +00:00
refactor(provider): share assistant payload replay matching (#3307)
* Share assistant payload replay matching * fix(provider): recover assistant field when ordinal AI index is taken The mismatch-length fallback in `_match_ai_message` only tried the exact `fallback_ordinal` AI index. When serialization drops or reorders an assistant message, a unique signature match can consume a non-ordinal index, leaving a later ambiguous payload's ordinal already used — so its provider field (e.g. `reasoning_content`) was silently dropped. Scan forward from the ordinal for the next unused `AIMessage` (wrapping to earlier indices) to preserve the positional bias while still recovering the field. Forward scanning avoids a naive min-unused pick that could restore the wrong field after a leading message is dropped. Add a regression test for the dropped-leading-message case. * fix(provider): avoid earlier assistant fallback replay
This commit is contained in:
parent
052b1e2102
commit
4093c83383
@ -0,0 +1,124 @@
|
||||
"""Helpers for replaying provider-specific assistant message fields.
|
||||
|
||||
Several provider adapters need to preserve fields that LangChain stores on the
|
||||
original ``AIMessage`` but drops when serializing request payloads. This module
|
||||
keeps the assistant-message matching logic shared while letting each provider
|
||||
decide which fields to restore.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
|
||||
AssistantPayloadRestorer = Callable[[dict[str, Any], AIMessage], None]
|
||||
|
||||
|
||||
def restore_assistant_payloads(
|
||||
payload_messages: Sequence[dict[str, Any]],
|
||||
original_messages: Sequence[BaseMessage],
|
||||
restore: AssistantPayloadRestorer,
|
||||
) -> None:
|
||||
"""Restore provider-specific fields onto serialized assistant payloads."""
|
||||
if len(payload_messages) == len(original_messages):
|
||||
for payload_msg, orig_msg in zip(payload_messages, original_messages):
|
||||
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
|
||||
restore(payload_msg, orig_msg)
|
||||
return
|
||||
|
||||
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
|
||||
assistant_payloads = [m for m in payload_messages if m.get("role") == "assistant"]
|
||||
used_ai_indexes: set[int] = set()
|
||||
|
||||
for ordinal, payload_msg in enumerate(assistant_payloads):
|
||||
ai_msg = _match_ai_message(payload_msg, ai_messages, used_ai_indexes, ordinal)
|
||||
if ai_msg is not None:
|
||||
restore(payload_msg, ai_msg)
|
||||
|
||||
|
||||
def restore_additional_kwargs_field(payload_msg: dict[str, Any], orig_msg: AIMessage, field_name: str) -> None:
|
||||
"""Copy a provider-specific ``additional_kwargs`` field onto a payload message."""
|
||||
value = orig_msg.additional_kwargs.get(field_name)
|
||||
if value is not None:
|
||||
payload_msg[field_name] = value
|
||||
|
||||
|
||||
def restore_reasoning_content(payload_msg: dict[str, Any], orig_msg: AIMessage) -> None:
|
||||
"""Copy provider reasoning content onto a serialized assistant payload."""
|
||||
restore_additional_kwargs_field(payload_msg, orig_msg, "reasoning_content")
|
||||
|
||||
|
||||
def _match_ai_message(
|
||||
payload_msg: dict[str, Any],
|
||||
ai_messages: Sequence[AIMessage],
|
||||
used_ai_indexes: set[int],
|
||||
fallback_ordinal: int,
|
||||
) -> AIMessage | None:
|
||||
payload_key = _assistant_signature(payload_msg)
|
||||
if payload_key is not None:
|
||||
matches = [index for index, ai_msg in enumerate(ai_messages) if index not in used_ai_indexes and _ai_signature(ai_msg) == payload_key]
|
||||
if len(matches) == 1:
|
||||
used_ai_indexes.add(matches[0])
|
||||
return ai_messages[matches[0]]
|
||||
|
||||
fallback_index = _next_unused_index_at_or_after(len(ai_messages), used_ai_indexes, fallback_ordinal)
|
||||
if fallback_index is not None:
|
||||
used_ai_indexes.add(fallback_index)
|
||||
return ai_messages[fallback_index]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _next_unused_index_at_or_after(count: int, used_ai_indexes: set[int], start: int) -> int | None:
|
||||
"""Return the next unused AI index at or after ``start``.
|
||||
|
||||
Scanning forward from the payload's ordinal preserves the positional bias of
|
||||
the previous behaviour while still recovering when serialization drops or
|
||||
reorders messages so the exact ordinal index is already taken. It does not
|
||||
wrap to earlier indexes because those messages may be represented by payload
|
||||
entries that were already dropped.
|
||||
"""
|
||||
if count == 0 or start >= count:
|
||||
return None
|
||||
for index in range(start, count):
|
||||
if index not in used_ai_indexes:
|
||||
return index
|
||||
return None
|
||||
|
||||
|
||||
def _assistant_signature(payload_msg: dict[str, Any]) -> tuple[str, str] | None:
|
||||
return _signature(
|
||||
payload_msg.get("content"),
|
||||
_tool_call_ids(payload_msg.get("tool_calls") or []),
|
||||
)
|
||||
|
||||
|
||||
def _ai_signature(message: AIMessage) -> tuple[str, str] | None:
|
||||
tool_calls = message.tool_calls or message.additional_kwargs.get("tool_calls") or []
|
||||
return _signature(message.content, _tool_call_ids(tool_calls))
|
||||
|
||||
|
||||
def _signature(content: Any, tool_call_ids: tuple[str, ...]) -> tuple[str, str] | None:
|
||||
if content in (None, "") and not tool_call_ids:
|
||||
return None
|
||||
return (_stable_repr(content), "|".join(tool_call_ids))
|
||||
|
||||
|
||||
def _stable_repr(value: Any) -> str:
|
||||
try:
|
||||
return json.dumps(value, sort_keys=True, ensure_ascii=False)
|
||||
except TypeError:
|
||||
return repr(value)
|
||||
|
||||
|
||||
def _tool_call_ids(tool_calls: Sequence[Any]) -> tuple[str, ...]:
|
||||
ids: list[str] = []
|
||||
for tool_call in tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
call_id = tool_call.get("id")
|
||||
if isinstance(call_id, str) and call_id:
|
||||
ids.append(call_id)
|
||||
return tuple(ids)
|
||||
@ -10,9 +10,10 @@ on all assistant messages when thinking mode is enabled.
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
from deerflow.models.assistant_payload_replay import restore_assistant_payloads, restore_reasoning_content
|
||||
|
||||
|
||||
class PatchedChatDeepSeek(ChatDeepSeek):
|
||||
"""ChatDeepSeek with proper reasoning_content preservation.
|
||||
@ -49,25 +50,10 @@ class PatchedChatDeepSeek(ChatDeepSeek):
|
||||
# Call parent to get the base payload
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
|
||||
# Match payload messages with original messages to restore reasoning_content
|
||||
payload_messages = payload.get("messages", [])
|
||||
|
||||
# The payload messages and original messages should be in the same order
|
||||
# Iterate through both and match by position
|
||||
if len(payload_messages) == len(original_messages):
|
||||
for payload_msg, orig_msg in zip(payload_messages, original_messages):
|
||||
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
|
||||
reasoning_content = orig_msg.additional_kwargs.get("reasoning_content")
|
||||
if reasoning_content is not None:
|
||||
payload_msg["reasoning_content"] = reasoning_content
|
||||
else:
|
||||
# Fallback: match by counting assistant messages
|
||||
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
|
||||
assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"]
|
||||
|
||||
for (idx, payload_msg), ai_msg in zip(assistant_payloads, ai_messages):
|
||||
reasoning_content = ai_msg.additional_kwargs.get("reasoning_content")
|
||||
if reasoning_content is not None:
|
||||
payload_messages[idx]["reasoning_content"] = reasoning_content
|
||||
restore_assistant_payloads(
|
||||
payload.get("messages", []),
|
||||
original_messages,
|
||||
restore_reasoning_content,
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
@ -17,6 +17,8 @@ from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from deerflow.models.assistant_payload_replay import restore_assistant_payloads, restore_reasoning_content
|
||||
|
||||
_MISSING = object()
|
||||
|
||||
|
||||
@ -45,12 +47,6 @@ def _with_reasoning_content(message: AIMessage | AIMessageChunk, reasoning: str)
|
||||
return message.model_copy(update={"additional_kwargs": additional_kwargs})
|
||||
|
||||
|
||||
def _restore_reasoning_content(payload_msg: dict, orig_msg: AIMessage) -> None:
|
||||
reasoning = orig_msg.additional_kwargs.get("reasoning_content")
|
||||
if reasoning is not None:
|
||||
payload_msg["reasoning_content"] = reasoning
|
||||
|
||||
|
||||
def _get_typed_choice_message(response: Any, index: int) -> Any:
|
||||
choices = getattr(response, "choices", None)
|
||||
if choices is None:
|
||||
@ -81,17 +77,11 @@ class PatchedChatMiMo(ChatOpenAI):
|
||||
) -> dict:
|
||||
original_messages = self._convert_input(input_).to_messages()
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
payload_messages = payload.get("messages", [])
|
||||
|
||||
if len(payload_messages) == len(original_messages):
|
||||
for payload_msg, orig_msg in zip(payload_messages, original_messages):
|
||||
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
|
||||
_restore_reasoning_content(payload_msg, orig_msg)
|
||||
else:
|
||||
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
|
||||
assistant_payloads = [m for m in payload_messages if m.get("role") == "assistant"]
|
||||
for payload_msg, ai_msg in zip(assistant_payloads, ai_messages):
|
||||
_restore_reasoning_content(payload_msg, ai_msg)
|
||||
restore_assistant_payloads(
|
||||
payload.get("messages", []),
|
||||
original_messages,
|
||||
restore_reasoning_content,
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
@ -27,6 +27,8 @@ from langchain_core.language_models import LanguageModelInput
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from deerflow.models.assistant_payload_replay import restore_assistant_payloads
|
||||
|
||||
|
||||
class PatchedChatOpenAI(ChatOpenAI):
|
||||
"""ChatOpenAI with ``thought_signature`` preservation for Gemini thinking via OpenAI gateway.
|
||||
@ -75,18 +77,7 @@ class PatchedChatOpenAI(ChatOpenAI):
|
||||
# Obtain the base payload from the parent implementation.
|
||||
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
||||
|
||||
payload_messages = payload.get("messages", [])
|
||||
|
||||
if len(payload_messages) == len(original_messages):
|
||||
for payload_msg, orig_msg in zip(payload_messages, original_messages):
|
||||
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
|
||||
_restore_tool_call_signatures(payload_msg, orig_msg)
|
||||
else:
|
||||
# Fallback: match assistant-role entries positionally against AIMessages.
|
||||
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
|
||||
assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"]
|
||||
for (_, payload_msg), ai_msg in zip(assistant_payloads, ai_messages):
|
||||
_restore_tool_call_signatures(payload_msg, ai_msg)
|
||||
restore_assistant_payloads(payload.get("messages", []), original_messages, _restore_tool_call_signatures)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
166
backend/tests/test_assistant_payload_replay.py
Normal file
166
backend/tests/test_assistant_payload_replay.py
Normal file
@ -0,0 +1,166 @@
|
||||
"""Tests for shared assistant payload replay helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from deerflow.models.assistant_payload_replay import (
|
||||
restore_additional_kwargs_field,
|
||||
restore_assistant_payloads,
|
||||
restore_reasoning_content,
|
||||
)
|
||||
|
||||
|
||||
def _restore_reasoning(payload_msg: dict, orig_msg: AIMessage) -> None:
|
||||
restore_additional_kwargs_field(payload_msg, orig_msg, "reasoning_content")
|
||||
|
||||
|
||||
def test_restore_additional_kwargs_field_copies_present_values_only():
|
||||
payload_message = {"role": "assistant"}
|
||||
orig_message = AIMessage(
|
||||
content="answer",
|
||||
additional_kwargs={
|
||||
"reasoning_content": "",
|
||||
"ignored_none": None,
|
||||
},
|
||||
)
|
||||
|
||||
restore_additional_kwargs_field(payload_message, orig_message, "reasoning_content")
|
||||
restore_additional_kwargs_field(payload_message, orig_message, "ignored_none")
|
||||
restore_additional_kwargs_field(payload_message, orig_message, "missing")
|
||||
|
||||
assert payload_message == {"role": "assistant", "reasoning_content": ""}
|
||||
|
||||
|
||||
def test_restore_reasoning_content_copies_reasoning_content():
|
||||
payload_message = {"role": "assistant"}
|
||||
orig_message = AIMessage(content="answer", additional_kwargs={"reasoning_content": "thought"})
|
||||
|
||||
restore_reasoning_content(payload_message, orig_message)
|
||||
|
||||
assert payload_message["reasoning_content"] == "thought"
|
||||
|
||||
|
||||
def test_restore_assistant_payloads_matches_by_position_when_lengths_match():
|
||||
original_messages = [
|
||||
HumanMessage(content="question"),
|
||||
AIMessage(content="answer", additional_kwargs={"reasoning_content": "thought"}),
|
||||
]
|
||||
payload_messages = [
|
||||
{"role": "user", "content": "question"},
|
||||
{"role": "assistant", "content": "answer"},
|
||||
]
|
||||
|
||||
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
|
||||
|
||||
assert payload_messages[1]["reasoning_content"] == "thought"
|
||||
|
||||
|
||||
def test_restore_assistant_payloads_fallback_matches_unique_content_signature():
|
||||
original_messages = [
|
||||
AIMessage(content="first", additional_kwargs={"reasoning_content": "first-thought"}),
|
||||
AIMessage(content="second", additional_kwargs={"reasoning_content": "second-thought"}),
|
||||
]
|
||||
payload_messages = [{"role": "assistant", "content": "second"}]
|
||||
|
||||
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
|
||||
|
||||
assert payload_messages[0]["reasoning_content"] == "second-thought"
|
||||
|
||||
|
||||
def test_restore_assistant_payloads_fallback_matches_unique_tool_call_signature():
|
||||
original_messages = [
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"reasoning_content": "first-thought"},
|
||||
tool_calls=[{"id": "call_first", "name": "tool", "args": {}}],
|
||||
),
|
||||
AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"reasoning_content": "second-thought"},
|
||||
tool_calls=[{"id": "call_second", "name": "tool", "args": {}}],
|
||||
),
|
||||
]
|
||||
payload_messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [{"id": "call_second", "type": "function", "function": {"name": "tool", "arguments": "{}"}}],
|
||||
}
|
||||
]
|
||||
|
||||
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
|
||||
|
||||
assert payload_messages[0]["reasoning_content"] == "second-thought"
|
||||
|
||||
|
||||
def test_restore_assistant_payloads_fallback_matches_structured_content_signature():
|
||||
original_messages = [
|
||||
AIMessage(
|
||||
content=[{"type": "text", "text": "first"}],
|
||||
additional_kwargs={"reasoning_content": "first-thought"},
|
||||
),
|
||||
AIMessage(
|
||||
content=[{"type": "text", "text": "second"}],
|
||||
additional_kwargs={"reasoning_content": "second-thought"},
|
||||
),
|
||||
]
|
||||
payload_messages = [{"role": "assistant", "content": [{"text": "second", "type": "text"}]}]
|
||||
|
||||
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
|
||||
|
||||
assert payload_messages[0]["reasoning_content"] == "second-thought"
|
||||
|
||||
|
||||
def test_restore_assistant_payloads_fallback_uses_order_when_signature_is_ambiguous():
|
||||
original_messages = [
|
||||
AIMessage(content="", additional_kwargs={"reasoning_content": "first-thought"}),
|
||||
AIMessage(content="", additional_kwargs={"reasoning_content": "second-thought"}),
|
||||
]
|
||||
payload_messages = [{"role": "assistant", "content": ""}]
|
||||
|
||||
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
|
||||
|
||||
assert payload_messages[0]["reasoning_content"] == "first-thought"
|
||||
|
||||
|
||||
def test_restore_assistant_payloads_fallback_uses_next_unused_when_ordinal_taken():
|
||||
# Serialization dropped a leading empty assistant message, so payload ordinals
|
||||
# no longer line up with the original AIMessage indices. The first payload
|
||||
# uniquely matches a non-ordinal index by signature, which leaves the later
|
||||
# ambiguous payload's exact ordinal index already used. It must still fall
|
||||
# back to the remaining unused AIMessage (scanning forward from the ordinal)
|
||||
# instead of silently dropping the field.
|
||||
original_messages = [
|
||||
AIMessage(content="", additional_kwargs={"reasoning_content": "dropped-thought"}),
|
||||
AIMessage(content="unique", additional_kwargs={"reasoning_content": "unique-thought"}),
|
||||
AIMessage(content="", additional_kwargs={"reasoning_content": "trailing-thought"}),
|
||||
]
|
||||
payload_messages = [
|
||||
{"role": "assistant", "content": "unique"},
|
||||
{"role": "assistant", "content": ""},
|
||||
]
|
||||
|
||||
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
|
||||
|
||||
assert payload_messages[0]["reasoning_content"] == "unique-thought"
|
||||
# Forward scan from the taken ordinal picks the trailing message, not the
|
||||
# dropped leading one (which a naive min-unused scan would wrongly select).
|
||||
assert payload_messages[1]["reasoning_content"] == "trailing-thought"
|
||||
|
||||
|
||||
def test_restore_assistant_payloads_does_not_wrap_to_earlier_unused_message():
|
||||
original_messages = [
|
||||
HumanMessage(content="leading user"),
|
||||
AIMessage(content="", additional_kwargs={"reasoning_content": "dropped-leading-thought"}),
|
||||
AIMessage(content="unique", additional_kwargs={"reasoning_content": "unique-thought"}),
|
||||
]
|
||||
payload_messages = [
|
||||
{"role": "assistant", "content": "unique"},
|
||||
{"role": "assistant", "content": ""},
|
||||
]
|
||||
|
||||
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
|
||||
|
||||
assert payload_messages[0]["reasoning_content"] == "unique-thought"
|
||||
assert "reasoning_content" not in payload_messages[1]
|
||||
Loading…
x
Reference in New Issue
Block a user