diff --git a/backend/packages/harness/deerflow/models/assistant_payload_replay.py b/backend/packages/harness/deerflow/models/assistant_payload_replay.py new file mode 100644 index 000000000..cffba5e99 --- /dev/null +++ b/backend/packages/harness/deerflow/models/assistant_payload_replay.py @@ -0,0 +1,124 @@ +"""Helpers for replaying provider-specific assistant message fields. + +Several provider adapters need to preserve fields that LangChain stores on the +original ``AIMessage`` but drops when serializing request payloads. This module +keeps the assistant-message matching logic shared while letting each provider +decide which fields to restore. +""" + +from __future__ import annotations + +import json +from collections.abc import Callable, Sequence +from typing import Any + +from langchain_core.messages import AIMessage, BaseMessage + +AssistantPayloadRestorer = Callable[[dict[str, Any], AIMessage], None] + + +def restore_assistant_payloads( + payload_messages: Sequence[dict[str, Any]], + original_messages: Sequence[BaseMessage], + restore: AssistantPayloadRestorer, +) -> None: + """Restore provider-specific fields onto serialized assistant payloads.""" + if len(payload_messages) == len(original_messages): + for payload_msg, orig_msg in zip(payload_messages, original_messages): + if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage): + restore(payload_msg, orig_msg) + return + + ai_messages = [m for m in original_messages if isinstance(m, AIMessage)] + assistant_payloads = [m for m in payload_messages if m.get("role") == "assistant"] + used_ai_indexes: set[int] = set() + + for ordinal, payload_msg in enumerate(assistant_payloads): + ai_msg = _match_ai_message(payload_msg, ai_messages, used_ai_indexes, ordinal) + if ai_msg is not None: + restore(payload_msg, ai_msg) + + +def restore_additional_kwargs_field(payload_msg: dict[str, Any], orig_msg: AIMessage, field_name: str) -> None: + """Copy a provider-specific ``additional_kwargs`` field onto a payload message.""" + value = orig_msg.additional_kwargs.get(field_name) + if value is not None: + payload_msg[field_name] = value + + +def restore_reasoning_content(payload_msg: dict[str, Any], orig_msg: AIMessage) -> None: + """Copy provider reasoning content onto a serialized assistant payload.""" + restore_additional_kwargs_field(payload_msg, orig_msg, "reasoning_content") + + +def _match_ai_message( + payload_msg: dict[str, Any], + ai_messages: Sequence[AIMessage], + used_ai_indexes: set[int], + fallback_ordinal: int, +) -> AIMessage | None: + payload_key = _assistant_signature(payload_msg) + if payload_key is not None: + matches = [index for index, ai_msg in enumerate(ai_messages) if index not in used_ai_indexes and _ai_signature(ai_msg) == payload_key] + if len(matches) == 1: + used_ai_indexes.add(matches[0]) + return ai_messages[matches[0]] + + fallback_index = _next_unused_index_at_or_after(len(ai_messages), used_ai_indexes, fallback_ordinal) + if fallback_index is not None: + used_ai_indexes.add(fallback_index) + return ai_messages[fallback_index] + + return None + + +def _next_unused_index_at_or_after(count: int, used_ai_indexes: set[int], start: int) -> int | None: + """Return the next unused AI index at or after ``start``. + + Scanning forward from the payload's ordinal preserves the positional bias of + the previous behaviour while still recovering when serialization drops or + reorders messages so the exact ordinal index is already taken. It does not + wrap to earlier indexes because those messages may be represented by payload + entries that were already dropped. + """ + if count == 0 or start >= count: + return None + for index in range(start, count): + if index not in used_ai_indexes: + return index + return None + + +def _assistant_signature(payload_msg: dict[str, Any]) -> tuple[str, str] | None: + return _signature( + payload_msg.get("content"), + _tool_call_ids(payload_msg.get("tool_calls") or []), + ) + + +def _ai_signature(message: AIMessage) -> tuple[str, str] | None: + tool_calls = message.tool_calls or message.additional_kwargs.get("tool_calls") or [] + return _signature(message.content, _tool_call_ids(tool_calls)) + + +def _signature(content: Any, tool_call_ids: tuple[str, ...]) -> tuple[str, str] | None: + if content in (None, "") and not tool_call_ids: + return None + return (_stable_repr(content), "|".join(tool_call_ids)) + + +def _stable_repr(value: Any) -> str: + try: + return json.dumps(value, sort_keys=True, ensure_ascii=False) + except TypeError: + return repr(value) + + +def _tool_call_ids(tool_calls: Sequence[Any]) -> tuple[str, ...]: + ids: list[str] = [] + for tool_call in tool_calls: + if isinstance(tool_call, dict): + call_id = tool_call.get("id") + if isinstance(call_id, str) and call_id: + ids.append(call_id) + return tuple(ids) diff --git a/backend/packages/harness/deerflow/models/patched_deepseek.py b/backend/packages/harness/deerflow/models/patched_deepseek.py index b25e60911..341ff9b6b 100644 --- a/backend/packages/harness/deerflow/models/patched_deepseek.py +++ b/backend/packages/harness/deerflow/models/patched_deepseek.py @@ -10,9 +10,10 @@ on all assistant messages when thinking mode is enabled. from typing import Any from langchain_core.language_models import LanguageModelInput -from langchain_core.messages import AIMessage from langchain_deepseek import ChatDeepSeek +from deerflow.models.assistant_payload_replay import restore_assistant_payloads, restore_reasoning_content + class PatchedChatDeepSeek(ChatDeepSeek): """ChatDeepSeek with proper reasoning_content preservation. @@ -49,25 +50,10 @@ class PatchedChatDeepSeek(ChatDeepSeek): # Call parent to get the base payload payload = super()._get_request_payload(input_, stop=stop, **kwargs) - # Match payload messages with original messages to restore reasoning_content - payload_messages = payload.get("messages", []) - - # The payload messages and original messages should be in the same order - # Iterate through both and match by position - if len(payload_messages) == len(original_messages): - for payload_msg, orig_msg in zip(payload_messages, original_messages): - if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage): - reasoning_content = orig_msg.additional_kwargs.get("reasoning_content") - if reasoning_content is not None: - payload_msg["reasoning_content"] = reasoning_content - else: - # Fallback: match by counting assistant messages - ai_messages = [m for m in original_messages if isinstance(m, AIMessage)] - assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"] - - for (idx, payload_msg), ai_msg in zip(assistant_payloads, ai_messages): - reasoning_content = ai_msg.additional_kwargs.get("reasoning_content") - if reasoning_content is not None: - payload_messages[idx]["reasoning_content"] = reasoning_content + restore_assistant_payloads( + payload.get("messages", []), + original_messages, + restore_reasoning_content, + ) return payload diff --git a/backend/packages/harness/deerflow/models/patched_mimo.py b/backend/packages/harness/deerflow/models/patched_mimo.py index 7589af78f..3851748a4 100644 --- a/backend/packages/harness/deerflow/models/patched_mimo.py +++ b/backend/packages/harness/deerflow/models/patched_mimo.py @@ -17,6 +17,8 @@ from langchain_core.messages import AIMessage, AIMessageChunk from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_openai import ChatOpenAI +from deerflow.models.assistant_payload_replay import restore_assistant_payloads, restore_reasoning_content + _MISSING = object() @@ -45,12 +47,6 @@ def _with_reasoning_content(message: AIMessage | AIMessageChunk, reasoning: str) return message.model_copy(update={"additional_kwargs": additional_kwargs}) -def _restore_reasoning_content(payload_msg: dict, orig_msg: AIMessage) -> None: - reasoning = orig_msg.additional_kwargs.get("reasoning_content") - if reasoning is not None: - payload_msg["reasoning_content"] = reasoning - - def _get_typed_choice_message(response: Any, index: int) -> Any: choices = getattr(response, "choices", None) if choices is None: @@ -81,17 +77,11 @@ class PatchedChatMiMo(ChatOpenAI): ) -> dict: original_messages = self._convert_input(input_).to_messages() payload = super()._get_request_payload(input_, stop=stop, **kwargs) - payload_messages = payload.get("messages", []) - - if len(payload_messages) == len(original_messages): - for payload_msg, orig_msg in zip(payload_messages, original_messages): - if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage): - _restore_reasoning_content(payload_msg, orig_msg) - else: - ai_messages = [m for m in original_messages if isinstance(m, AIMessage)] - assistant_payloads = [m for m in payload_messages if m.get("role") == "assistant"] - for payload_msg, ai_msg in zip(assistant_payloads, ai_messages): - _restore_reasoning_content(payload_msg, ai_msg) + restore_assistant_payloads( + payload.get("messages", []), + original_messages, + restore_reasoning_content, + ) return payload diff --git a/backend/packages/harness/deerflow/models/patched_openai.py b/backend/packages/harness/deerflow/models/patched_openai.py index 9a7801f48..cf98141de 100644 --- a/backend/packages/harness/deerflow/models/patched_openai.py +++ b/backend/packages/harness/deerflow/models/patched_openai.py @@ -27,6 +27,8 @@ from langchain_core.language_models import LanguageModelInput from langchain_core.messages import AIMessage from langchain_openai import ChatOpenAI +from deerflow.models.assistant_payload_replay import restore_assistant_payloads + class PatchedChatOpenAI(ChatOpenAI): """ChatOpenAI with ``thought_signature`` preservation for Gemini thinking via OpenAI gateway. @@ -75,18 +77,7 @@ class PatchedChatOpenAI(ChatOpenAI): # Obtain the base payload from the parent implementation. payload = super()._get_request_payload(input_, stop=stop, **kwargs) - payload_messages = payload.get("messages", []) - - if len(payload_messages) == len(original_messages): - for payload_msg, orig_msg in zip(payload_messages, original_messages): - if payload_msg.get("role") == "assistant" and isinstance(orig_msg, AIMessage): - _restore_tool_call_signatures(payload_msg, orig_msg) - else: - # Fallback: match assistant-role entries positionally against AIMessages. - ai_messages = [m for m in original_messages if isinstance(m, AIMessage)] - assistant_payloads = [(i, m) for i, m in enumerate(payload_messages) if m.get("role") == "assistant"] - for (_, payload_msg), ai_msg in zip(assistant_payloads, ai_messages): - _restore_tool_call_signatures(payload_msg, ai_msg) + restore_assistant_payloads(payload.get("messages", []), original_messages, _restore_tool_call_signatures) return payload diff --git a/backend/tests/test_assistant_payload_replay.py b/backend/tests/test_assistant_payload_replay.py new file mode 100644 index 000000000..433ce9f85 --- /dev/null +++ b/backend/tests/test_assistant_payload_replay.py @@ -0,0 +1,166 @@ +"""Tests for shared assistant payload replay helpers.""" + +from __future__ import annotations + +from langchain_core.messages import AIMessage, HumanMessage + +from deerflow.models.assistant_payload_replay import ( + restore_additional_kwargs_field, + restore_assistant_payloads, + restore_reasoning_content, +) + + +def _restore_reasoning(payload_msg: dict, orig_msg: AIMessage) -> None: + restore_additional_kwargs_field(payload_msg, orig_msg, "reasoning_content") + + +def test_restore_additional_kwargs_field_copies_present_values_only(): + payload_message = {"role": "assistant"} + orig_message = AIMessage( + content="answer", + additional_kwargs={ + "reasoning_content": "", + "ignored_none": None, + }, + ) + + restore_additional_kwargs_field(payload_message, orig_message, "reasoning_content") + restore_additional_kwargs_field(payload_message, orig_message, "ignored_none") + restore_additional_kwargs_field(payload_message, orig_message, "missing") + + assert payload_message == {"role": "assistant", "reasoning_content": ""} + + +def test_restore_reasoning_content_copies_reasoning_content(): + payload_message = {"role": "assistant"} + orig_message = AIMessage(content="answer", additional_kwargs={"reasoning_content": "thought"}) + + restore_reasoning_content(payload_message, orig_message) + + assert payload_message["reasoning_content"] == "thought" + + +def test_restore_assistant_payloads_matches_by_position_when_lengths_match(): + original_messages = [ + HumanMessage(content="question"), + AIMessage(content="answer", additional_kwargs={"reasoning_content": "thought"}), + ] + payload_messages = [ + {"role": "user", "content": "question"}, + {"role": "assistant", "content": "answer"}, + ] + + restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning) + + assert payload_messages[1]["reasoning_content"] == "thought" + + +def test_restore_assistant_payloads_fallback_matches_unique_content_signature(): + original_messages = [ + AIMessage(content="first", additional_kwargs={"reasoning_content": "first-thought"}), + AIMessage(content="second", additional_kwargs={"reasoning_content": "second-thought"}), + ] + payload_messages = [{"role": "assistant", "content": "second"}] + + restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning) + + assert payload_messages[0]["reasoning_content"] == "second-thought" + + +def test_restore_assistant_payloads_fallback_matches_unique_tool_call_signature(): + original_messages = [ + AIMessage( + content="", + additional_kwargs={"reasoning_content": "first-thought"}, + tool_calls=[{"id": "call_first", "name": "tool", "args": {}}], + ), + AIMessage( + content="", + additional_kwargs={"reasoning_content": "second-thought"}, + tool_calls=[{"id": "call_second", "name": "tool", "args": {}}], + ), + ] + payload_messages = [ + { + "role": "assistant", + "content": "", + "tool_calls": [{"id": "call_second", "type": "function", "function": {"name": "tool", "arguments": "{}"}}], + } + ] + + restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning) + + assert payload_messages[0]["reasoning_content"] == "second-thought" + + +def test_restore_assistant_payloads_fallback_matches_structured_content_signature(): + original_messages = [ + AIMessage( + content=[{"type": "text", "text": "first"}], + additional_kwargs={"reasoning_content": "first-thought"}, + ), + AIMessage( + content=[{"type": "text", "text": "second"}], + additional_kwargs={"reasoning_content": "second-thought"}, + ), + ] + payload_messages = [{"role": "assistant", "content": [{"text": "second", "type": "text"}]}] + + restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning) + + assert payload_messages[0]["reasoning_content"] == "second-thought" + + +def test_restore_assistant_payloads_fallback_uses_order_when_signature_is_ambiguous(): + original_messages = [ + AIMessage(content="", additional_kwargs={"reasoning_content": "first-thought"}), + AIMessage(content="", additional_kwargs={"reasoning_content": "second-thought"}), + ] + payload_messages = [{"role": "assistant", "content": ""}] + + restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning) + + assert payload_messages[0]["reasoning_content"] == "first-thought" + + +def test_restore_assistant_payloads_fallback_uses_next_unused_when_ordinal_taken(): + # Serialization dropped a leading empty assistant message, so payload ordinals + # no longer line up with the original AIMessage indices. The first payload + # uniquely matches a non-ordinal index by signature, which leaves the later + # ambiguous payload's exact ordinal index already used. It must still fall + # back to the remaining unused AIMessage (scanning forward from the ordinal) + # instead of silently dropping the field. + original_messages = [ + AIMessage(content="", additional_kwargs={"reasoning_content": "dropped-thought"}), + AIMessage(content="unique", additional_kwargs={"reasoning_content": "unique-thought"}), + AIMessage(content="", additional_kwargs={"reasoning_content": "trailing-thought"}), + ] + payload_messages = [ + {"role": "assistant", "content": "unique"}, + {"role": "assistant", "content": ""}, + ] + + restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning) + + assert payload_messages[0]["reasoning_content"] == "unique-thought" + # Forward scan from the taken ordinal picks the trailing message, not the + # dropped leading one (which a naive min-unused scan would wrongly select). + assert payload_messages[1]["reasoning_content"] == "trailing-thought" + + +def test_restore_assistant_payloads_does_not_wrap_to_earlier_unused_message(): + original_messages = [ + HumanMessage(content="leading user"), + AIMessage(content="", additional_kwargs={"reasoning_content": "dropped-leading-thought"}), + AIMessage(content="unique", additional_kwargs={"reasoning_content": "unique-thought"}), + ] + payload_messages = [ + {"role": "assistant", "content": "unique"}, + {"role": "assistant", "content": ""}, + ] + + restore_assistant_payloads(payload_messages, original_messages, _restore_reasoning) + + assert payload_messages[0]["reasoning_content"] == "unique-thought" + assert "reasoning_content" not in payload_messages[1]