refactor(provider): share assistant payload replay matching (#3307)

* Share assistant payload replay matching

* fix(provider): recover assistant field when ordinal AI index is taken

The mismatch-length fallback in `_match_ai_message` only tried the exact
`fallback_ordinal` AI index. When serialization drops or reorders an
assistant message, a unique signature match can consume a non-ordinal
index, leaving a later ambiguous payload's ordinal already used — so its
provider field (e.g. `reasoning_content`) was silently dropped.

Scan forward from the ordinal for the next unused `AIMessage` (wrapping to
earlier indices) to preserve the positional bias while still recovering
the field. Forward scanning avoids a naive min-unused pick that could
restore the wrong field after a leading message is dropped.

Add a regression test for the dropped-leading-message case.

* fix(provider): avoid earlier assistant fallback replay
This commit is contained in:
AochenShen99 2026-05-29 23:05:59 +08:00 committed by GitHub
parent 052b1e2102
commit 4093c83383
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 307 additions and 50 deletions

View File

@ -0,0 +1,124 @@
"""Helpers for replaying provider-specific assistant message fields.
Several provider adapters need to preserve fields that LangChain stores on the
original ``AIMessage`` but drops when serializing request payloads. This module
keeps the assistant-message matching logic shared while letting each provider
decide which fields to restore.
"""
from __future__ import annotations
import json
from collections.abc import Callable, Sequence
from typing import Any
from langchain_core.messages import AIMessage, BaseMessage
AssistantPayloadRestorer = Callable[[dict[str, Any], AIMessage], None]
def restore_assistant_payloads(
payload_messages: Sequence[dict[str, Any]],
original_messages: Sequence[BaseMessage],
restore: AssistantPayloadRestorer,
) -> None:
"""Restore provider-specific fields onto serialized assistant payloads."""
if len(payload_messages) == len(original_messages):
for payload_msg, orig_msg in zip(payload_messages, original_messages):
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
restore(payload_msg, orig_msg)
return
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
assistant_payloads = [m for m in payload_messages if m.get("role") == "assistant"]
used_ai_indexes: set[int] = set()
for ordinal, payload_msg in enumerate(assistant_payloads):
ai_msg = _match_ai_message(payload_msg, ai_messages, used_ai_indexes, ordinal)
if ai_msg is not None:
restore(payload_msg, ai_msg)
def restore_additional_kwargs_field(payload_msg: dict[str, Any], orig_msg: AIMessage, field_name: str) -> None:
"""Copy a provider-specific ``additional_kwargs`` field onto a payload message."""
value = orig_msg.additional_kwargs.get(field_name)
if value is not None:
payload_msg[field_name] = value
def restore_reasoning_content(payload_msg: dict[str, Any], orig_msg: AIMessage) -> None:
"""Copy provider reasoning content onto a serialized assistant payload."""
restore_additional_kwargs_field(payload_msg, orig_msg, "reasoning_content")
def _match_ai_message(
payload_msg: dict[str, Any],
ai_messages: Sequence[AIMessage],
used_ai_indexes: set[int],
fallback_ordinal: int,
) -> AIMessage | None:
payload_key = _assistant_signature(payload_msg)
if payload_key is not None:
matches = [index for index, ai_msg in enumerate(ai_messages) if index not in used_ai_indexes and _ai_signature(ai_msg) == payload_key]
if len(matches) == 1:
used_ai_indexes.add(matches[0])
return ai_messages[matches[0]]
fallback_index = _next_unused_index_at_or_after(len(ai_messages), used_ai_indexes, fallback_ordinal)
if fallback_index is not None:
used_ai_indexes.add(fallback_index)
return ai_messages[fallback_index]
return None
def _next_unused_index_at_or_after(count: int, used_ai_indexes: set[int], start: int) -> int | None:
"""Return the next unused AI index at or after ``start``.
Scanning forward from the payload's ordinal preserves the positional bias of
the previous behaviour while still recovering when serialization drops or
reorders messages so the exact ordinal index is already taken. It does not
wrap to earlier indexes because those messages may be represented by payload
entries that were already dropped.
"""
if count == 0 or start >= count:
return None
for index in range(start, count):
if index not in used_ai_indexes:
return index
return None
def _assistant_signature(payload_msg: dict[str, Any]) -> tuple[str, str] | None:
return _signature(
payload_msg.get("content"),
_tool_call_ids(payload_msg.get("tool_calls") or []),
)
def _ai_signature(message: AIMessage) -> tuple[str, str] | None:
tool_calls = message.tool_calls or message.additional_kwargs.get("tool_calls") or []
return _signature(message.content, _tool_call_ids(tool_calls))
def _signature(content: Any, tool_call_ids: tuple[str, ...]) -> tuple[str, str] | None:
if content in (None, "") and not tool_call_ids:
return None
return (_stable_repr(content), "|".join(tool_call_ids))
def _stable_repr(value: Any) -> str:
try:
return json.dumps(value, sort_keys=True, ensure_ascii=False)
except TypeError:
return repr(value)
def _tool_call_ids(tool_calls: Sequence[Any]) -> tuple[str, ...]:
ids: list[str] = []
for tool_call in tool_calls:
if isinstance(tool_call, dict):
call_id = tool_call.get("id")
if isinstance(call_id, str) and call_id:
ids.append(call_id)
return tuple(ids)

View File

@ -10,9 +10,10 @@ on all assistant messages when thinking mode is enabled.
from typing import Any
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage
from langchain_deepseek import ChatDeepSeek
from deerflow.models.assistant_payload_replay import restore_assistant_payloads, restore_reasoning_content
class PatchedChatDeepSeek(ChatDeepSeek):
"""ChatDeepSeek with proper reasoning_content preservation.
@ -49,25 +50,10 @@ class PatchedChatDeepSeek(ChatDeepSeek):
# Call parent to get the base payload
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
# Match payload messages with original messages to restore reasoning_content
payload_messages = payload.get("messages", [])
# The payload messages and original messages should be in the same order
# Iterate through both and match by position
if len(payload_messages) == len(original_messages):
for payload_msg, orig_msg in zip(payload_messages, original_messages):
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
reasoning_content = orig_msg.additional_kwargs.get("reasoning_content")
if reasoning_content is not None:
payload_msg["reasoning_content"] = reasoning_content
else:
# Fallback: match by counting assistant messages
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"]
for (idx, payload_msg), ai_msg in zip(assistant_payloads, ai_messages):
reasoning_content = ai_msg.additional_kwargs.get("reasoning_content")
if reasoning_content is not None:
payload_messages[idx]["reasoning_content"] = reasoning_content
restore_assistant_payloads(
payload.get("messages", []),
original_messages,
restore_reasoning_content,
)
return payload

View File

@ -17,6 +17,8 @@ from langchain_core.messages import AIMessage, AIMessageChunk
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_openai import ChatOpenAI
from deerflow.models.assistant_payload_replay import restore_assistant_payloads, restore_reasoning_content
_MISSING = object()
@ -45,12 +47,6 @@ def _with_reasoning_content(message: AIMessage | AIMessageChunk, reasoning: str)
return message.model_copy(update={"additional_kwargs": additional_kwargs})
def _restore_reasoning_content(payload_msg: dict, orig_msg: AIMessage) -> None:
reasoning = orig_msg.additional_kwargs.get("reasoning_content")
if reasoning is not None:
payload_msg["reasoning_content"] = reasoning
def _get_typed_choice_message(response: Any, index: int) -> Any:
choices = getattr(response, "choices", None)
if choices is None:
@ -81,17 +77,11 @@ class PatchedChatMiMo(ChatOpenAI):
) -> dict:
original_messages = self._convert_input(input_).to_messages()
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
payload_messages = payload.get("messages", [])
if len(payload_messages) == len(original_messages):
for payload_msg, orig_msg in zip(payload_messages, original_messages):
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
_restore_reasoning_content(payload_msg, orig_msg)
else:
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
assistant_payloads = [m for m in payload_messages if m.get("role") == "assistant"]
for payload_msg, ai_msg in zip(assistant_payloads, ai_messages):
_restore_reasoning_content(payload_msg, ai_msg)
restore_assistant_payloads(
payload.get("messages", []),
original_messages,
restore_reasoning_content,
)
return payload

View File

@ -27,6 +27,8 @@ from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import AIMessage
from langchain_openai import ChatOpenAI
from deerflow.models.assistant_payload_replay import restore_assistant_payloads
class PatchedChatOpenAI(ChatOpenAI):
"""ChatOpenAI with ``thought_signature`` preservation for Gemini thinking via OpenAI gateway.
@ -75,18 +77,7 @@ class PatchedChatOpenAI(ChatOpenAI):
# Obtain the base payload from the parent implementation.
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
payload_messages = payload.get("messages", [])
if len(payload_messages) == len(original_messages):
for payload_msg, orig_msg in zip(payload_messages, original_messages):
if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage):
_restore_tool_call_signatures(payload_msg, orig_msg)
else:
# Fallback: match assistant-role entries positionally against AIMessages.
ai_messages = [m for m in original_messages if isinstance(m, AIMessage)]
assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"]
for (_, payload_msg), ai_msg in zip(assistant_payloads, ai_messages):
_restore_tool_call_signatures(payload_msg, ai_msg)
restore_assistant_payloads(payload.get("messages", []), original_messages, _restore_tool_call_signatures)
return payload

View File

@ -0,0 +1,166 @@
"""Tests for shared assistant payload replay helpers."""
from __future__ import annotations
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.models.assistant_payload_replay import (
restore_additional_kwargs_field,
restore_assistant_payloads,
restore_reasoning_content,
)
def _restore_reasoning(payload_msg: dict, orig_msg: AIMessage) -> None:
restore_additional_kwargs_field(payload_msg, orig_msg, "reasoning_content")
def test_restore_additional_kwargs_field_copies_present_values_only():
payload_message = {"role": "assistant"}
orig_message = AIMessage(
content="answer",
additional_kwargs={
"reasoning_content": "",
"ignored_none": None,
},
)
restore_additional_kwargs_field(payload_message, orig_message, "reasoning_content")
restore_additional_kwargs_field(payload_message, orig_message, "ignored_none")
restore_additional_kwargs_field(payload_message, orig_message, "missing")
assert payload_message == {"role": "assistant", "reasoning_content": ""}
def test_restore_reasoning_content_copies_reasoning_content():
payload_message = {"role": "assistant"}
orig_message = AIMessage(content="answer", additional_kwargs={"reasoning_content": "thought"})
restore_reasoning_content(payload_message, orig_message)
assert payload_message["reasoning_content"] == "thought"
def test_restore_assistant_payloads_matches_by_position_when_lengths_match():
original_messages = [
HumanMessage(content="question"),
AIMessage(content="answer", additional_kwargs={"reasoning_content": "thought"}),
]
payload_messages = [
{"role": "user", "content": "question"},
{"role": "assistant", "content": "answer"},
]
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
assert payload_messages[1]["reasoning_content"] == "thought"
def test_restore_assistant_payloads_fallback_matches_unique_content_signature():
original_messages = [
AIMessage(content="first", additional_kwargs={"reasoning_content": "first-thought"}),
AIMessage(content="second", additional_kwargs={"reasoning_content": "second-thought"}),
]
payload_messages = [{"role": "assistant", "content": "second"}]
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
assert payload_messages[0]["reasoning_content"] == "second-thought"
def test_restore_assistant_payloads_fallback_matches_unique_tool_call_signature():
original_messages = [
AIMessage(
content="",
additional_kwargs={"reasoning_content": "first-thought"},
tool_calls=[{"id": "call_first", "name": "tool", "args": {}}],
),
AIMessage(
content="",
additional_kwargs={"reasoning_content": "second-thought"},
tool_calls=[{"id": "call_second", "name": "tool", "args": {}}],
),
]
payload_messages = [
{
"role": "assistant",
"content": "",
"tool_calls": [{"id": "call_second", "type": "function", "function": {"name": "tool", "arguments": "{}"}}],
}
]
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
assert payload_messages[0]["reasoning_content"] == "second-thought"
def test_restore_assistant_payloads_fallback_matches_structured_content_signature():
original_messages = [
AIMessage(
content=[{"type": "text", "text": "first"}],
additional_kwargs={"reasoning_content": "first-thought"},
),
AIMessage(
content=[{"type": "text", "text": "second"}],
additional_kwargs={"reasoning_content": "second-thought"},
),
]
payload_messages = [{"role": "assistant", "content": [{"text": "second", "type": "text"}]}]
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
assert payload_messages[0]["reasoning_content"] == "second-thought"
def test_restore_assistant_payloads_fallback_uses_order_when_signature_is_ambiguous():
original_messages = [
AIMessage(content="", additional_kwargs={"reasoning_content": "first-thought"}),
AIMessage(content="", additional_kwargs={"reasoning_content": "second-thought"}),
]
payload_messages = [{"role": "assistant", "content": ""}]
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
assert payload_messages[0]["reasoning_content"] == "first-thought"
def test_restore_assistant_payloads_fallback_uses_next_unused_when_ordinal_taken():
# Serialization dropped a leading empty assistant message, so payload ordinals
# no longer line up with the original AIMessage indices. The first payload
# uniquely matches a non-ordinal index by signature, which leaves the later
# ambiguous payload's exact ordinal index already used. It must still fall
# back to the remaining unused AIMessage (scanning forward from the ordinal)
# instead of silently dropping the field.
original_messages = [
AIMessage(content="", additional_kwargs={"reasoning_content": "dropped-thought"}),
AIMessage(content="unique", additional_kwargs={"reasoning_content": "unique-thought"}),
AIMessage(content="", additional_kwargs={"reasoning_content": "trailing-thought"}),
]
payload_messages = [
{"role": "assistant", "content": "unique"},
{"role": "assistant", "content": ""},
]
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
assert payload_messages[0]["reasoning_content"] == "unique-thought"
# Forward scan from the taken ordinal picks the trailing message, not the
# dropped leading one (which a naive min-unused scan would wrongly select).
assert payload_messages[1]["reasoning_content"] == "trailing-thought"
def test_restore_assistant_payloads_does_not_wrap_to_earlier_unused_message():
original_messages = [
HumanMessage(content="leading user"),
AIMessage(content="", additional_kwargs={"reasoning_content": "dropped-leading-thought"}),
AIMessage(content="unique", additional_kwargs={"reasoning_content": "unique-thought"}),
]
payload_messages = [
{"role": "assistant", "content": "unique"},
{"role": "assistant", "content": ""},
]
restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning)
assert payload_messages[0]["reasoning_content"] == "unique-thought"
assert "reasoning_content" not in payload_messages[1]