"""Tool management for function calling and MCP.""" import asyncio import base64 import binascii from dataclasses import dataclass import inspect import logging import mimetypes import os import threading from pathlib import Path from typing import Any, Dict, List, Mapping, Sequence from fastmcp import Client from fastmcp.client.client import CallToolResult as FastMcpCallToolResult from fastmcp.client.transports import StreamableHttpTransport, StdioTransport from mcp import types from entity.configs import ToolingConfig, ConfigError from entity.configs.node.tooling import FunctionToolConfig, McpLocalConfig, McpRemoteConfig from entity.messages import MessageBlock, MessageBlockType from entity.tool_spec import ToolSpec from utils.attachments import AttachmentStore from utils.function_manager import FUNCTION_CALLING_DIR, FunctionManager logger = logging.getLogger(__name__) DEFAULT_MCP_HTTP_TIMEOUT = 10.0 @dataclass class _FunctionManagerCacheEntry: manager: FunctionManager auto_loaded: bool = False class ToolManager: """Manage function tools for agent nodes.""" def __init__(self) -> None: self._functions_dir: Path = FUNCTION_CALLING_DIR self._function_managers: Dict[Path, _FunctionManagerCacheEntry] = {} self._mcp_tool_cache: Dict[str, List[Any]] = {} self._mcp_stdio_clients: Dict[str, "_StdioClientWrapper"] = {} def _get_function_manager(self) -> FunctionManager: entry = self._function_managers.get(self._functions_dir) if entry is None: entry = _FunctionManagerCacheEntry(manager=FunctionManager(self._functions_dir)) self._function_managers[self._functions_dir] = entry return entry.manager def _ensure_functions_loaded(self, auto_load: bool) -> None: if not auto_load: return entry = self._function_managers.setdefault( self._functions_dir, _FunctionManagerCacheEntry(manager=FunctionManager(self._functions_dir)) ) if not entry.auto_loaded: entry.manager.load_functions() entry.auto_loaded = True async def _fetch_mcp_tools_http( self, server_url: str, *, headers: Dict[str, str] | None = None, timeout: float | None = None, attempts: int = 3, ) -> List[Any]: delay = 0.5 last_error: Exception | None = None for attempt in range(1, attempts + 1): try: client = Client( transport=StreamableHttpTransport(server_url, headers=headers or None), timeout=timeout or DEFAULT_MCP_HTTP_TIMEOUT, ) async with client: return await client.list_tools() except Exception as exc: # pragma: no cover - passthrough to caller last_error = exc if attempt == attempts: raise await asyncio.sleep(delay) delay *= 2 if last_error: raise last_error return [] async def _fetch_mcp_tools_stdio(self, config: McpLocalConfig, launch_key: str) -> List[Any]: client = self._get_stdio_client(config, launch_key) return client.list_tools() def get_tool_specs(self, tool_configs: List[ToolingConfig] | None) -> List[ToolSpec]: """Return provider-agnostic tool specifications for the given config list.""" if not tool_configs: return [] specs: List[ToolSpec] = [] seen_tools: set[str] = set() for idx, tool_config in enumerate(tool_configs): current_specs: List[ToolSpec] = [] if tool_config.type == "function": config = tool_config.as_config(FunctionToolConfig) if not config: raise ValueError("Function tooling configuration missing") current_specs = self._build_function_specs(config) elif tool_config.type == "mcp_remote": config = tool_config.as_config(McpRemoteConfig) if not config: raise ValueError("MCP remote configuration missing") current_specs = self._build_mcp_remote_specs(config) elif tool_config.type == "mcp_local": config = tool_config.as_config(McpLocalConfig) if not config: raise ValueError("MCP local configuration missing") current_specs = self._build_mcp_local_specs(config) else: # Skip unknown types or raise error? Existing code raised error in execute but ignored in get_specs? # Better to ignore or log warning for robustness, but let's stick to safe behavior. pass prefix = tool_config.prefix for spec in current_specs: original_name = spec.name final_name = f"{prefix}_{original_name}" if prefix else original_name if final_name in seen_tools: raise ConfigError( f"Duplicate tool name '{final_name}' detected. " f"Please use a unique 'prefix' in your tooling configuration." ) seen_tools.add(final_name) # Update spec spec.name = final_name spec.metadata["_config_index"] = idx spec.metadata["original_name"] = original_name specs.append(spec) return specs async def execute_tool( self, tool_name: str, arguments: Dict[str, Any], tool_config: ToolingConfig, *, tool_context: Dict[str, Any] | None = None, ) -> Any: """Execute a tool using the provided configuration.""" if tool_config.type == "function": config = tool_config.as_config(FunctionToolConfig) if not config: raise ValueError("Function tooling configuration missing") return self._execute_function_tool(tool_name, arguments, config, tool_context) if tool_config.type == "mcp_remote": config = tool_config.as_config(McpRemoteConfig) if not config: raise ValueError("MCP remote configuration missing") return await self._execute_mcp_remote_tool(tool_name, arguments, config, tool_context) if tool_config.type == "mcp_local": config = tool_config.as_config(McpLocalConfig) if not config: raise ValueError("MCP local configuration missing") return await self._execute_mcp_local_tool(tool_name, arguments, config, tool_context) raise ValueError(f"Unsupported tool type: {tool_config.type}") def _build_function_specs(self, config: FunctionToolConfig) -> List[ToolSpec]: self._ensure_functions_loaded(config.auto_load) specs: List[ToolSpec] = [] for tool in config.tools: parameters = tool.get("parameters") if not isinstance(parameters, Mapping): parameters = {"type": "object", "properties": {}} specs.append( ToolSpec( name=tool.get("name", ""), description=tool.get("description") or "", parameters=parameters, metadata={"source": "function"}, ) ) return specs def _build_mcp_remote_specs(self, config: McpRemoteConfig) -> List[ToolSpec]: cache_key = f"remote:{config.cache_key()}" tools = self._mcp_tool_cache.get(cache_key) if tools is None: tools = asyncio.run( self._fetch_mcp_tools_http( config.server, headers=config.headers, timeout=config.timeout, ) ) self._mcp_tool_cache[cache_key] = tools specs: List[ToolSpec] = [] for tool in tools: specs.append( ToolSpec( name=tool.name, description=tool.description or "", parameters=tool.inputSchema or {"type": "object", "properties": {}}, metadata={"source": "mcp", "server": config.server, "mode": "remote"}, ) ) return specs def _build_mcp_local_specs(self, config: McpLocalConfig) -> List[ToolSpec]: launch_key = config.cache_key() if not launch_key: raise ValueError("MCP local configuration missing launch key") cache_key = f"stdio:{launch_key}" tools = self._mcp_tool_cache.get(cache_key) if tools is None: tools = asyncio.run(self._fetch_mcp_tools_stdio(config, launch_key)) self._mcp_tool_cache[cache_key] = tools specs: List[ToolSpec] = [] for tool in tools: specs.append( ToolSpec( name=tool.name, description=tool.description or "", parameters=tool.inputSchema or {"type": "object", "properties": {}}, metadata={"source": "mcp", "server": "stdio", "mode": "local"}, ) ) return specs def _execute_function_tool( self, tool_name: str, arguments: Dict[str, Any], config: FunctionToolConfig, tool_context: Dict[str, Any] | None = None, ) -> Any: mgr = self._get_function_manager() if config.auto_load: mgr.load_functions() func = mgr.get_function(tool_name) if func is None: raise ValueError(f"Tool {tool_name} not found in {self._functions_dir}") call_args = dict(arguments or {}) if ( tool_context is not None # and "_context" not in call_args and self._function_accepts_context(func) ): call_args["_context"] = tool_context return func(**call_args) def _function_accepts_context(self, func: Any) -> bool: try: signature = inspect.signature(func) except (ValueError, TypeError): return False for param in signature.parameters.values(): if param.kind is inspect.Parameter.VAR_KEYWORD: return True if param.name == "_context" and param.kind in ( inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY, ): return True return False async def _execute_mcp_remote_tool( self, tool_name: str, arguments: Dict[str, Any], config: McpRemoteConfig, tool_context: Dict[str, Any] | None = None, ) -> Any: client = Client( transport=StreamableHttpTransport(config.server, headers=config.headers or None), timeout=config.timeout or DEFAULT_MCP_HTTP_TIMEOUT, ) async with client: result = await client.call_tool(tool_name, arguments) return self._normalize_mcp_result(tool_name, result, tool_context) async def _execute_mcp_local_tool( self, tool_name: str, arguments: Dict[str, Any], config: McpLocalConfig, tool_context: Dict[str, Any] | None = None, ) -> Any: launch_key = config.cache_key() if not launch_key: raise ValueError("MCP local configuration missing launch key") stdio_client = self._get_stdio_client(config, launch_key) result = stdio_client.call_tool(tool_name, arguments) return self._normalize_mcp_result(tool_name, result, tool_context) def _normalize_mcp_result( self, tool_name: str, result: FastMcpCallToolResult, tool_context: Dict[str, Any] | None, ) -> Any: attachment_store = self._extract_attachment_store(tool_context) blocks = self._convert_mcp_content_to_blocks(tool_name, result.content, attachment_store) if blocks: return blocks if result.structured_content is not None: return result.structured_content if result.content: content = result.content[0] if isinstance(content, types.TextContent): return content.text return str(content) return None def _extract_attachment_store(self, tool_context: Dict[str, Any] | None) -> AttachmentStore | None: if not tool_context: return None candidate = tool_context.get("attachment_store") if isinstance(candidate, AttachmentStore): return candidate if candidate is not None: logger.warning( "attachment_store in tool_context is not AttachmentStore (got %s)", type(candidate).__name__, ) return None def _convert_mcp_content_to_blocks( self, tool_name: str, contents: Sequence[types.ContentBlock] | None, attachment_store: AttachmentStore | None, ) -> List[MessageBlock]: blocks: List[MessageBlock] = [] if not contents: return blocks for idx, content in enumerate(contents): converted = self._convert_single_mcp_block(tool_name, content, idx, attachment_store) if converted: blocks.extend(converted) return blocks def _convert_single_mcp_block( self, tool_name: str, content: types.ContentBlock, block_index: int, attachment_store: AttachmentStore | None, ) -> List[MessageBlock]: if isinstance(content, types.TextContent): return [MessageBlock.text_block(content.text)] if isinstance(content, types.ImageContent): return self._materialize_mcp_binary_block( tool_name, content.data, content.mimeType, MessageBlockType.IMAGE, block_index, attachment_store, ) if isinstance(content, types.AudioContent): return self._materialize_mcp_binary_block( tool_name, content.data, content.mimeType, MessageBlockType.AUDIO, block_index, attachment_store, ) if isinstance(content, types.EmbeddedResource): resource = content.resource if isinstance(resource, types.TextResourceContents): data_payload = { "uri": str(resource.uri), "mime_type": resource.mimeType, } return [ MessageBlock( type=MessageBlockType.TEXT, text=resource.text, data={k: v for k, v in data_payload.items() if v is not None}, ) ] if isinstance(resource, types.BlobResourceContents): extra = { "resource_uri": str(resource.uri), } return self._materialize_mcp_binary_block( tool_name, resource.blob, resource.mimeType, self._message_block_type_from_mime(resource.mimeType), block_index, attachment_store, extra=extra, ) if isinstance(content, types.ResourceLink): data_payload = { "uri": str(content.uri), "mime_type": content.mimeType, "description": content.description, } return [ MessageBlock( type=MessageBlockType.DATA, text=content.description or f"Resource link: {content.uri}", data={k: v for k, v in data_payload.items() if v is not None}, ) ] logger.warning("Unhandled MCP content block type: %s", type(content).__name__) return [] def _materialize_mcp_binary_block( self, tool_name: str, payload_b64: str, mime_type: str | None, block_type: MessageBlockType, block_index: int, attachment_store: AttachmentStore | None, *, extra: Dict[str, Any] | None = None, ) -> List[MessageBlock]: display_name = self._build_attachment_name(tool_name, block_type, block_index, mime_type) try: binary = base64.b64decode(payload_b64) except (binascii.Error, ValueError) as exc: logger.warning("Failed to decode MCP %s payload for %s: %s", block_type.value, tool_name, exc) return [ MessageBlock.text_block( f"[failed to decode {block_type.value} content from {tool_name}]" ) ] metadata = { "source": "mcp_tool", "tool_name": tool_name, "block_type": block_type.value, } if extra: metadata.update(extra) if attachment_store is None: placeholder = ( f"[binary content omitted: {display_name} ({mime_type or 'application/octet-stream'})]" ) return [ MessageBlock( type=MessageBlockType.TEXT, text=placeholder, data={**metadata, "reason": "attachment_store_missing", "mime_type": mime_type}, ) ] record = attachment_store.register_bytes( binary, kind=block_type, mime_type=mime_type, display_name=display_name, extra=metadata, ) return [record.as_message_block()] def _build_attachment_name( self, tool_name: str, block_type: MessageBlockType, block_index: int, mime_type: str | None, ) -> str: base = f"{tool_name}_{block_type.value}_{block_index + 1}".strip() or "attachment" safe_base = "".join(ch if ch.isalnum() or ch in {"-", "_"} else "_" for ch in base) ext = mimetypes.guess_extension(mime_type or "") or "" return f"{safe_base}{ext}" def _message_block_type_from_mime(self, mime_type: str | None) -> MessageBlockType: if not mime_type: return MessageBlockType.FILE if mime_type.startswith("image/"): return MessageBlockType.IMAGE if mime_type.startswith("audio/"): return MessageBlockType.AUDIO if mime_type.startswith("video/"): return MessageBlockType.VIDEO return MessageBlockType.FILE def _get_stdio_client(self, config: McpLocalConfig, launch_key: str) -> "_StdioClientWrapper": client = self._mcp_stdio_clients.get(launch_key) if client is None: client = _StdioClientWrapper(config) self._mcp_stdio_clients[launch_key] = client return client class _StdioClientWrapper: def __init__(self, config: McpLocalConfig) -> None: env = os.environ.copy() if config.inherit_env else {} env.update(config.env) env_payload = env or None transport = StdioTransport( command=config.command, args=list(config.args), env=env_payload, cwd=config.cwd, keep_alive=True, ) self._client = Client(transport=transport) self._loop = asyncio.new_event_loop() self._thread = threading.Thread(target=self._run_loop, daemon=True) self._thread.start() init_future = asyncio.run_coroutine_threadsafe(self._initialize(), self._loop) init_future.result() def _run_loop(self) -> None: asyncio.set_event_loop(self._loop) self._loop.run_forever() async def _initialize(self) -> None: self._lock = asyncio.Lock() await self._client.__aenter__() def list_tools(self) -> List[Any]: future = asyncio.run_coroutine_threadsafe(self._call("list_tools"), self._loop) return future.result() def call_tool(self, name: str, arguments: Dict[str, Any]) -> Any: future = asyncio.run_coroutine_threadsafe( self._call("call_tool", name, arguments), self._loop, ) return future.result() async def _call(self, method: str, *args: Any, **kwargs: Any) -> Any: async with self._lock: func = getattr(self._client, method) return await func(*args, **kwargs) def close(self) -> None: future = asyncio.run_coroutine_threadsafe(self._shutdown(), self._loop) future.result() self._loop.call_soon_threadsafe(self._loop.stop) self._thread.join() async def _shutdown(self) -> None: async with self._lock: await self._client.__aexit__(None, None, None)