From 2bb1a2dfa28fb79a308b5f980fabd44693bcd0f7 Mon Sep 17 00:00:00 2001 From: pyp0327 <108285878+pyp0327@users.noreply.github.com> Date: Sat, 25 Apr 2026 08:59:03 +0800 Subject: [PATCH] feat(models): Provider for MindIE model engine (#2483) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(models): 适配 MindIE引擎的模型 * test: add unit tests for MindIEChatModel adapter and fix PR review comments * chore: update uv.lock with pytest-asyncio * build: add pytest-asyncio to test dependencies * fix: address PR review comments (lazy import, cache clients, safe newline escape, strict xml regex) --------- Co-authored-by: Willem Jiang --- .../harness/deerflow/models/factory.py | 6 + .../deerflow/models/mindie_provider.py | 237 +++++++++++ backend/pyproject.toml | 7 +- backend/tests/test_mindie_provider.py | 397 ++++++++++++++++++ backend/uv.lock | 15 + config.example.yaml | 21 + 6 files changed, 682 insertions(+), 1 deletion(-) create mode 100644 backend/packages/harness/deerflow/models/mindie_provider.py create mode 100644 backend/tests/test_mindie_provider.py diff --git a/backend/packages/harness/deerflow/models/factory.py b/backend/packages/harness/deerflow/models/factory.py index bd2828e94..aec9b291a 100644 --- a/backend/packages/harness/deerflow/models/factory.py +++ b/backend/packages/harness/deerflow/models/factory.py @@ -131,6 +131,12 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, * elif "reasoning_effort" not in model_settings_from_config: model_settings_from_config["reasoning_effort"] = "medium" + # For MindIE models: enforce conservative retry defaults. + # Timeout normalization is handled inside MindIEChatModel itself. + if getattr(model_class, "__name__", "") == "MindIEChatModel": + # Enforce max_retries constraint to prevent cascading timeouts. + model_settings_from_config["max_retries"] = model_settings_from_config.get("max_retries", 1) + model_instance = model_class(**{**model_settings_from_config, **kwargs}) callbacks = build_tracing_callbacks() diff --git a/backend/packages/harness/deerflow/models/mindie_provider.py b/backend/packages/harness/deerflow/models/mindie_provider.py new file mode 100644 index 000000000..5f0d12e83 --- /dev/null +++ b/backend/packages/harness/deerflow/models/mindie_provider.py @@ -0,0 +1,237 @@ +import ast +import json +import re +import uuid +from collections.abc import Iterator + +import httpx +from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage, ToolMessage +from langchain_core.outputs import ChatGenerationChunk, ChatResult +from langchain_openai import ChatOpenAI + + +def _fix_messages(messages: list) -> list: + """Sanitize incoming messages for MindIE compatibility. + + MindIE's chat template may fail to parse LangChain's native tool_calls + or ToolMessage roles, resulting in 0-token generation errors. This function + flattens multi-modal list contents into strings and converts tool-related + messages into raw text with XML tags expected by the underlying model. + """ + fixed = [] + for msg in messages: + # Flatten content if it's a list of blocks + if isinstance(msg.content, list): + parts = [] + for block in msg.content: + if isinstance(block, str): + parts.append(block) + elif isinstance(block, dict) and block.get("type") == "text": + parts.append(block.get("text", "")) + text = "".join(parts) + else: + text = msg.content or "" + + # Convert AIMessage with tool_calls to raw XML text format + if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", []): + xml_parts = [] + for tool in msg.tool_calls: + args_xml = " ".join(f"{json.dumps(v, ensure_ascii=False)}" for k, v in tool.get("args", {}).items()) + xml_parts.append(f" {args_xml} ") + full_text = f"{text}\n" + "\n".join(xml_parts) if text else "\n".join(xml_parts) + fixed.append(AIMessage(content=full_text.strip() or " ")) + continue + + # Wrap tool execution results in XML tags and convert to HumanMessage + if isinstance(msg, ToolMessage): + tool_result_text = f"\n{text}\n" + fixed.append(HumanMessage(content=tool_result_text)) + continue + + # Fallback to prevent completely empty message content + if not text.strip(): + text = " " + + fixed.append(msg.model_copy(update={"content": text})) + + return fixed + + +def _parse_xml_tool_call_to_dict(content: str) -> tuple[str, list[dict]]: + """Parse XML-style tool calls from model output into LangChain dicts. + + Args: + content: The raw text output from the model. + + Returns: + A tuple containing the cleaned text (with XML blocks removed) and + a list of tool call dictionaries formatted for LangChain. + """ + if not isinstance(content, str) or "" not in content: + return content, [] + + tool_calls = [] + clean_parts: list[str] = [] + cursor = 0 + for start, end, inner_content in _iter_tool_call_blocks(content): + clean_parts.append(content[cursor:start]) + cursor = end + + func_match = re.search(r"]+)>", inner_content) + if not func_match: + continue + function_name = func_match.group(1).strip() + + args = {} + param_pattern = re.compile(r"]+)>(.*?)", re.DOTALL) + for param_match in param_pattern.finditer(inner_content): + key = param_match.group(1).strip() + raw_value = param_match.group(2).strip() + + # Attempt to deserialize string values into native Python types + # to satisfy downstream Pydantic validation. + parsed_value = raw_value + if raw_value.startswith(("[", "{")) or raw_value in ("true", "false", "null") or raw_value.isdigit(): + try: + parsed_value = json.loads(raw_value) + except json.JSONDecodeError: + try: + parsed_value = ast.literal_eval(raw_value) + except (ValueError, SyntaxError): + pass + + args[key] = parsed_value + + tool_calls.append({"name": function_name, "args": args, "id": f"call_{uuid.uuid4().hex[:10]}"}) + clean_parts.append(content[cursor:]) + + return "".join(clean_parts).strip(), tool_calls + + +def _iter_tool_call_blocks(content: str) -> Iterator[tuple[int, int, str]]: + """Iterate `...` blocks and tolerate nesting.""" + token_pattern = re.compile(r"") + depth = 0 + block_start = -1 + + for match in token_pattern.finditer(content): + token = match.group(0) + if token == "": + if depth == 0: + block_start = match.start() + depth += 1 + continue + + if depth == 0: + continue + + depth -= 1 + if depth == 0 and block_start != -1: + block_end = match.end() + inner_start = block_start + len("") + inner_end = match.start() + yield block_start, block_end, content[inner_start:inner_end] + block_start = -1 + + +def _decode_escaped_newlines_outside_fences(content: str) -> str: + """Decode literal `\\n` outside fenced code blocks.""" + if "\\n" not in content: + return content + + parts = re.split(r"(```[\s\S]*?```)", content) + for idx, part in enumerate(parts): + if part.startswith("```"): + continue + parts[idx] = part.replace("\\n", "\n") + return "".join(parts) + + +class MindIEChatModel(ChatOpenAI): + """Chat model adapter for MindIE engine. + + Addresses compatibility issues including: + - Flattening multimodal list contents to strings. + - Intercepting and parsing hardcoded XML tool calls into LangChain standard. + - Handling stream=True dropping choices when tools are present by falling back + to non-streaming generation and yielding simulated chunks. + - Fixing over-escaped newline characters from gateway responses. + """ + + def __init__(self, **kwargs): + """Normalize timeout kwargs without creating long-lived clients.""" + connect_timeout = kwargs.pop("connect_timeout", 30.0) + read_timeout = kwargs.pop("read_timeout", 900.0) + write_timeout = kwargs.pop("write_timeout", 60.0) + pool_timeout = kwargs.pop("pool_timeout", 30.0) + + kwargs.setdefault( + "timeout", + httpx.Timeout( + connect=connect_timeout, + read=read_timeout, + write=write_timeout, + pool=pool_timeout, + ), + ) + super().__init__(**kwargs) + + def _patch_result_with_tools(self, result: ChatResult) -> ChatResult: + """Apply post-generation fixes to the model result.""" + for gen in result.generations: + msg = gen.message + + if isinstance(msg.content, str): + # Keep escaped newlines inside fenced code blocks untouched. + msg.content = _decode_escaped_newlines_outside_fences(msg.content) + + if "" in msg.content: + clean_content, extracted_tools = _parse_xml_tool_call_to_dict(msg.content) + + if extracted_tools: + msg.content = clean_content + if getattr(msg, "tool_calls", None) is None: + msg.tool_calls = [] + msg.tool_calls.extend(extracted_tools) + return result + + def _generate(self, messages, stop=None, run_manager=None, **kwargs): + result = super()._generate(_fix_messages(messages), stop=stop, run_manager=run_manager, **kwargs) + return self._patch_result_with_tools(result) + + async def _agenerate(self, messages, stop=None, run_manager=None, **kwargs): + result = await super()._agenerate(_fix_messages(messages), stop=stop, run_manager=run_manager, **kwargs) + return self._patch_result_with_tools(result) + + async def _astream(self, messages, stop=None, run_manager=None, **kwargs): + # Route standard queries to native streaming for lower TTFB + if not kwargs.get("tools"): + async for chunk in super()._astream(_fix_messages(messages), stop=stop, run_manager=run_manager, **kwargs): + if isinstance(chunk.message.content, str): + chunk.message.content = _decode_escaped_newlines_outside_fences(chunk.message.content) + yield chunk + return + + # Fallback for tool-enabled requests: + # MindIE currently drops choices when stream=True and tools are present. + # We await the full generation and yield chunks to simulate streaming. + result = await self._agenerate(messages, stop=stop, run_manager=run_manager, **kwargs) + + for gen in result.generations: + msg = gen.message + content = msg.content + standard_tool_calls = getattr(msg, "tool_calls", []) + + # Yield text in chunks to allow downstream UI/Markdown parsers to render smoothly + if isinstance(content, str) and content: + chunk_size = 15 + for i in range(0, len(content), chunk_size): + chunk_text = content[i : i + chunk_size] + chunk_msg = AIMessageChunk(content=chunk_text, id=msg.id, response_metadata=msg.response_metadata if i == 0 else {}) + yield ChatGenerationChunk(message=chunk_msg, generation_info=gen.generation_info if i == 0 else None) + + if standard_tool_calls: + yield ChatGenerationChunk(message=AIMessageChunk(content="", id=msg.id, tool_calls=standard_tool_calls, invalid_tool_calls=getattr(msg, "invalid_tool_calls", []))) + else: + chunk_msg = AIMessageChunk(content=content, id=msg.id, tool_calls=standard_tool_calls, invalid_tool_calls=getattr(msg, "invalid_tool_calls", [])) + yield ChatGenerationChunk(message=chunk_msg, generation_info=gen.generation_info) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 220ac23d6..fe280d701 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -20,7 +20,12 @@ dependencies = [ ] [dependency-groups] -dev = ["prompt-toolkit>=3.0.0", "pytest>=9.0.3", "ruff>=0.14.11"] +dev = [ + "prompt-toolkit>=3.0.0", + "pytest>=9.0.3", + "pytest-asyncio>=1.3.0", + "ruff>=0.14.11", +] [tool.uv.workspace] members = ["packages/harness"] diff --git a/backend/tests/test_mindie_provider.py b/backend/tests/test_mindie_provider.py new file mode 100644 index 000000000..552966c37 --- /dev/null +++ b/backend/tests/test_mindie_provider.py @@ -0,0 +1,397 @@ +""" +Unit tests for MindIEChatModel adapter. +""" + +from unittest.mock import AsyncMock, patch + +import pytest +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.outputs import ChatGeneration, ChatResult + +# ── Import the module under test ────────────────────────────────────────────── +from deerflow.models.mindie_provider import ( + MindIEChatModel, + _fix_messages, + _parse_xml_tool_call_to_dict, +) + +# ═════════════════════════════════════════════════════════════════════════════ +# Helpers +# ═════════════════════════════════════════════════════════════════════════════ + + +def _make_chat_result(content: str, tool_calls=None) -> ChatResult: + msg = AIMessage(content=content) + if tool_calls: + msg.tool_calls = tool_calls + gen = ChatGeneration(message=msg) + return ChatResult(generations=[gen]) + + +# ═════════════════════════════════════════════════════════════════════════════ +# 1. _fix_messages +# ═════════════════════════════════════════════════════════════════════════════ + + +class TestFixMessages: + # ── list content → str ──────────────────────────────────────────────────── + + def test_list_content_extracted_to_str(self): + msg = HumanMessage( + content=[ + {"type": "text", "text": "Hello"}, + {"type": "text", "text": " world"}, + ] + ) + result = _fix_messages([msg]) + assert result[0].content == "Hello world" + + def test_list_content_ignores_non_text_blocks(self): + msg = HumanMessage( + content=[ + {"type": "image_url", "image_url": "http://x.com/img.png"}, + {"type": "text", "text": "caption"}, + ] + ) + result = _fix_messages([msg]) + assert result[0].content == "caption" + + def test_empty_list_content_becomes_space(self): + msg = HumanMessage(content=[]) + result = _fix_messages([msg]) + assert result[0].content == " " + + # ── plain str content ───────────────────────────────────────────────────── + + def test_plain_string_content_preserved(self): + msg = HumanMessage(content="hi there") + result = _fix_messages([msg]) + assert result[0].content == "hi there" + + def test_empty_string_content_becomes_space(self): + msg = HumanMessage(content="") + result = _fix_messages([msg]) + assert result[0].content == " " + + # ── AIMessage with tool_calls → XML ─────────────────────────────────────── + + def test_ai_message_with_tool_calls_serialised_to_xml(self): + msg = AIMessage( + content="Sure", + tool_calls=[ + { + "name": "get_weather", + "args": {"city": "London"}, + "id": "call_abc", + } + ], + ) + result = _fix_messages([msg]) + out = result[0] + assert isinstance(out, AIMessage) + assert "" in out.content + assert "" in out.content + assert '"London"' in out.content + assert not getattr(out, "tool_calls", []) + + def test_ai_message_text_preserved_before_xml(self): + msg = AIMessage( + content="Here you go", + tool_calls=[{"name": "search", "args": {"q": "pytest"}, "id": "x"}], + ) + result = _fix_messages([msg]) + assert result[0].content.startswith("Here you go") + + def test_ai_message_multiple_tool_calls(self): + msg = AIMessage( + content="", + tool_calls=[ + {"name": "tool_a", "args": {"x": 1}, "id": "id1"}, + {"name": "tool_b", "args": {"y": 2}, "id": "id2"}, + ], + ) + result = _fix_messages([msg]) + content = result[0].content + assert content.count("") == 2 + assert "" in content + assert "" in content + + # ── ToolMessage → HumanMessage ──────────────────────────────────────────── + + def test_tool_message_becomes_human_message(self): + msg = ToolMessage(content="42 degrees", tool_call_id="call_abc") + result = _fix_messages([msg]) + out = result[0] + assert isinstance(out, HumanMessage) + assert "" in out.content + assert "42 degrees" in out.content + + def test_tool_message_with_list_content(self): + msg = ToolMessage( + content=[{"type": "text", "text": "result"}], + tool_call_id="call_xyz", + ) + result = _fix_messages([msg]) + assert isinstance(result[0], HumanMessage) + assert "result" in result[0].content + + # ── Mixed message list ──────────────────────────────────────────────────── + + def test_mixed_message_types_ordering_preserved(self): + msgs = [ + HumanMessage(content="q"), + AIMessage(content="a"), + ToolMessage(content="tool out", tool_call_id="c1"), + HumanMessage(content="follow up"), + ] + result = _fix_messages(msgs) + assert len(result) == 4 + assert isinstance(result[2], HumanMessage) + assert result[3].content == "follow up" + + # ── SystemMessage pass-through ──────────────────────────────────────────── + + def test_system_message_passed_through_unchanged(self): + msg = SystemMessage(content="You are helpful.") + result = _fix_messages([msg]) + assert result[0].content == "You are helpful." + + +# ═════════════════════════════════════════════════════════════════════════════ +# 2. _parse_xml_tool_call_to_dict +# ═════════════════════════════════════════════════════════════════════════════ + + +class TestParseXmlToolCalls: + def test_no_tool_call_returns_original(self): + content = "Just a normal reply." + clean, calls = _parse_xml_tool_call_to_dict(content) + assert clean == content + assert calls == [] + + def test_single_tool_call_parsed(self): + content = " pytest " + clean, calls = _parse_xml_tool_call_to_dict(content) + assert clean == "" + assert len(calls) == 1 + assert calls[0]["name"] == "search" + assert calls[0]["args"]["query"] == "pytest" + assert calls[0]["id"].startswith("call_") + + def test_multiple_tool_calls_parsed(self): + content = "12" + _, calls = _parse_xml_tool_call_to_dict(content) + assert len(calls) == 2 + assert calls[0]["name"] == "a" + assert calls[1]["name"] == "b" + + def test_text_before_tool_call_preserved(self): + content = "Here is the answer.\nv" + clean, calls = _parse_xml_tool_call_to_dict(content) + assert clean == "Here is the answer." + assert len(calls) == 1 + + def test_integer_param_deserialised(self): + content = "42" + _, calls = _parse_xml_tool_call_to_dict(content) + assert calls[0]["args"]["n"] == 42 + + def test_list_param_deserialised(self): + content = '["a","b"]' + _, calls = _parse_xml_tool_call_to_dict(content) + assert calls[0]["args"]["lst"] == ["a", "b"] + + def test_dict_param_deserialised(self): + content = '{"k": 1}' + _, calls = _parse_xml_tool_call_to_dict(content) + assert calls[0]["args"]["d"] == {"k": 1} + + def test_bool_param_deserialised(self): + content = "true" + _, calls = _parse_xml_tool_call_to_dict(content) + assert calls[0]["args"]["flag"] is True + + def test_malformed_param_stays_string(self): + content = "{broken json" + _, calls = _parse_xml_tool_call_to_dict(content) + assert calls[0]["args"]["bad"] == "{broken json" + + def test_non_string_input_returned_as_is(self): + result = _parse_xml_tool_call_to_dict(None) + assert result == (None, []) + + def test_unique_ids_generated(self): + block = "v" + _, c1 = _parse_xml_tool_call_to_dict(block) + _, c2 = _parse_xml_tool_call_to_dict(block) + assert c1[0]["id"] != c2[0]["id"] + + +# ═════════════════════════════════════════════════════════════════════════════ +# 3. MindIEChatModel._patch_result_with_tools +# ═════════════════════════════════════════════════════════════════════════════ + + +class TestPatchResult: + def _model(self): + with patch.object(MindIEChatModel, "__init__", return_value=None): + m = MindIEChatModel.__new__(MindIEChatModel) + return m + + def test_escaped_newlines_fixed(self): + model = self._model() + result = _make_chat_result("line1\\nline2") + patched = model._patch_result_with_tools(result) + assert patched.generations[0].message.content == "line1\nline2" + + def test_xml_tool_calls_extracted(self): + model = self._model() + content = "1+1" + result = _make_chat_result(content) + patched = model._patch_result_with_tools(result) + msg = patched.generations[0].message + assert msg.content == "" + assert len(msg.tool_calls) == 1 + assert msg.tool_calls[0]["name"] == "calc" + + def test_patch_result_appends_to_existing_tool_calls(self): + model = self._model() + existing = [{"name": "existing", "args": {}, "id": "e1"}] + content = "v" + result = _make_chat_result(content, tool_calls=existing) + patched = model._patch_result_with_tools(result) + msg = patched.generations[0].message + assert len(msg.tool_calls) == 2 + names = [tc["name"] for tc in msg.tool_calls] + assert "existing" in names + assert "new_tool" in names + + def test_no_tool_call_content_unchanged(self): + model = self._model() + result = _make_chat_result("plain reply") + patched = model._patch_result_with_tools(result) + assert patched.generations[0].message.content == "plain reply" + + def test_non_string_content_skipped(self): + model = self._model() + msg = AIMessage(content=[{"type": "text", "text": "hi"}]) + gen = ChatGeneration(message=msg) + result = ChatResult(generations=[gen]) + patched = model._patch_result_with_tools(result) + assert patched is not None + + +# ═════════════════════════════════════════════════════════════════════════════ +# 4. MindIEChatModel._generate (sync) +# ═════════════════════════════════════════════════════════════════════════════ + + +class TestGenerate: + def test_generate_calls_fix_messages_and_patch(self): + with patch("deerflow.models.mindie_provider.ChatOpenAI._generate") as mock_super_gen, patch.object(MindIEChatModel, "__init__", return_value=None): + mock_super_gen.return_value = _make_chat_result("hello") + model = MindIEChatModel.__new__(MindIEChatModel) + + msgs = [HumanMessage(content="ping")] + result = model._generate(msgs) + + assert mock_super_gen.called + called_msgs = mock_super_gen.call_args[0][0] + assert all(isinstance(m.content, str) for m in called_msgs) + assert result.generations[0].message.content == "hello" + + +# ═════════════════════════════════════════════════════════════════════════════ +# 5. MindIEChatModel._agenerate (async) +# ═════════════════════════════════════════════════════════════════════════════ + + +class TestAGenerate: + @pytest.mark.asyncio + async def test_agenerate_patches_result(self): + with patch("deerflow.models.mindie_provider.ChatOpenAI._agenerate", new_callable=AsyncMock) as mock_ag, patch.object(MindIEChatModel, "__init__", return_value=None): + mock_ag.return_value = _make_chat_result("world\\nfoo") + model = MindIEChatModel.__new__(MindIEChatModel) + + result = await model._agenerate([HumanMessage(content="hi")]) + assert result.generations[0].message.content == "world\nfoo" + + +# ═════════════════════════════════════════════════════════════════════════════ +# 6. MindIEChatModel._astream (async generator) +# ═════════════════════════════════════════════════════════════════════════════ + + +class TestAStream: + async def _collect(self, gen): + chunks = [] + async for chunk in gen: + chunks.append(chunk) + return chunks + + @pytest.mark.asyncio + async def test_no_tools_uses_real_stream(self): + from langchain_core.messages import AIMessageChunk + from langchain_core.outputs import ChatGenerationChunk + + async def fake_stream(*args, **kwargs): + for char in ["hel", "lo"]: + yield ChatGenerationChunk(message=AIMessageChunk(content=char)) + + with patch("deerflow.models.mindie_provider.ChatOpenAI._astream", side_effect=fake_stream), patch.object(MindIEChatModel, "__init__", return_value=None): + model = MindIEChatModel.__new__(MindIEChatModel) + chunks = await self._collect(model._astream([HumanMessage(content="hi")])) + + assert "".join(c.message.content for c in chunks) == "hello" + + @pytest.mark.asyncio + async def test_no_tools_fixes_escaped_newlines_in_stream(self): + from langchain_core.messages import AIMessageChunk + from langchain_core.outputs import ChatGenerationChunk + + async def fake_stream(*args, **kwargs): + yield ChatGenerationChunk(message=AIMessageChunk(content="a\\nb")) + + with patch("deerflow.models.mindie_provider.ChatOpenAI._astream", side_effect=fake_stream), patch.object(MindIEChatModel, "__init__", return_value=None): + model = MindIEChatModel.__new__(MindIEChatModel) + chunks = await self._collect(model._astream([HumanMessage(content="x")])) + + assert chunks[0].message.content == "a\nb" + + @pytest.mark.asyncio + async def test_with_tools_fake_streams_text_in_chunks(self): + with patch.object(MindIEChatModel, "_agenerate", new_callable=AsyncMock) as mock_ag, patch.object(MindIEChatModel, "__init__", return_value=None): + long_text = "A" * 50 + mock_ag.return_value = _make_chat_result(long_text) + model = MindIEChatModel.__new__(MindIEChatModel) + + chunks = await self._collect(model._astream([HumanMessage(content="q")], tools=[{"type": "function", "function": {"name": "dummy"}}])) + + full = "".join(c.message.content for c in chunks) + assert full == long_text + assert len(chunks) > 1 + + @pytest.mark.asyncio + async def test_with_tools_emits_tool_call_chunk(self): + + tool_calls = [{"name": "fn", "args": {}, "id": "c1"}] + with patch.object(MindIEChatModel, "_agenerate", new_callable=AsyncMock) as mock_ag, patch.object(MindIEChatModel, "__init__", return_value=None): + mock_ag.return_value = _make_chat_result("ok", tool_calls=tool_calls) + model = MindIEChatModel.__new__(MindIEChatModel) + + chunks = await self._collect(model._astream([HumanMessage(content="q")], tools=[{"type": "function", "function": {"name": "fn"}}])) + + tool_chunks = [c for c in chunks if getattr(c.message, "tool_calls", [])] + assert tool_chunks, "No chunk carried tool_calls" + assert tool_chunks[-1].message.tool_calls[0]["name"] == "fn" + + @pytest.mark.asyncio + async def test_with_tools_empty_text_still_emits_tool_chunk(self): + tool_calls = [{"name": "x", "args": {}, "id": "c2"}] + with patch.object(MindIEChatModel, "_agenerate", new_callable=AsyncMock) as mock_ag, patch.object(MindIEChatModel, "__init__", return_value=None): + mock_ag.return_value = _make_chat_result("", tool_calls=tool_calls) + model = MindIEChatModel.__new__(MindIEChatModel) + + chunks = await self._collect(model._astream([HumanMessage(content="q")], tools=[{"type": "function", "function": {"name": "x"}}])) + + assert any(getattr(c.message, "tool_calls", []) for c in chunks) diff --git a/backend/uv.lock b/backend/uv.lock index 716b7e07a..bd2630869 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -688,6 +688,7 @@ dependencies = [ dev = [ { name = "prompt-toolkit" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "ruff" }, ] @@ -711,6 +712,7 @@ requires-dist = [ dev = [ { name = "prompt-toolkit", specifier = ">=3.0.0" }, { name = "pytest", specifier = ">=9.0.3" }, + { name = "pytest-asyncio", specifier = ">=1.3.0" }, { name = "ruff", specifier = ">=0.14.11" }, ] @@ -3127,6 +3129,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5", size = 50087, upload-time = "2025-11-10T16:07:47.256Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/35/f8b19922b6a25bc0880171a2f1a003eaeb93657475193ab516fd87cac9da/pytest_asyncio-1.3.0-py3-none-any.whl", hash = "sha256:611e26147c7f77640e6d0a92a38ed17c3e9848063698d5c93d5aa7aa11cebff5", size = 15075, upload-time = "2025-11-10T16:07:45.537Z" }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" diff --git a/config.example.yaml b/config.example.yaml index 1c5bf4129..32a94105a 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -326,6 +326,27 @@ models: # chat_template_kwargs: # enable_thinking: true + + # Example: Qwen3-Coder deployed on MindIE Engine + # - name: Qwen3_Coder_480B_MindIE + # display_name: Qwen3-Coder-480B (MindIE) + # use: deerflow.models.mindie_provider:MindIEChatModel + # model: Qwen3-Coder-480B-A35B-Instruct-Client + # base_url: http://localhost:8989/v1 + # api_key: $OPENAI_API_KEY + # temperature: 0 + # max_retries: 1 + # supports_thinking: false + # supports_vision: false + # supports_reasoning_effort: false + # # --- Advanced Network Settings --- + # # Due to MindIE's streaming limitations with tool calling, the provider + # # uses mock-streaming (awaiting full generation). Extended timeouts are required. + # read_timeout: 900.0 # 15 minutes to prevent drops during long document generation + # connect_timeout: 30.0 + # write_timeout: 60.0 + # pool_timeout: 30.0 + # ============================================================================ # Tool Groups Configuration # ============================================================================