mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-30 20:38:09 +00:00
* Share assistant payload replay matching * fix(provider): recover assistant field when ordinal AI index is taken The mismatch-length fallback in `_match_ai_message` only tried the exact `fallback_ordinal` AI index. When serialization drops or reorders an assistant message, a unique signature match can consume a non-ordinal index, leaving a later ambiguous payload's ordinal already used — so its provider field (e.g. `reasoning_content`) was silently dropped. Scan forward from the ordinal for the next unused `AIMessage` (wrapping to earlier indices) to preserve the positional bias while still recovering the field. Forward scanning avoids a naive min-unused pick that could restore the wrong field after a leading message is dropped. Add a regression test for the dropped-leading-message case. * fix(provider): avoid earlier assistant fallback replay
141 lines
5.4 KiB
Python
141 lines
5.4 KiB
Python
"""Patched ChatOpenAI adapter for Xiaomi MiMo reasoning_content replay.
|
|
|
|
MiMo's OpenAI-compatible API returns ``reasoning_content`` in thinking mode and
|
|
requires that value to be replayed on historical assistant messages in
|
|
multi-turn agent conversations. Standard ``langchain_openai.ChatOpenAI`` drops
|
|
that provider-specific field, which can cause HTTP 400 errors once tool calls
|
|
enter the conversation history.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Mapping
|
|
from typing import Any
|
|
|
|
from langchain_core.language_models import LanguageModelInput
|
|
from langchain_core.messages import AIMessage, AIMessageChunk
|
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
from langchain_openai import ChatOpenAI
|
|
|
|
from deerflow.models.assistant_payload_replay import restore_assistant_payloads, restore_reasoning_content
|
|
|
|
_MISSING = object()
|
|
|
|
|
|
def _extract_reasoning_content(value: Any) -> str | object:
|
|
"""Return reasoning_content from a dict/Pydantic object, preserving empty strings."""
|
|
if isinstance(value, Mapping):
|
|
if "reasoning_content" in value and value["reasoning_content"] is not None:
|
|
return value["reasoning_content"]
|
|
return _MISSING
|
|
|
|
reasoning = getattr(value, "reasoning_content", _MISSING)
|
|
if reasoning is not _MISSING and reasoning is not None:
|
|
return reasoning
|
|
|
|
model_extra = getattr(value, "model_extra", None)
|
|
if isinstance(model_extra, Mapping) and "reasoning_content" in model_extra and model_extra["reasoning_content"] is not None:
|
|
return model_extra["reasoning_content"]
|
|
|
|
return _MISSING
|
|
|
|
|
|
def _with_reasoning_content(message: AIMessage | AIMessageChunk, reasoning: str) -> AIMessage | AIMessageChunk:
|
|
additional_kwargs = dict(message.additional_kwargs)
|
|
if additional_kwargs.get("reasoning_content") != reasoning:
|
|
additional_kwargs["reasoning_content"] = reasoning
|
|
return message.model_copy(update={"additional_kwargs": additional_kwargs})
|
|
|
|
|
|
def _get_typed_choice_message(response: Any, index: int) -> Any:
|
|
choices = getattr(response, "choices", None)
|
|
if choices is None:
|
|
return None
|
|
try:
|
|
return choices[index].message
|
|
except (AttributeError, IndexError, TypeError):
|
|
return None
|
|
|
|
|
|
class PatchedChatMiMo(ChatOpenAI):
|
|
"""ChatOpenAI with ``reasoning_content`` preservation for MiMo thinking mode."""
|
|
|
|
@classmethod
|
|
def is_lc_serializable(cls) -> bool:
|
|
return True
|
|
|
|
@property
|
|
def lc_secrets(self) -> dict[str, str]:
|
|
return {"api_key": "MIMO_API_KEY", "openai_api_key": "MIMO_API_KEY"}
|
|
|
|
def _get_request_payload(
|
|
self,
|
|
input_: LanguageModelInput,
|
|
*,
|
|
stop: list[str] | None = None,
|
|
**kwargs: Any,
|
|
) -> dict:
|
|
original_messages = self._convert_input(input_).to_messages()
|
|
payload = super()._get_request_payload(input_, stop=stop, **kwargs)
|
|
restore_assistant_payloads(
|
|
payload.get("messages", []),
|
|
original_messages,
|
|
restore_reasoning_content,
|
|
)
|
|
|
|
return payload
|
|
|
|
def _convert_chunk_to_generation_chunk(
|
|
self,
|
|
chunk: dict,
|
|
default_chunk_class: type,
|
|
base_generation_info: dict | None,
|
|
) -> ChatGenerationChunk | None:
|
|
generation_chunk = super()._convert_chunk_to_generation_chunk(
|
|
chunk,
|
|
default_chunk_class,
|
|
base_generation_info,
|
|
)
|
|
if generation_chunk is None:
|
|
return None
|
|
|
|
choices = chunk.get("choices", [])
|
|
if choices:
|
|
delta = choices[0].get("delta") or {}
|
|
reasoning = _extract_reasoning_content(delta)
|
|
if reasoning is not _MISSING and isinstance(generation_chunk.message, AIMessageChunk):
|
|
generation_chunk = ChatGenerationChunk(
|
|
message=_with_reasoning_content(generation_chunk.message, reasoning),
|
|
generation_info=generation_chunk.generation_info,
|
|
)
|
|
|
|
return generation_chunk
|
|
|
|
def _create_chat_result(
|
|
self,
|
|
response: dict | Any,
|
|
generation_info: dict | None = None,
|
|
) -> ChatResult:
|
|
result = super()._create_chat_result(response, generation_info)
|
|
response_dict = response if isinstance(response, dict) else response.model_dump()
|
|
choices = response_dict.get("choices", [])
|
|
|
|
patched_generations: list[ChatGeneration] | None = None
|
|
for index, generation in enumerate(result.generations):
|
|
choice = choices[index] if index < len(choices) else {}
|
|
choice_message = choice.get("message", {}) if isinstance(choice, Mapping) else {}
|
|
reasoning = _extract_reasoning_content(choice_message)
|
|
if reasoning is _MISSING and not isinstance(response, dict):
|
|
reasoning = _extract_reasoning_content(_get_typed_choice_message(response, index))
|
|
|
|
message = generation.message
|
|
if reasoning is not _MISSING and isinstance(message, AIMessage):
|
|
if patched_generations is None:
|
|
patched_generations = list(result.generations)
|
|
patched_generations[index] = ChatGeneration(
|
|
message=_with_reasoning_content(message, reasoning),
|
|
generation_info=generation.generation_info,
|
|
)
|
|
|
|
return ChatResult(generations=patched_generations or result.generations, llm_output=result.llm_output)
|