mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-04-28 12:48:40 +00:00
chore(adpator):Adapt MindIE engine model and improve testing and fixes (#2523)
* feat(models): 适配 MindIE引擎的模型 * test: add unit tests for MindIEChatModel adapter and fix PR review comments * chore: update uv.lock with pytest-asyncio * build: add pytest-asyncio to test dependencies * fix: address PR review comments (lazy import, cache clients, safe newline escape, strict xml regex) * fix(mindie): preserve string args without JSON quotes in XML tool call serialization * fix(mindie): preserve string args without JSON quotes in XML tool call serialization * test_mindie_provider:format * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix(mindie): prevent nested tool_call params from leaking into outer args * fixed by escaping XML entities in _fix_messages and unescaping during parse, with regression tests added. --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
6bd88fe14c
commit
395c14357b
@ -1,4 +1,5 @@
|
||||
import ast
|
||||
import html
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
@ -36,8 +37,8 @@ def _fix_messages(messages: list) -> list:
|
||||
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", []):
|
||||
xml_parts = []
|
||||
for tool in msg.tool_calls:
|
||||
args_xml = " ".join(f"<parameter={k}>{json.dumps(v, ensure_ascii=False)}</parameter>" for k, v in tool.get("args", {}).items())
|
||||
xml_parts.append(f"<tool_call> <function={tool['name']}> {args_xml} </function> </tool_call>")
|
||||
args_xml = " ".join(f"<parameter={html.escape(str(k), quote=False)}>{html.escape(v if isinstance(v, str) else json.dumps(v, ensure_ascii=False), quote=False)}</parameter>" for k, v in tool.get("args", {}).items())
|
||||
xml_parts.append(f"<tool_call> <function={html.escape(str(tool['name']), quote=False)}> {args_xml} </function> </tool_call>")
|
||||
full_text = f"{text}\n" + "\n".join(xml_parts) if text else "\n".join(xml_parts)
|
||||
fixed.append(AIMessage(content=full_text.strip() or " "))
|
||||
continue
|
||||
@ -80,13 +81,24 @@ def _parse_xml_tool_call_to_dict(content: str) -> tuple[str, list[dict]]:
|
||||
func_match = re.search(r"<function=([^>]+)>", inner_content)
|
||||
if not func_match:
|
||||
continue
|
||||
function_name = func_match.group(1).strip()
|
||||
function_name = html.unescape(func_match.group(1).strip())
|
||||
|
||||
# Ignore nested tool blocks when extracting parameters for this call.
|
||||
# Nested `<tool_call>` sections represent separate invocations and
|
||||
# their `<parameter>` tags must not leak into the current call args.
|
||||
param_source_parts: list[str] = []
|
||||
nested_cursor = 0
|
||||
for nested_start, nested_end, _ in _iter_tool_call_blocks(inner_content):
|
||||
param_source_parts.append(inner_content[nested_cursor:nested_start])
|
||||
nested_cursor = nested_end
|
||||
param_source_parts.append(inner_content[nested_cursor:])
|
||||
param_source = "".join(param_source_parts)
|
||||
|
||||
args = {}
|
||||
param_pattern = re.compile(r"<parameter=([^>]+)>(.*?)</parameter>", re.DOTALL)
|
||||
for param_match in param_pattern.finditer(inner_content):
|
||||
key = param_match.group(1).strip()
|
||||
raw_value = param_match.group(2).strip()
|
||||
for param_match in param_pattern.finditer(param_source):
|
||||
key = html.unescape(param_match.group(1).strip())
|
||||
raw_value = html.unescape(param_match.group(2).strip())
|
||||
|
||||
# Attempt to deserialize string values into native Python types
|
||||
# to satisfy downstream Pydantic validation.
|
||||
|
||||
@ -91,7 +91,7 @@ class TestFixMessages:
|
||||
assert isinstance(out, AIMessage)
|
||||
assert "<tool_call>" in out.content
|
||||
assert "<function=get_weather>" in out.content
|
||||
assert '<parameter=city>"London"</parameter>' in out.content
|
||||
assert "<parameter=city>London</parameter>" in out.content
|
||||
assert not getattr(out, "tool_calls", [])
|
||||
|
||||
def test_ai_message_text_preserved_before_xml(self):
|
||||
@ -116,6 +116,22 @@ class TestFixMessages:
|
||||
assert "<function=tool_a>" in content
|
||||
assert "<function=tool_b>" in content
|
||||
|
||||
def test_ai_message_tool_args_are_xml_escaped(self):
|
||||
msg = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "fn<&>",
|
||||
"args": {"k<&>": "v<&>"},
|
||||
"id": "id1",
|
||||
}
|
||||
],
|
||||
)
|
||||
result = _fix_messages([msg])
|
||||
content = result[0].content
|
||||
assert "<function=fn<&>>" in content
|
||||
assert "<parameter=k<&>>v<&></parameter>" in content
|
||||
|
||||
# ── ToolMessage → HumanMessage ────────────────────────────────────────────
|
||||
|
||||
def test_tool_message_becomes_human_message(self):
|
||||
@ -185,6 +201,15 @@ class TestParseXmlToolCalls:
|
||||
assert calls[0]["name"] == "a"
|
||||
assert calls[1]["name"] == "b"
|
||||
|
||||
def test_nested_tool_call_blocks_do_not_break_parsing(self):
|
||||
content = "<tool_call><function=outer><parameter=q>1</parameter><tool_call><function=inner><parameter=x>2</parameter></function></tool_call></function></tool_call>"
|
||||
clean, calls = _parse_xml_tool_call_to_dict(content)
|
||||
assert clean == ""
|
||||
assert len(calls) == 1
|
||||
assert calls[0]["name"] == "outer"
|
||||
assert calls[0]["args"] == {"q": 1}
|
||||
assert "x" not in calls[0]["args"]
|
||||
|
||||
def test_text_before_tool_call_preserved(self):
|
||||
content = "Here is the answer.\n<tool_call><function=f><parameter=k>v</parameter></function></tool_call>"
|
||||
clean, calls = _parse_xml_tool_call_to_dict(content)
|
||||
@ -226,6 +251,12 @@ class TestParseXmlToolCalls:
|
||||
_, c2 = _parse_xml_tool_call_to_dict(block)
|
||||
assert c1[0]["id"] != c2[0]["id"]
|
||||
|
||||
def test_escaped_entities_are_unescaped(self):
|
||||
content = "<tool_call><function=fn<&>><parameter=k<&>>v<&></parameter></function></tool_call>"
|
||||
_, calls = _parse_xml_tool_call_to_dict(content)
|
||||
assert calls[0]["name"] == "fn<&>"
|
||||
assert calls[0]["args"]["k<&>"] == "v<&>"
|
||||
|
||||
|
||||
# ═════════════════════════════════════════════════════════════════════════════
|
||||
# 3. MindIEChatModel._patch_result_with_tools
|
||||
@ -244,6 +275,12 @@ class TestPatchResult:
|
||||
patched = model._patch_result_with_tools(result)
|
||||
assert patched.generations[0].message.content == "line1\nline2"
|
||||
|
||||
def test_escaped_newlines_inside_code_fence_preserved(self):
|
||||
model = self._model()
|
||||
result = _make_chat_result('text\\n```json\n{"k":"a\\\\nb"}\n```\\nend')
|
||||
patched = model._patch_result_with_tools(result)
|
||||
assert patched.generations[0].message.content == 'text\n```json\n{"k":"a\\\\nb"}\n```\nend'
|
||||
|
||||
def test_xml_tool_calls_extracted(self):
|
||||
model = self._model()
|
||||
content = "<tool_call><function=calc><parameter=expr>1+1</parameter></function></tool_call>"
|
||||
@ -281,6 +318,50 @@ class TestPatchResult:
|
||||
assert patched is not None
|
||||
|
||||
|
||||
class TestMindIEInit:
|
||||
def test_timeout_kwargs_are_normalized(self):
|
||||
captured = {}
|
||||
|
||||
def fake_init(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
|
||||
with patch("deerflow.models.mindie_provider.ChatOpenAI.__init__", new=fake_init):
|
||||
MindIEChatModel(
|
||||
model="mindie-test",
|
||||
api_key="test-key",
|
||||
connect_timeout=1.0,
|
||||
read_timeout=2.0,
|
||||
write_timeout=3.0,
|
||||
pool_timeout=4.0,
|
||||
)
|
||||
|
||||
timeout = captured.get("timeout")
|
||||
assert timeout is not None
|
||||
assert timeout.connect == 1.0
|
||||
assert timeout.read == 2.0
|
||||
assert timeout.write == 3.0
|
||||
assert timeout.pool == 4.0
|
||||
|
||||
def test_explicit_timeout_takes_precedence(self):
|
||||
captured = {}
|
||||
|
||||
def fake_init(self, **kwargs):
|
||||
captured.update(kwargs)
|
||||
|
||||
with patch("deerflow.models.mindie_provider.ChatOpenAI.__init__", new=fake_init):
|
||||
MindIEChatModel(
|
||||
model="mindie-test",
|
||||
api_key="test-key",
|
||||
timeout=9.0,
|
||||
connect_timeout=1.0,
|
||||
read_timeout=2.0,
|
||||
write_timeout=3.0,
|
||||
pool_timeout=4.0,
|
||||
)
|
||||
|
||||
assert captured.get("timeout") == 9.0
|
||||
|
||||
|
||||
# ═════════════════════════════════════════════════════════════════════════════
|
||||
# 4. MindIEChatModel._generate (sync)
|
||||
# ═════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user