diff --git a/backend/packages/harness/deerflow/sandbox/tools.py b/backend/packages/harness/deerflow/sandbox/tools.py index 32ee7d646..a20004a8a 100644 --- a/backend/packages/harness/deerflow/sandbox/tools.py +++ b/backend/packages/harness/deerflow/sandbox/tools.py @@ -3,10 +3,9 @@ import re import shlex from pathlib import Path -from langchain.tools import ToolRuntime, tool -from langgraph.typing import ContextT +from langchain.tools import tool -from deerflow.agents.thread_state import ThreadDataState, ThreadState +from deerflow.agents.thread_state import ThreadDataState from deerflow.config import get_app_config from deerflow.config.paths import VIRTUAL_PATH_PREFIX from deerflow.sandbox.exceptions import ( @@ -19,6 +18,7 @@ from deerflow.sandbox.sandbox import Sandbox from deerflow.sandbox.sandbox_provider import get_sandbox_provider from deerflow.sandbox.search import GrepMatch from deerflow.sandbox.security import LOCAL_HOST_BASH_DISABLED_MESSAGE, is_host_bash_allowed +from deerflow.tools.types import Runtime _ABSOLUTE_PATH_PATTERN = re.compile(r"(?()]+)") _FILE_URL_PATTERN = re.compile(r"\bfile://\S+", re.IGNORECASE) @@ -419,7 +419,7 @@ def _join_path_preserving_style(base: str, relative: str) -> str: return f"{stripped_base}{separator}{normalized_relative}" -def _sanitize_error(error: Exception, runtime: "ToolRuntime[ContextT, ThreadState] | None" = None) -> str: +def _sanitize_error(error: Exception, runtime: Runtime | None = None) -> str: """Sanitize an error message to avoid leaking host filesystem paths. In local-sandbox mode, resolved host paths in the error string are masked @@ -994,7 +994,7 @@ def _apply_cwd_prefix(command: str, thread_data: ThreadDataState | None) -> str: return command -def get_thread_data(runtime: ToolRuntime[ContextT, ThreadState] | None) -> ThreadDataState | None: +def get_thread_data(runtime: Runtime | None) -> ThreadDataState | None: """Extract thread_data from runtime state.""" if runtime is None: return None @@ -1003,7 +1003,7 @@ def get_thread_data(runtime: ToolRuntime[ContextT, ThreadState] | None) -> Threa return runtime.state.get("thread_data") -def is_local_sandbox(runtime: ToolRuntime[ContextT, ThreadState] | None) -> bool: +def is_local_sandbox(runtime: Runtime | None) -> bool: """Check if the current sandbox is a local sandbox. Path replacement is only needed for local sandbox since aio sandbox @@ -1019,7 +1019,7 @@ def is_local_sandbox(runtime: ToolRuntime[ContextT, ThreadState] | None) -> bool return sandbox_state.get("sandbox_id") == "local" -def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox: +def sandbox_from_runtime(runtime: Runtime | None = None) -> Sandbox: """Extract sandbox instance from tool runtime. DEPRECATED: Use ensure_sandbox_initialized() for lazy initialization support. @@ -1048,7 +1048,7 @@ def sandbox_from_runtime(runtime: ToolRuntime[ContextT, ThreadState] | None = No return sandbox -def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | None = None) -> Sandbox: +def ensure_sandbox_initialized(runtime: Runtime | None = None) -> Sandbox: """Ensure sandbox is initialized, acquiring lazily if needed. On first call, acquires a sandbox from the provider and stores it in runtime state. @@ -1107,7 +1107,7 @@ def ensure_sandbox_initialized(runtime: ToolRuntime[ContextT, ThreadState] | Non return sandbox -def ensure_thread_directories_exist(runtime: ToolRuntime[ContextT, ThreadState] | None) -> None: +def ensure_thread_directories_exist(runtime: Runtime | None) -> None: """Ensure thread data directories (workspace, uploads, outputs) exist. This function is called lazily when any sandbox tool is first used. @@ -1221,7 +1221,7 @@ def _truncate_ls_output(output: str, max_chars: int) -> str: @tool("bash", parse_docstring=True) -def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, command: str) -> str: +def bash_tool(runtime: Runtime, description: str, command: str) -> str: """Execute a bash command in a Linux environment. @@ -1270,7 +1270,7 @@ def bash_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, com @tool("ls", parse_docstring=True) -def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path: str) -> str: +def ls_tool(runtime: Runtime, description: str, path: str) -> str: """List the contents of a directory up to 2 levels deep in tree format. Args: @@ -1318,7 +1318,7 @@ def ls_tool(runtime: ToolRuntime[ContextT, ThreadState], description: str, path: @tool("glob", parse_docstring=True) def glob_tool( - runtime: ToolRuntime[ContextT, ThreadState], + runtime: Runtime, description: str, pattern: str, path: str, @@ -1368,7 +1368,7 @@ def glob_tool( @tool("grep", parse_docstring=True) def grep_tool( - runtime: ToolRuntime[ContextT, ThreadState], + runtime: Runtime, description: str, pattern: str, path: str, @@ -1438,7 +1438,7 @@ def grep_tool( @tool("read_file", parse_docstring=True) def read_file_tool( - runtime: ToolRuntime[ContextT, ThreadState], + runtime: Runtime, description: str, path: str, start_line: int | None = None, @@ -1493,7 +1493,7 @@ def read_file_tool( @tool("write_file", parse_docstring=True) def write_file_tool( - runtime: ToolRuntime[ContextT, ThreadState], + runtime: Runtime, description: str, path: str, content: str, @@ -1533,7 +1533,7 @@ def write_file_tool( @tool("str_replace", parse_docstring=True) def str_replace_tool( - runtime: ToolRuntime[ContextT, ThreadState], + runtime: Runtime, description: str, path: str, old_str: str, diff --git a/backend/packages/harness/deerflow/tools/builtins/present_file_tool.py b/backend/packages/harness/deerflow/tools/builtins/present_file_tool.py index 13a7a017e..c091e01df 100644 --- a/backend/packages/harness/deerflow/tools/builtins/present_file_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/present_file_tool.py @@ -1,20 +1,19 @@ from pathlib import Path from typing import Annotated -from langchain.tools import InjectedToolCallId, ToolRuntime, tool +from langchain.tools import InjectedToolCallId, tool from langchain_core.messages import ToolMessage from langgraph.config import get_config from langgraph.types import Command -from langgraph.typing import ContextT -from deerflow.agents.thread_state import ThreadState from deerflow.config.paths import VIRTUAL_PATH_PREFIX, get_paths from deerflow.runtime.user_context import get_effective_user_id +from deerflow.tools.types import Runtime OUTPUTS_VIRTUAL_PREFIX = f"{VIRTUAL_PATH_PREFIX}/outputs" -def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None: +def _get_thread_id(runtime: Runtime) -> str | None: """Resolve the current thread id from runtime context or RunnableConfig.""" thread_id = runtime.context.get("thread_id") if runtime.context else None if thread_id: @@ -32,7 +31,7 @@ def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState]) -> str | None: def _normalize_presented_filepath( - runtime: ToolRuntime[ContextT, ThreadState], + runtime: Runtime, filepath: str, ) -> str: """Normalize a presented file path to the `/mnt/user-data/outputs/*` contract. @@ -83,7 +82,7 @@ def _normalize_presented_filepath( @tool("present_files", parse_docstring=True) def present_file_tool( - runtime: ToolRuntime[ContextT, ThreadState], + runtime: Runtime, filepaths: list[str], tool_call_id: Annotated[str, InjectedToolCallId], ) -> Command: diff --git a/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py b/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py index d05d0cc73..5ea591f76 100644 --- a/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/setup_agent_tool.py @@ -3,12 +3,12 @@ import logging import yaml from langchain_core.messages import ToolMessage from langchain_core.tools import tool -from langgraph.prebuilt import ToolRuntime from langgraph.types import Command from deerflow.config.agents_config import validate_agent_name from deerflow.config.paths import get_paths from deerflow.runtime.user_context import get_effective_user_id +from deerflow.tools.types import Runtime logger = logging.getLogger(__name__) @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) def setup_agent( soul: str, description: str, - runtime: ToolRuntime, + runtime: Runtime, skills: list[str] | None = None, ) -> Command: """Setup the custom DeerFlow agent. diff --git a/backend/packages/harness/deerflow/tools/builtins/task_tool.py b/backend/packages/harness/deerflow/tools/builtins/task_tool.py index 1328507b2..0154f6a7a 100644 --- a/backend/packages/harness/deerflow/tools/builtins/task_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/task_tool.py @@ -6,11 +6,9 @@ import uuid from dataclasses import replace from typing import TYPE_CHECKING, Annotated, Any, cast -from langchain.tools import InjectedToolCallId, ToolRuntime, tool +from langchain.tools import InjectedToolCallId, tool from langgraph.config import get_stream_writer -from langgraph.typing import ContextT -from deerflow.agents.thread_state import ThreadState from deerflow.config import get_app_config from deerflow.sandbox.security import LOCAL_BASH_SUBAGENT_DISABLED_MESSAGE, is_host_bash_allowed from deerflow.subagents import SubagentExecutor, get_available_subagent_names, get_subagent_config @@ -21,6 +19,7 @@ from deerflow.subagents.executor import ( get_background_task_result, request_cancel_background_task, ) +from deerflow.tools.types import Runtime if TYPE_CHECKING: from deerflow.config.app_config import AppConfig @@ -50,7 +49,7 @@ def _merge_skill_allowlists(parent: list[str] | None, child: list[str] | None) - @tool("task", parse_docstring=True) async def task_tool( - runtime: ToolRuntime[ContextT, ThreadState], + runtime: Runtime, description: str, prompt: str, subagent_type: str, diff --git a/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py b/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py index 76cedefcf..90d951859 100644 --- a/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/update_agent_tool.py @@ -22,13 +22,13 @@ from typing import Any import yaml from langchain_core.messages import ToolMessage from langchain_core.tools import tool -from langgraph.prebuilt import ToolRuntime from langgraph.types import Command from deerflow.config.agents_config import load_agent_config, validate_agent_name from deerflow.config.app_config import get_app_config from deerflow.config.paths import get_paths from deerflow.runtime.user_context import get_effective_user_id +from deerflow.tools.types import Runtime logger = logging.getLogger(__name__) @@ -69,7 +69,7 @@ def _cleanup_temps(temps: list[Path]) -> None: @tool def update_agent( - runtime: ToolRuntime, + runtime: Runtime, soul: str | None = None, description: str | None = None, skills: list[str] | None = None, diff --git a/backend/packages/harness/deerflow/tools/builtins/view_image_tool.py b/backend/packages/harness/deerflow/tools/builtins/view_image_tool.py index 3dedcab70..3895cfd12 100644 --- a/backend/packages/harness/deerflow/tools/builtins/view_image_tool.py +++ b/backend/packages/harness/deerflow/tools/builtins/view_image_tool.py @@ -3,13 +3,13 @@ import mimetypes from pathlib import Path from typing import Annotated -from langchain.tools import InjectedToolCallId, ToolRuntime, tool +from langchain.tools import InjectedToolCallId, tool from langchain_core.messages import ToolMessage from langgraph.types import Command -from langgraph.typing import ContextT -from deerflow.agents.thread_state import ThreadDataState, ThreadState +from deerflow.agents.thread_state import ThreadDataState from deerflow.config.paths import VIRTUAL_PATH_PREFIX +from deerflow.tools.types import Runtime _ALLOWED_IMAGE_VIRTUAL_ROOTS = ( f"{VIRTUAL_PATH_PREFIX}/workspace", @@ -48,7 +48,7 @@ def _sanitize_image_error(error: Exception, thread_data: ThreadDataState | None) @tool("view_image", parse_docstring=True) def view_image_tool( - runtime: ToolRuntime[ContextT, ThreadState], + runtime: Runtime, image_path: str, tool_call_id: Annotated[str, InjectedToolCallId], ) -> Command: diff --git a/backend/packages/harness/deerflow/tools/skill_manage_tool.py b/backend/packages/harness/deerflow/tools/skill_manage_tool.py index c0114eb08..46865242c 100644 --- a/backend/packages/harness/deerflow/tools/skill_manage_tool.py +++ b/backend/packages/harness/deerflow/tools/skill_manage_tool.py @@ -7,16 +7,15 @@ import logging from typing import Any from weakref import WeakValueDictionary -from langchain.tools import ToolRuntime, tool -from langgraph.typing import ContextT +from langchain.tools import tool from deerflow.agents.lead_agent.prompt import refresh_skills_system_prompt_cache_async -from deerflow.agents.thread_state import ThreadState from deerflow.mcp.tools import _make_sync_tool_wrapper from deerflow.skills.security_scanner import scan_skill_content from deerflow.skills.storage import get_or_new_skill_storage from deerflow.skills.storage.skill_storage import SkillStorage from deerflow.skills.types import SKILL_MD_FILE +from deerflow.tools.types import Runtime logger = logging.getLogger(__name__) @@ -31,7 +30,7 @@ def _get_lock(name: str) -> asyncio.Lock: return lock -def _get_thread_id(runtime: ToolRuntime[ContextT, ThreadState] | None) -> str | None: +def _get_thread_id(runtime: Runtime | None) -> str | None: if runtime is None: return None if runtime.context and runtime.context.get("thread_id"): @@ -65,7 +64,7 @@ async def _to_thread(func, /, *args, **kwargs): async def _skill_manage_impl( - runtime: ToolRuntime[ContextT, ThreadState], + runtime: Runtime, action: str, name: str, content: str | None = None, @@ -204,7 +203,7 @@ async def _skill_manage_impl( @tool("skill_manage", parse_docstring=True) async def skill_manage_tool( - runtime: ToolRuntime[ContextT, ThreadState], + runtime: Runtime, action: str, name: str, content: str | None = None, diff --git a/backend/packages/harness/deerflow/tools/types.py b/backend/packages/harness/deerflow/tools/types.py new file mode 100644 index 000000000..4fffbd6c7 --- /dev/null +++ b/backend/packages/harness/deerflow/tools/types.py @@ -0,0 +1,11 @@ +from typing import Any + +from langchain.tools import ToolRuntime + +from deerflow.agents.thread_state import ThreadState + +# Concrete runtime type used by all DeerFlow tools. +# Using dict[str, Any] for the context parameter instead of the unbound ContextT +# TypeVar prevents PydanticSerializationUnexpectedValue warnings when LangChain +# calls model_dump() on a tool's auto-generated args_schema. +Runtime = ToolRuntime[dict[str, Any], ThreadState] diff --git a/backend/tests/test_tool_args_schema_no_pydantic_warning.py b/backend/tests/test_tool_args_schema_no_pydantic_warning.py new file mode 100644 index 000000000..037771b3e --- /dev/null +++ b/backend/tests/test_tool_args_schema_no_pydantic_warning.py @@ -0,0 +1,91 @@ +"""Regression test: tool args schemas must not emit Pydantic serialization warnings. + +DeerFlow tools annotate their runtime parameter as ``Runtime`` +(``deerflow.tools.types.Runtime`` = ``ToolRuntime[dict[str, Any], ThreadState]``) +so the LangChain tool framework injects the runtime automatically. +When the inner ``Runtime.context`` field is left as the unbound ``ContextT`` +TypeVar (default ``None``), Pydantic's ``model_dump()`` on the auto-generated +args schema emits a ``PydanticSerializationUnexpectedValue`` warning on every +tool call because the actual context DeerFlow installs is a dict. Using the +``Runtime`` alias (which binds the context to ``dict[str, Any]``) keeps +Pydantic's serialization expectations aligned with reality. +""" + +from __future__ import annotations + +import warnings + +import pytest +from langchain.tools import ToolRuntime + +from deerflow.sandbox.tools import ( + bash_tool, + glob_tool, + grep_tool, + ls_tool, + read_file_tool, + str_replace_tool, + write_file_tool, +) +from deerflow.tools.builtins.present_file_tool import present_file_tool +from deerflow.tools.builtins.setup_agent_tool import setup_agent +from deerflow.tools.builtins.task_tool import task_tool +from deerflow.tools.builtins.update_agent_tool import update_agent +from deerflow.tools.builtins.view_image_tool import view_image_tool +from deerflow.tools.skill_manage_tool import skill_manage_tool + + +def _make_runtime(context: dict) -> ToolRuntime: + return ToolRuntime( + state={"sandbox": {"sandbox_id": "local"}, "thread_data": {}}, + context=context, + config={"configurable": {"thread_id": context.get("thread_id", "thread-1")}}, + stream_writer=lambda _: None, + tools=[], + tool_call_id="call-1", + store=None, + ) + + +_TOOL_CASES = [ + (bash_tool, {"description": "list", "command": "ls"}), + (ls_tool, {"description": "list", "path": "/tmp"}), + (glob_tool, {"description": "find", "pattern": "*.py", "path": "/tmp"}), + (grep_tool, {"description": "search", "pattern": "x", "path": "/tmp"}), + (read_file_tool, {"description": "read", "path": "/tmp/x"}), + (write_file_tool, {"description": "write", "path": "/tmp/x", "content": "hi"}), + (str_replace_tool, {"description": "replace", "path": "/tmp/x", "old_str": "a", "new_str": "b"}), + (present_file_tool, {"filepaths": ["/tmp/x"], "tool_call_id": "call-1"}), + (view_image_tool, {"image_path": "/tmp/img.png", "tool_call_id": "call-1"}), + (task_tool, {"description": "do", "prompt": "go", "subagent_type": "general-purpose", "tool_call_id": "call-1"}), + (skill_manage_tool, {"action": "list", "name": "demo"}), + (setup_agent, {"soul": "s", "description": "d"}), + (update_agent, {}), +] + + +@pytest.mark.parametrize( + ("tool_obj", "extra_args"), + _TOOL_CASES, + ids=[case[0].name for case in _TOOL_CASES], +) +def test_tool_args_schema_does_not_emit_pydantic_context_warning(tool_obj, extra_args) -> None: + """``model_dump()`` of the auto-generated args_schema must not warn about ``context``. + + The model_dump path is hit by LangChain's ``BaseTool._parse_input`` on every tool + invocation (see langchain_core/tools/base.py:712), so any warning here would fire + once per tool call and pollute production logs. + """ + schema = tool_obj.args_schema + assert schema is not None, f"{tool_obj.name} has no args_schema" + + runtime_obj = _make_runtime({"thread_id": "thread-1", "sandbox_id": "local"}) + payload = {**extra_args, "runtime": runtime_obj} + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + validated = schema.model_validate(payload) + validated.model_dump() + + pydantic_warnings = [w for w in caught if "PydanticSerializationUnexpectedValue" in str(w.message)] + assert not pydantic_warnings, f"{tool_obj.name} args_schema.model_dump() emitted Pydantic context serialization warnings: {[str(w.message) for w in pydantic_warnings]}"